import copy

import numpy as np
import random
import math
import time
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

import csv
maxRate = 500
maxRatebs = 600
#UAV_R = 18
R_forest = 15#18
#R_LOS = UAV_R
ROW = 100#100
COL = 100#100
# UE_TOTAL = 150
#UAV_R2 = UAV_R ** 2
h = 200
r_th = 0  # dB
# ----------------------fspl
f = 1.4e3  # MHz
d0 = 1
c = 3e8
alpha = 3.5
pi = math.pi
uav_power_aver = 0.2#W
uav_power_upper = 3
uav_aver_upper = 0.6
power_t = 0.2
#R_1 = 8#20
#R_2 = 2#24

A = 0.25
C = 0.39
E = 0.25
G = 0
H = 0.05
dep_tree = 2  # 植被深度
glo_pl_f = (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
snr_thuav = 3
snr_thbs = 15
Noise = 0
class BS:
    def __init__(self):
        self.loc = [-1, -1]
        # 基站种类 共三种
        self.sort = 'UAV'
        self.los_sum = 0
        self.R = 0
class UE:
    def __init__(self):
        self.loc = [-1,-1]
        # 用户按照所需信噪比分为两类
        self.sort = 1
        self.R = 0
        self.LoS = 1
        self.isCover = 0
    # 销除
    def __del__(self):
        pass
def Ue2Class(ue_list):
    UE_list = []
    for i in range(len(ue_list)):
        new_ue = UE()
        new_ue.loc = [ue_list[i][0],ue_list[i][1]]
        if np.random.random(1) < 1:
            new_ue.sort = 2
            new_ue.R = R_2
        UE_list.append(new_ue)
    return UE_list

def fspl(dist):
    # d单位是米
    if dist == 0:
        return 20 * np.log10(4 * pi * f * d0 / c)
    else:
        return 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10(dist / d0)



def pl_slant(dist):
    theta = math.atan2(h, dist)
    return A * np.power(f, C) * np.power(dep_tree, E) * np.power(theta + G, H)



def calc_coverate(uav_list, ue_list, R=R_forest):
    uav_x, uav_y = [], []
    us_px, us_py = [], []
    for i in range(ue_list.__len__()):
        us_px.append(ue_list[i][0])
        us_py.append(ue_list[i][1])

    for i in range(uav_list.__len__()):
        uav_x.append(uav_list[i][0])
        uav_y.append(uav_list[i][1])

    uecovered = np.zeros(len(us_px))
    for i in range(len(us_px)):
        for j in range(len(uav_x)):
            dist = math.sqrt((uav_x[j] - us_px[i]) ** 2 + (uav_y[j] - us_py[i]) ** 2)
            if dist <= R:
                uecovered[i] = 1
    return len(uecovered.nonzero()[0]) / len(us_px)


def calc_var_pl(uav_list, UE_list, channel=1):
    uav_x, uav_y = [], []
    us_px, us_py = [], []
    for i in range(UE_list.__len__()):
        us_px.append(UE_list[i].loc[0])
        us_py.append(UE_list[i].loc[1])
    us_p = us_px, us_py
    for i in range(uav_list.__len__()):
        uav_x.append(uav_list[i][0])
        uav_y.append(uav_list[i][1])
    uav_posi = uav_x, uav_y
    numuav, usernum = len(uav_x), len(us_px)
    pl_sum = 0  # 总路损
    pl_uav = []
    for i in range(numuav):
        pl_i = 0
        for j in range(usernum):
            dist = math.sqrt((uav_x[i] - us_px[j]) ** 2 + (uav_y[i] - us_py[j]) ** 2)
            if 0 < dist <= 2 * UE_list[j].R:
                pl_i += fspl(dist * 200)
                if channel == 1:
                    pl_i += pl_slant(dist * 200) + np.random.normal(scale=6, size=1)
        pl_sum += pl_i
        pl_uav.append(pl_i)
    # print('路损',pl_uav)
    aver_pl = pl_sum / numuav  # 总路损  平均路损

    var_pl = np.var(pl_uav) / 1000
    #uelinked, linked = connectionUavUe(us_p, uav_posi, numMax=100)
    # 吞吐量
    p_s = 20
    p_n = -95
    # 用户的吞吐量
    throughput = []
    for i in range(usernum):
        throughput_i = 0
        pl_i = 0
        for j in range(numuav):
            #if i in linked[j]:
            dist = math.sqrt((uav_x[j] - us_px[i]) ** 2 + (uav_y[j] - us_py[i]) ** 2)
            if dist <= UE_list[i].R:
                pl_i += fspl(dist * 200) + pl_slant(dist * 200) + np.random.normal(scale=6, size=1)
                throughput_i += math.log2(abs(1 + p_s - pl_i - p_n))

        throughput.append(throughput_i)
        # print(throughput)

    return var_pl, np.average(throughput)


def read_csv(addr):
    ue_list = []
    with open(addr, 'r') as f:
        f = csv.reader(f, delimiter=',')
        #next(f)
        for row in f:
            ue_list.append(row)
    for i in range(len(ue_list)):
        ue_list[i] = list(map(float, ue_list[i]))
    return ue_list


def cal_los_mat(loc, R, height2, pl_f):
    dist_1 = (loc[0] - R) ** 2 + (loc[1] - R) ** 2
    dist = math.sqrt(dist_1 + height2)
    los = 20 * math.log10(dist / 100) + pl_f
    return los

# size : 2 * R + 1
# 从用户出发，可以与用户建立连接关系的无人机坐标  1-连接 0-不连接
def Create_Boundary_pattern(R1, height, pl_f):
    height2 = height ** 2
    R_pattern = R1 * 2 + 1
    con_pattern = np.zeros((R_pattern, R_pattern))
    mat_pattern = np.zeros((R_pattern, R_pattern))
    for i in range(R1):
        con_pattern[i, R1] = 1
        con_pattern[R_pattern - i - 1, R1] = 1
        con_pattern[R1, i] = 1
        con_pattern[R1, R_pattern - i - 1] = 1

        los = cal_los_mat([i, R1], R1, height2, pl_f)
        mat_pattern[i, R1] = los
        mat_pattern[R_pattern - i - 1, R1] = los
        mat_pattern[R1, i] = los
        mat_pattern[R1, R_pattern - i - 1] = los
        for j in range(R1):
            if ((i - R1) ** 2 + (j - R1) ** 2) <= R1 ** 2:
                con_pattern[i, j] = 1
                con_pattern[i, R_pattern - j - 1] = 1
                con_pattern[R_pattern - i - 1, j] = 1
                con_pattern[R_pattern - i - 1, R_pattern - j - 1] = 1

                los = cal_los_mat([i, j], R1, height2, pl_f)
                mat_pattern[i, j] = los
                mat_pattern[i, R_pattern - j - 1] = los
                mat_pattern[R_pattern - i - 1, j] = los
                mat_pattern[R_pattern - i - 1, R_pattern - j - 1] = los

    con_pattern[R1, R1] = 1
    mat_pattern[R1, R1] = cal_los_mat([R1, R1], R1, height2, pl_f)
    return con_pattern, mat_pattern


# input:     R1(联通半径)    row_in(全局行数)      col_in(全局列数)
# output:    matlist(生成的矩阵)
# 从用户坐标得出 可连接基站情况-----> 基站部署的位置可覆盖用户数量
def Create_Boundary_list(R1, row_in, col_in, row_x, col_y, patt):
    mat_pattern = patt
    j = row_x
    i = col_y

    mat_main = np.zeros((row_in, col_in))
    if i - R1 < 0:
        col_left = R1 - i
        m_col_left = 0
    else:
        col_left = 0
        m_col_left = i - R1
    if i + R1 >= col_in:
        col_right = R1 + col_in - i
        m_col_right = col_in
    else:
        col_right = 2 * R1 + 1
        m_col_right = i + R1 + 1
    if j - R1 < 0:
        row_up = R1 - j
        m_row_up = 0
    else:
        row_up = 0
        m_row_up = j - R1
    if j + R1 >= row_in:
        row_down = R1 + row_in - j
        m_row_down = row_in
    else:
        row_down = 2 * R1 + 1
        m_row_down = j + R1 + 1

    mat_main[m_row_up:m_row_down, m_col_left:m_col_right] = mat_pattern[row_up:row_down, col_left:col_right]

    return mat_main

# size:
def Create_map(UE_list, c_pattern1,R):
    map_mat = np.zeros((ROW, COL))
    for i in range(len(UE_list)):
        # 对每个用户来说，

        map_mat += Create_Boundary_list(R, ROW, COL,int(UE_list[i].loc[0]), int(UE_list[i].loc[1]), c_pattern1)

    return map_mat


def Distance_check(ue, uav_loc,R):
    dis = (ue.loc[0] - uav_loc[0]) ** 2 + (ue.loc[1] - uav_loc[1]) ** 2
    dis = math.sqrt(dis)
    if dis <= R:
        return 1
    else:
        return 0

def pathlossA2G(dist2,h):
    '''
    :param dist2: 基站在水平面的投影与用户的距离
    :param h:高度
    :return:空地倾斜路径损耗
    '''
    f = 1.4e9  # MHz
    d0 = 1
    c = 3e8
    alpha = 3.5
    pi = math.pi
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    H = 0.05
    dep_tree = 2  # 植被深度
    dist3d = math.sqrt(dist2**2 + h**2)
    fspl = 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10(dist3d / d0)
    theta = math.atan2(h, dist2)
    slant =  A * np.power(f / math.pow(10, 6), C) * np.power(dep_tree, E) * np.power(theta + G, H)
    return fspl + slant

def Snr_check(ue,uav):
    dis = (ue.loc[0] - uav.loc[0]) ** 2 + (ue.loc[1] - uav.loc[1]) ** 2
    dis = math.sqrt(dis)
    pl = pathlossA2G(dis,uav.loc[2])
    if uav.power_allocation - pl - Noise >= ue.snrth:
        return 1
    else:
        return 0

def Solve_process(UE_list, bound_pattern1, los_pattern1,bound_pattern2, los_pattern2,  ue_num_list):

    hot_map = Create_map(UE_list, bound_pattern1,bound_pattern2)
    max_val = np.max(hot_map)
    # 仅保留可覆盖最多用户的位置 --> 部署一架可以覆盖最多用户的无人机
    hot_map[hot_map < max_val] = 0
    hot_map[hot_map > 0] = 1

    ue_num_list.append(max_val)

    los_map = Create_map(UE_list, los_pattern1,los_pattern2)
    los_map = los_map * hot_map
    min_val = 0
    if len(los_map[np.nonzero(los_map)]):
        min_val = np.min(los_map[np.nonzero(los_map)])
    # 确定无人机的位置
    aim = np.where(los_map == min_val)
    rand_num = random.randint(0, len(aim[0]) - 1)
    aim_loc = [aim[0][rand_num], aim[1][rand_num]]
    # 已覆盖的用户不影响后续部署
    for i in range(len(UE_list) - 1, -1, -1):
        if Distance_check(UE_list[i], aim_loc) == 1:
            del UE_list[i]

    return aim_loc

def Solvve_init_limit(UE_list, con_pattern1, los_pattern1, con_pattern2, los_pattern2, ue_num_list,numuav):
    uav_list =[]
    while len(uav_list)<numuav:
        uav_loc = Solve_process(UE_list, con_pattern1, los_pattern1, con_pattern2, los_pattern2, ue_num_list)  # 得到一架无人机位置
        uav_list.append(uav_loc)

    return uav_list

def Draw_result(u_list, solve_list, R1, ue_num_list):
    fig = plt.figure('fig')
    ax = fig.add_subplot(111)  # 111代表1*1的图的第一个子图
    for i in range(len(u_list)):
        plt.plot(u_list[i][0], u_list[i][1], "xy")
    for i in range(len(solve_list)):
        plt.plot(solve_list[i][0], solve_list[i][1], "vr")
        circle = Circle(xy=(solve_list[i][0], solve_list[i][1]), radius=R1, alpha=0.1, color='b')
        ax.add_patch(circle)
        plt.text(solve_list[i][0], solve_list[i][1], str(i + 1) + str('--') + str(int(ue_num_list[i])), fontsize=15)
    title = 'solve-greedy' + '.jpg'

    fig.savefig(title)
    fig.clf()


# 用贪婪算法初始化：
def Solve_process3(UE_list, bound_pattern1,bs,R):


    hot_map = Create_map(UE_list, bound_pattern1,R)
    max_val = np.max(hot_map)
    # 仅保留可覆盖最多用户的位置 --> 部署一架可以覆盖最多用户的无人机
    hot_map[hot_map < max_val] = 0
    hot_map[hot_map > 0] = 1

    los_map = 200 * hot_map
    min_val = 0
    if len(los_map[np.nonzero(los_map)]):
        min_val = np.min(los_map[np.nonzero(los_map)])

    # 确定无人机的位置
    aim = np.where(los_map == min_val)
    rand_num = random.randint(0, len(aim[0]) - 1)
    aim_loc = [aim[0][rand_num], aim[1][rand_num]]
    # 已覆盖的用户不影响后续部署
    for i in range(len(UE_list) - 1, -1, -1):
        if Distance_check(UE_list[i], aim_loc,R) == 1:
            del UE_list[i]

    return aim_loc

def init_ch(UE_list,numuav,R):

    # input： 用户、地面基站、
    uav_list = []

    glo_pl_f = (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
    glo_height = 18

    con_pattern1, los_pattern1 = Create_Boundary_pattern(R, glo_height, glo_pl_f)

    while (len(uav_list) < numuav):

        uav_loc = Solve_process3(UE_list, con_pattern1, [],R) # 得到一架无人机位置
        uav_list.append(uav_loc)
    centroids = np.zeros((len(uav_list), len(uav_list[0])))
    for i in range(len(uav_list)):
        centroids[i, :] = uav_list[i]
    return centroids


def main2(UE_list,numuav,bs,R):

    # input： 用户、地面基站、
    uav_list = []

    for i in range(len(bs)):

        aim_loc = bs[i]

        for j in range(len(UE_list) - 1, -1, -1):
            # 检查是否建立连接的
            if Distance_check(UE_list[j], aim_loc) == 1:
                del UE_list[j]


    glo_pl_f = (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
    glo_height = 18

    con_pattern1, los_pattern1 = Create_Boundary_pattern(R, glo_height, glo_pl_f)

    while (len(uav_list) < numuav):

        uav_loc = Solve_process3(UE_list, con_pattern1, bs,R) # 得到一架无人机位置
        uav_list.append(uav_loc)

    return uav_list

def UAVlisttoClass(uav_list_init, uavnum, sort,R):
    BS_list = []
    # 无人机
    if uavnum:
        for i in range(uav_list_init.__len__()):
            new_uav = BS()
            new_uav.loc = uav_list_init[i]
            new_uav.sort = sort
            if new_uav.sort == 'MBS':
                new_uav.R = R
            BS_list.append(new_uav)



    return BS_list
def main3(UE_list,sumgbs,R_mbs,sumuav,R_uav):
    UE_list1 = copy.deepcopy(UE_list)
    bs = main2(UE_list, sumgbs, [], R_mbs)

    BS_list = UAVlisttoClass(bs, len(bs), 'MBS', R_mbs)

    for i in range(len(BS_list)):
        BS_list[i].loc.append(0)
        BS_list[i].sort = 'MBS'

        BS_list[i].rate = 0
        BS_list[i].rate_upper = maxRatebs

        BS_list[i].power_upper = uav_power_aver
        BS_list[i].power_allocation = []
        BS_list[i].power = 0
        BS_list[i].power_a2a = 0

    uav = main2(UE_list, sumuav, [], R_uav)
    BS_list = []
    # 无人机初始化
    for i in range(uav.__len__()):

        new_uav = BS()
        loc = uav[i]
        loc.append(h)
        new_uav.loc = loc
        new_uav.sort = 'UAV'

        new_uav.rate = 0
        new_uav.rate_upper = maxRatebs
        new_uav.power_allocation = np.ones(len(UE_list1))*power_t
        new_uav.power_upper = uav_power_upper
        new_uav.power = 0
        new_uav.link = 0

        new_uav.common_power = 0
        new_uav.power_aver_upper = uav_aver_upper
        new_uav.power_a2a = power_t
        BS_list.append(new_uav)
    return BS_list

def main_rsma(UE_list,sumgbs,R_mbs,sumuav,R_uav):
    UE_list1 = copy.deepcopy(UE_list)
    bs = main2(UE_list, sumgbs, [], R_mbs)

    BS_list = UAVlisttoClass(bs, len(bs), 'MBS', R_mbs)

    for i in range(len(BS_list)):
        BS_list[i].loc.append(0)
        BS_list[i].sort = 'MBS'

        BS_list[i].rate = 0
        BS_list[i].rate_upper = maxRatebs

        BS_list[i].power_upper = uav_power_aver
        BS_list[i].power_allocation = []
        BS_list[i].power = 0
        BS_list[i].power_a2a = 0

    uav = main2(UE_list, sumuav, [], R_uav)
    BS_list = []
    # 无人机初始化
    for i in range(uav.__len__()):

        new_uav = BS()
        loc = uav[i]
        loc.append(h)
        new_uav.loc = loc
        new_uav.sort = 'UAV'

        new_uav.rate = 0
        new_uav.rate_upper = maxRatebs
        new_uav.power_allocation = np.ones(len(UE_list1))*power_t
        new_uav.power_upper = uav_power_upper
        new_uav.power = 0
        new_uav.link = 0

        new_uav.eta = 0.5
        new_uav.common_power = new_uav.eta*uav_power_upper
        new_uav.power_aver_upper = uav_aver_upper
        new_uav.power_a2a = power_t
        BS_list.append(new_uav)
    return BS_list
if __name__ == '__main__':

    filepre1 = r"D:\pycharm\deploy_algorithms\data0718\list_victim30_"
    filepre2 = r"D:\pycharm\deploy_algorithms\data0718\list_fireman30_"

    ue_file = filepre1 + '.csv'
    fireman_file = filepre2 + '.csv'

    origin_ue_loc_list = read_csv(ue_file)
    fireman_list = read_csv(fireman_file)
    print(len(fireman_list), len(fireman_list[0]))
    us_px, us_py = [], []

    R = 4
    UE_list = []
    for i in range(len(origin_ue_loc_list)):
        new_ue = UE()
        new_ue.loc = [origin_ue_loc_list[i][60], origin_ue_loc_list[i][61]]
        us_px.append(new_ue.loc[0])
        us_py.append(new_ue.loc[1])
        new_ue.sort = 'ue'
        new_ue.R = R

        UE_list.append(new_ue)

    for i in range(len(fireman_list)):
        new_ue = UE()

        new_ue.loc = [fireman_list[i][60], fireman_list[i][61]]
        us_px.append(new_ue.loc[0])
        us_py.append(new_ue.loc[1])
        new_ue.sort = 'fireman'
        new_ue.R = R

        UE_list.append(new_ue)
    print(us_px, us_py)
    sumgbs = 4
    UE_list1 = copy.deepcopy(UE_list)
    BS_list = main3(UE_list1, sumgbs, R_mbs=4, sumuav=7, R_uav=8)
    """bs = main2(UE_list1, sumgbs, [],R)
    BS_list = UAVlisttoClass(bs, len(bs), 'MBS', R)
    print(1, len(UE_list1))
    uav = main2(UE_list1, 7, [],R=8)

    for i in range(uav.__len__()):
        new_uav = BS()
        new_uav.loc = uav[i]
        new_uav.sort = 'UAV'

        new_uav.R = 8
        BS_list.append(new_uav)

    print(bs,uav)"""


    #addr = r"D:\pycharm\deploy_algorithms\data07\list_ue50_1.csv"
    #result = main2(addr)
    #print(result)