import numpy as np
import random
import math
import time
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import csv
import GAfunction
import power_ga
from scipy.spatial import ConvexHull
from ray import *

UAV_R = 25  # UAV对UE的覆盖范围
R_forest = 25
R_LOS = UAV_R
ROW = 100
COL = 100
X_SIZE = ROW
#UE_TOTAL = 80
UAV_R2 = UAV_R ** 2
numbs = 1

A = 0.25
C = 0.39
E = 0.25
G = 0
H = 0.05
dep_tree = 2 # 植被深度
glo_pl_f = (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
def pl_slant(dist):
    theta = math.atan2(h,dist)
    return A*np.power(f,C)*np.power(dep_tree,E)*np.power(theta+G,H)

class UE:
    def __init__(self):
        self.loc = [-1, -1,0]
        self.list_loc = []


class UAV:
    def __init__(self):
        self.loc = [-1, -1,200]
        self.client_list = []
        self.los_sum = 0


def Get_UE_list(space_ue):
    loc_list = []
    for x in space_ue:
        loc_list.append(x.loc)
    return loc_list


def cal_los_mat(loc, R, height2, pl_f):
    dist_1 = (loc[0] - R) ** 2 + (loc[1] - R) ** 2
    dist = math.sqrt(dist_1 + height2)
    los = 20 * math.log10(dist / 100) + pl_f
    return los


def Create_Boundary_pattern(R1, height, pl_f):
    height2 = height ** 2
    R_pattern = R1 * 2 + 1
    con_pattern = np.zeros((R_pattern, R_pattern))
    mat_pattern = np.zeros((R_pattern, R_pattern))
    for i in range(R1):
        con_pattern[i, R1] = 1
        con_pattern[R_pattern - i - 1, R1] = 1
        con_pattern[R1, i] = 1
        con_pattern[R1, R_pattern - i - 1] = 1

        los = cal_los_mat([i, R1], R1, height2, pl_f)
        mat_pattern[i, R1] = los
        mat_pattern[R_pattern - i - 1, R1] = los
        mat_pattern[R1, i] = los
        mat_pattern[R1, R_pattern - i - 1] = los
        for j in range(R1):
            if ((i - R1) ** 2 + (j - R1) ** 2) <= R1 ** 2:
                con_pattern[i, j] = 1
                con_pattern[i, R_pattern - j - 1] = 1
                con_pattern[R_pattern - i - 1, j] = 1
                con_pattern[R_pattern - i - 1, R_pattern - j - 1] = 1

                los = cal_los_mat([i, j], R1, height2, pl_f)
                mat_pattern[i, j] = los
                mat_pattern[i, R_pattern - j - 1] = los
                mat_pattern[R_pattern - i - 1, j] = los
                mat_pattern[R_pattern - i - 1, R_pattern - j - 1] = los

    con_pattern[R1, R1] = 1
    mat_pattern[R1, R1] = cal_los_mat([R1, R1], R1, height2, pl_f)
    return con_pattern, mat_pattern


def Create_Boundary_list(R1, row_in, col_in, row_x, col_y, patt_con):
    j = row_x
    i = col_y

    mat_main = np.zeros((row_in, col_in))
    mat_con = np.zeros((row_in, col_in))
    if i - R1 < 0:
        col_left = R1 - i
        m_col_left = 0
    else:
        col_left = 0
        m_col_left = i - R1
    if i + R1 >= col_in:
        col_right = R1 + col_in - i
        m_col_right = col_in
    else:
        col_right = 2 * R1 + 1
        m_col_right = i + R1 + 1
    if j - R1 < 0:
        row_up = R1 - j
        m_row_up = 0
    else:
        row_up = 0
        m_row_up = j - R1
    if j + R1 >= row_in:
        row_down = R1 + row_in - j
        m_row_down = row_in
    else:
        row_down = 2 * R1 + 1
        m_row_down = j + R1 + 1

    mat_con[m_row_up:m_row_down, m_col_left:m_col_right] = patt_con[row_up:row_down, col_left:col_right]

    return mat_con


def Func_convexhall(ue_list, tar_loc):
    # 建成是否>2--能否构成封闭图形
    if len(ue_list) > 2:
        ue_nlist = np.asarray(ue_list)
        hull = ConvexHull(ue_nlist)
        hull_list = hull.vertices.tolist()
    else:
        hull_list = list(range(len(ue_list)))
    # tar_loc = [-1, -1]即为首次形成凸包，没有选中的头部边缘节点
    if tar_loc[0] < 0 or tar_loc[1] < 0:
        index = random.choice(hull_list)
    else:
        index = ue_list.index(tar_loc)
    index = hull_list.index(index)
    new_list = hull_list[index:] + hull_list[:index]
    return new_list


def Area_arr_Create(ue_list, index_list, A_row, A_col, R1, patt_con):
    Area_mat = np.zeros((A_row, A_col))
    for x in index_list:
        mat_con = Create_Boundary_list(R1, A_row, A_col, ue_list[x][0], ue_list[x][1], patt_con)
        Area_mat += mat_con
    return Area_mat


def Get_below_2R(ue_list, edge_index, tar_loc, R1):
    s_list = ue_list.copy()
    edge = edge_index.copy()
    edge.sort(reverse=True)
    for x in edge:
        del s_list[x]
    for i in range(len(s_list) - 1, -1, -1):
        if ((s_list[i][0] - tar_loc[0]) ** 2 + (s_list[i][1] - tar_loc[1]) ** 2) > (R1 ** 2):
            del s_list[i]
    return s_list


# u_list:用户剩余用户列表    solve_list:UAV部署坐标      solve_mat:根据UAV动态更新的矩阵     tar_loc:即将检查的边缘起始UAV   A_row A_col:部署区域规模    R1:UAV覆盖半径  con_pattern:联通矩阵模板
def Process_round(u_list, solve_list, solve_mat, tar_loc, A_row, A_col, R1, con_pattern, los_pattern):
    # 生成边缘凸包，并以tar_loc作为头节点
    edge_index = Func_convexhall(u_list, tar_loc)
    N_Area_debug = np.zeros((A_row, A_col))
    N_Area_los_pat = np.zeros((A_row, A_col))
    for x in edge_index:
        N_Area_debug += Create_Boundary_list(R1, A_row, A_col, u_list[x][0], u_list[x][1], con_pattern)
        N_Area_los_pat += Create_Boundary_list(R1, A_row, A_col, u_list[x][0], u_list[x][1], los_pattern)
    N_Area = N_Area_debug * Create_Boundary_list(R1, A_row, A_col, u_list[edge_index[0]][0], u_list[edge_index[0]][1],
                                                 con_pattern)
    max_val = np.max(N_Area)

    N_Area[N_Area < max_val] = 0
    N_Area[N_Area > 0] = 1

    # 只选取2倍半径内的GTs
    s_list = Get_below_2R(u_list, edge_index, [u_list[edge_index[0]][0], u_list[edge_index[0]][1]], R1)
    # 如果内部GTs数量为0,直接选择aim_loc作为部署结果
    if len(s_list) == 0:
        pat_cache = N_Area_los_pat * N_Area
        min_val = np.min(pat_cache[np.nonzero(pat_cache)])
        aim = np.where(pat_cache == min_val)
        rand_num = random.randint(0, len(aim[0]) - 1)
        aim_loc = [aim[0][rand_num], aim[1][rand_num]]

        solve_list.append(aim_loc)
        solve_mat += Create_Boundary_list(R1, A_row, A_col, aim_loc[0], aim_loc[1], con_pattern)
    # 内部GTs数量>0
    elif len(s_list) > 0:
        s_list_mat = np.zeros((A_row, A_col))
        s_list_loss_mat = np.zeros((A_row, A_col))
        for x in s_list:
            s_list_mat += Create_Boundary_list(R1, A_row, A_col, x[0], x[1], con_pattern)
            s_list_loss_mat += Create_Boundary_list(R1, A_row, A_col, x[0], x[1], los_pattern)
        s_list_mat = s_list_mat * N_Area
        max_val = np.max(s_list_mat)

        s_list_mat[s_list_mat < max_val] = 0
        s_list_mat[s_list_mat > 0] = 1
        s_list_loss_mat = s_list_loss_mat + N_Area_los_pat
        s_list_loss_mat = s_list_loss_mat * s_list_mat

        if (len(np.nonzero(s_list_loss_mat)[0])):
            #print(np.nonzero(s_list_loss_mat),len(np.nonzero(s_list_loss_mat)))
            min_val = np.min(s_list_loss_mat[np.nonzero(s_list_loss_mat)])
        else:min_val = s_list_loss_mat[0]
        #min_val = np.min(s_list_loss_mat[np.nonzero(s_list_loss_mat)])

        aim = np.where(s_list_loss_mat == min_val)
        rand_num = random.randint(0, len(aim[0]) - 1)
        aim_loc = [aim[0][rand_num], aim[1][rand_num]]

        solve_list.append(aim_loc)
        solve_mat += Create_Boundary_list(R1, A_row, A_col, aim_loc[0], aim_loc[1], con_pattern)

    # 次序遍历找出第一个未被覆盖的edge-GT
    new_tar_loc = [-1, -1]
    for x in edge_index:
        if solve_mat[u_list[x][0], u_list[x][1]] <= 0:
            new_tar_loc = u_list[x]
            break

    # u_list遍历检查，删除已被覆盖的GTs
    for i in range(len(u_list) - 1, -1, -1):
        if solve_mat[u_list[i][0], u_list[i][1]] > 0:
            del u_list[i]

    return solve_mat, new_tar_loc


def Draw_result(u_list, solve_list, R1, index_pic):
    fig = plt.figure('fig')
    ax = fig.add_subplot(111)  # 111代表1*1的图的第一个子图
    for i in range(len(u_list)):
        plt.plot(u_list[i][0], u_list[i][1], "xy")
    for i in range(len(solve_list)):
        plt.plot(solve_list[i][0], solve_list[i][1], "vr")
        circle = Circle(xy=(solve_list[i][0], solve_list[i][1]), radius=R1, alpha=0.1, color='b')
        ax.add_patch(circle)
        plt.text(solve_list[i][0], solve_list[i][1], str(i + 1), fontsize=15)
    title = 'solve-opt--' + str(index_pic + 1) + '.jpg'
    fig.savefig(title)
    fig.clf()


def Draw_result_line(u_list, solve_list, R1, index_pic):
    fig = plt.figure('fig')
    ax = fig.add_subplot(111)  # 111代表1*1的图的第一个子图

    for i in range(len(solve_list)):
        circle = Circle(xy=(solve_list[i][0] / 10.0, solve_list[i][1] / 10.0), radius=R1 / 10.0, alpha=0.1, color='b')
        ax.add_patch(circle)
        plt.text(solve_list[i][0] / 10.0, solve_list[i][1] / 10.0, str(i + 1), fontsize=15)

    type_ue_row = []
    type_ue_col = []
    for i in range(len(u_list)):
        type_ue_row.append(u_list[i][0] / 10.0)
        type_ue_col.append(u_list[i][1] / 10.0)

    type_uav_row = []
    type_uav_col = []
    for i in range(len(solve_list)):
        type_uav_row.append(solve_list[i][0] / 10.0)
        type_uav_col.append(solve_list[i][1] / 10.0)

    type_ue = ax.scatter(type_ue_row, type_ue_col, marker='x', color='y', linewidth=1.0)
    type_uav = ax.scatter(type_uav_row, type_uav_col, marker='v', color='r')
    ax.legend((type_ue, type_uav), (u'UE', u'UAV'))

    x = []
    y = []
    for i in range(len(solve_list) - 1):
        x.append([solve_list[i][0] / 10.0, solve_list[i + 1][0] / 10.0])
        y.append([solve_list[i][1] / 10.0, solve_list[i + 1][1] / 10.0])
    for i in range(len(x)):
        plt.plot(x[i], y[i], dashes=[6, 2], color='orange', linewidth=1.0)

    plt.xlabel("x(km)")
    plt.ylabel("y(km)")

    plt.xlim(-2, 12)
    plt.ylim(-2, 12)
    ax = plt.gca()
    ax.set_aspect(1)

    title = 'solve-opt--' + str(index_pic + 1) + '.jpg'
    fig.savefig(title, dpi=600)
    fig.clf()


def read_csv(addr):
    ue_list = []
    with open(addr, 'r') as f:
        f = csv.reader(f, delimiter=',')
        next(f)
        for row in f:
            ue_list.append(row)
    for i in range(len(ue_list)):
        ue_list[i] = list(map(int, ue_list[i]))
    return ue_list


# 得到所有凸包节点的列表--应用于首次进行时对每个凸包进行遍历
def Func_convexhall_list(ue_list):
    # 建成是否>2--能否构成封闭图形
    if len(ue_list) > 2:
        ue_nlist = np.asarray(ue_list)
        hull = ConvexHull(ue_nlist)
        hull_list = hull.vertices.tolist() # return：凸包的点的索引值
    else:
        hull_list = list(range(len(ue_list)))
    # tar_loc = [-1, -1]即为首次形成凸包，没有选中的头部边缘节点
    new_list = []
    for i in range(len(hull_list)):
        new_list.append(ue_list[hull_list[i]])
    return new_list


# 地面基站
def addGroundBS(us_1d, basenum):
    coveredplace = []
    density = GAfunction.user_density(us_1d, R_LOS)
    bs_x, bs_y = [], []
    for i in range(basenum):
        k = GAfunction.select_density_tournament(density)
        bs_x.append(int(k / X_SIZE))
        bs_y.append(k % X_SIZE)
        coveredplace.extend(GAfunction.base_coveredplace(int(k), R=R_LOS))
    bs = bs_x, bs_y

    return bs, coveredplace

def main(addr,uavnum):
    A_row = ROW
    A_col = COL
    R1 = UAV_R
    UAV_num_list = []

    origin_ue_loc_list = read_csv(addr)

    ue_list = []
    for i in range(len(origin_ue_loc_list)):
        ue_list.append(origin_ue_loc_list[i])
    us_px, us_py = [], []
    for i in range(ue_list.__len__()):
        us_px.append(ue_list[i][0])
        us_py.append(ue_list[i][1])
    us_p = us_px, us_py
    us_1d = GAfunction.us1d(us_p)
    bs, coveredplace = addGroundBS(us_1d, numbs)

    delnum = 0
    ttmp = []
    for i in range(len(us_px)):
        # 保持这个序号
        for j in range(numbs):
            dist = math.sqrt((us_px[i] - bs[0][j]) ** 2 + (us_py[i] - bs[1][j]) ** 2)
            if dist <= R_LOS:
                ttmp.append(ue_list[i - delnum])
                del ue_list[i - delnum]
                delnum += 1


    starttime = time.time()
    glo_pl_f = (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
    glo_height = 18
    con_pattern, los_pattern = Create_Boundary_pattern(R1, glo_height, glo_pl_f)

    tar_loc_list = Func_convexhall_list(origin_ue_loc_list)  # 得到凸包


    #print("The first hall is " + str(len(tar_loc_list)))
    # 首次部署遍历凸包各节点
    for i in range(len(tar_loc_list)):
        solve_list = []
        solve_mat = np.zeros((ROW, COL))
        up_ue_loc_list = origin_ue_loc_list.copy()
        t_loc = tar_loc_list[i] # 第一个节点的位置
        #while (len(up_ue_loc_list) > 0): # 存在用户未覆盖
        while solve_list.__len__() < uavnum :
            if len(up_ue_loc_list) <= 0:break
            solve_mat, t_loc = Process_round(up_ue_loc_list, solve_list, solve_mat, t_loc, A_row, A_col, R1,
                                             con_pattern, los_pattern)
        Draw_result_line(origin_ue_loc_list, solve_list, R1, i)
        UAV_num_list.append(len(solve_list))
        #print("No." + str(i + 1) + "  finished.")

    min_UAV_num = min(UAV_num_list)
    #print(origin_ue_loc_list.__len__())
    #print("The min is " + str(min_UAV_num))

    #print("Index:")
    # 吞吐量

    UE_list, BS_list = listtrans(solve_list,origin_ue_loc_list)
    solution_space, record_list = power_ga.Solve_process(UE_list, BS_list)
    power_scheme = solution_space[0].mat
    throughput = calc_var_pl(solve_list, origin_ue_loc_list,power_scheme)

    index_list = []
    for i in range(len(UAV_num_list)):
        if UAV_num_list[i] == min_UAV_num:
            index_list.append(i + 1)
    endtime = time.time()
    coverate = calc_coverate(solve_list,origin_ue_loc_list,R_forest)
    #print(index_list)

    #print('方差', calc_var_pl(solve_list, origin_ue_loc_list))

    #print("程序运行时间为：%.8s s" % (endtime - starttime))  # 时间显示到微秒
    #print(solve_list.__len__(),calc_var_pl(solve_list, origin_ue_loc_list),endtime-starttime)
    return solve_list.__len__(),endtime-starttime,throughput

def listtrans(solve_list,ue_list):
    UE_list = []
    for i in range(len(ue_list)):
        ue = UE()
        ue.loc = [ue_list[i][0], ue_list[i][1], 0]
        ue.snr = -5
        ue.snrth = -5
        UE_list.append(ue)
    BS_list = []
    for j in range(len(solve_list)):
        bs = UAV()
        bs.loc = [solve_list[j][0],solve_list[j][1],200]
        bs.power_aver = 40
        bs.maxRate = 500
        BS_list.append(bs)
    return UE_list,BS_list


h = 100
r_th = 0 #dB
# ----------------------fspl
f = 1.4e3 #MHz
d0 = 1
c = 3e8
alpha = 3.5
pi = math.pi
glo_pl_f = (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
def fspl(dist):
    # d单位是米
    if dist ==0:
        return 20*np.log10(4*pi*f*d0/c)
    else:
        return 20*np.log10(4*pi*f*d0/c)+10*alpha*math.log10(dist/d0)
# 服务上限的矩阵
def connectionUavUe(us_p, uav_posi, numMax):
    numuav = len(uav_posi[0])
    numue = len(us_p[0])

    # distance matrix
    uelinked = np.zeros(numue)
    distUavUe = np.zeros([numue, numuav])
    for i in range(numue):
        for j in range(numuav):
            dist = math.sqrt((uav_posi[0][j] - us_p[0][i]) ** 2 + (uav_posi[1][j] - us_p[1][i]) ** 2)
            distUavUe[i][j] = dist if dist <= R_LOS else 200  # 这里假设200是个很大的距离
    # connection matrix
    # element: ue index
    linked = []
    for i in range(numuav):
        linked.append([])
    for i in range(numue):
        while (uelinked[i] == 0):  # 第i个用户没有连接无人机
            minDistIdx = np.argmin(distUavUe[i])  # 与用户i最近的无人机
            if distUavUe[i][minDistIdx] > R_LOS:
                break
            if (len(linked[minDistIdx]) < numMax):  # 建立连接
                linked[minDistIdx].append(i)
                uelinked[i] = 1
            else:
                distUavUe[i][minDistIdx] = 200

    return uelinked, linked

def calc_coverate(uav_list,ue_list,R=R_forest):
    uav_x, uav_y = [], []
    us_px, us_py = [], []
    for i in range(ue_list.__len__()):
        us_px.append(ue_list[i][0])
        us_py.append(ue_list[i][1])

    for i in range(uav_list.__len__()):
        uav_x.append(uav_list[i][0])
        uav_y.append(uav_list[i][1])

    uecovered = np.zeros(len(us_px))
    for i in range(len(us_px)):
        for j in range(len(uav_x)):
            dist = math.sqrt((uav_x[j] - us_px[i]) ** 2 + (uav_y[j] - us_py[i]) ** 2)
            if dist <= R:
                uecovered[i] = 1
    return len(uecovered.nonzero()[0]) / len(us_px)

def calc_var_pl(uav_list,origin_ue_loc_list,power_scheme):

    uav_x,uav_y = [],[]
    us_px, us_py = [], []
    for i in range(origin_ue_loc_list.__len__()):
        us_px.append(origin_ue_loc_list[i][0])
        us_py.append(origin_ue_loc_list[i][1])
    us_p = us_px,us_py
    for i in range(uav_list.__len__()):
        uav_x.append(uav_list[i][0])
        uav_y.append(uav_list[i][1])
    uav_posi = uav_x,uav_y
    numuav, usernum = len(uav_x), len(us_px)

    uelinked,linked = connectionUavUe(us_p, uav_posi,numMax=100)

    # 吞吐量
    p_s = 40
    p_n = -115
        # 用户的吞吐量

    throughput = 0
    for i in range(usernum):

        for j in range(numuav):
            if i in linked[j]:
                dist = math.sqrt((uav_x[j] - us_px[i]) ** 2 + (uav_y[j] - us_py[i]) ** 2)
                snr = power_scheme[j][i] - pathlossA2G(dist,200) - p_n
                throughput += math.log10(1 + np.power(10,snr/10))


    return throughput

def calc_var_rsma_pl(uav_list,origin_ue_loc_list,power_scheme):

    uav_x,uav_y = [],[]
    us_px, us_py = [], []
    for i in range(origin_ue_loc_list.__len__()):
        us_px.append(origin_ue_loc_list[i][0])
        us_py.append(origin_ue_loc_list[i][1])
    us_p = us_px,us_py
    for i in range(uav_list.__len__()):
        uav_x.append(uav_list[i][0])
        uav_y.append(uav_list[i][1])
    uav_posi = uav_x,uav_y
    numuav, usernum = len(uav_x), len(us_px)

    uelinked,linked = connectionUavUe(us_p, uav_posi,numMax=100)

    # 吞吐量
    p_s = 40
    p_n = -115
        # 用户的吞吐量

    throughput = 0
    for i in range(usernum):

        for j in range(numuav):
            if i in linked[j]:
                dist = math.sqrt((uav_x[j] - us_px[i]) ** 2 + (uav_y[j] - us_py[i]) ** 2)
                ray = np.mean(get_rayleigh_gain())
                snr = power_scheme[j][i] - pathlossA2G(dist,200) - p_n - 10 * np.log10(ray)
                throughput += math.log10(1 + np.power(10,snr/10))


    return throughput

def pathlossA2G(dist2,h):
    '''
    :param dist2: 基站在水平面的投影与用户的距离
    :param h:高度
    :return:空地倾斜路径损耗
    '''
    f = 1.4e9  # MHz
    d0 = 1
    c = 3e8
    alpha = 3.5
    pi = math.pi
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    H = 0.05
    dep_tree = 2  # 植被深度
    dist3d = math.sqrt((dist2*100)**2 + h**2)
    fspl = 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10(dist3d / d0)
    theta = math.atan2(h, dist2)
    slant =  A * np.power(f / math.pow(10, 6), C) * np.power(dep_tree, E) * np.power(theta + G, H)
    return fspl + slant

if __name__ == '__main__':
    addr = r"D:\pycharm\va_3d\data\list_ue90_2.csv"
    result = main(addr,7)
    print(result)