import copy

import numpy as np
import random
import math
import operator
import pickle
import time
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import progressbar
import csv
from virtual_force import *


ROW = 100
COL = 100
POP_SIZE = 200 # 种群规模
UAV_R = 18
snr_thuav = 3
f = 1.4e9 #MHz
d0 = 1
c = 3e8
alpha = 3.5
pi = math.pi
A = 0.25
C = 0.39
E = 0.25
G = 0
H = 0.05
dep_tree = 2 # 植被深度

N_power = -130 #dBm
"""
    input : ue_position uav_position
    output : power
"""

class UE:
    def __init__(self):
        self.loc = [-1, -1,0]
        self.list_loc = []   # 所有部署UAV可覆盖该UE的位置
        self.rate = -1

class UAV:
    def __init__(self):
        self.loc = [-1, -1,0]
        self.client_list = []
        self.los_sum = 0
        self.power = 30


class SOLUTIN:
    def __init__(self):
        self.mat = np.empty((ROW, COL)) # ? size : uavnum * uenum
        self.fit = -1
        self.rate_total = 0



# 计算两个坐标距离的平方
# input:     list_A=[x, y], list_B=[x, y]
def cal_distance2(list_A, list_B):
    return (list_A[0] - list_B[0]) ** 2 + (list_A[1] - list_B[1]) ** 2


def randomNumVaccinePersonTotal(maxValue, num):
    '''生成总和固定的整数序列
    maxValue: 序列总和
    num：要生成的整数个数

    return
    per_all_persons:list,指定 num个接种点各自待接种的人数
    '''
    maxValue = int(maxValue)
    suiji_ser = random.sample(range(1, maxValue), k=num - 1)  # 在1~maxValue之间，采集20个数据
    suiji_ser.append(0)  # 加上数据开头
    suiji_ser.append(maxValue)
    suiji_ser = sorted(suiji_ser)
    per_all_persons = [suiji_ser[i] - suiji_ser[i - 1] for i in range(1, len(suiji_ser))]  # 列表推导式，计算列表中每两个数之间的间隔

    return per_all_persons

def randomInitPower(uespace,uavspace,linked,power_scheme):
    '''
    根据用户需求对原功率分配进行调整作为初始化,每个染色体调用初始一次
    input：原功率分配方案，用户的通信需求
    output：新的功率分配方案
    :return:
    '''
    power_new = copy.deepcopy(power_scheme)

    #计算权重,3种snr
    for i in range(len(uavspace)):
        for j in range(len(linked[i])):
            sizer = 0.05
            ueidx = linked[i][j]
            # print(ueidx,'ue')
            power_new[i][ueidx] += np.random.uniform(-1*sizer-uespace[ueidx].snrth*0.2,sizer+uespace[ueidx].snrth*0.2)
    return power_new


# 进行POP_SIZE次生成初始化种群
# input:     list of class:UE      list of class:SOLUTION
# 直接完成修改
def Create_solution_space(uespace,uavspace,solution_list,linked,power_scheme):
    power_init = copy.deepcopy(power_scheme)
    for i in range(POP_SIZE):
        cache_item = SOLUTIN()
        cache_matrix = np.zeros((len(uavspace),len(uespace)))
        power = randomInitPower(uespace,uavspace,linked,power_init)

        for j in range(len(uavspace)):
            for k in range(len(linked[j])):
                cache_matrix[j][linked[j][k]] = power[j][linked[j][k]]
        cache_item.mat = cache_matrix

        solution_list.append(cache_item)



# 生成UAV部署的matrix和坐标list
# input:     list of class:UE
# output:    matrix of UAV       list of class:UAV
def Create_UAV_matrix(space):
    cache_list = []
    cache_uav_list = []
    matrix_c = np.zeros((ROW, COL))
    for i in range(len(space)):
        loc_cache = random.choice(space[i].list_loc)
        if not loc_cache in cache_list:
            cache_uav = UAV()
            cache_uav.loc = loc_cache[:]
            cache_list.append(loc_cache)
            cache_uav_list.append(cache_uav)
            matrix_c[loc_cache[0], loc_cache[1]] = 1
    return matrix_c, cache_uav_list



# 计算适应度
# input:     mat of UAV      loc_list of UAV
# output:    fitness_num
def Fit_func(mat, ue_space,uav_space):

    fit = 0
    for i in range(len(ue_space)):
        for j in range(len(uav_space)):
            if mat[j][i]:
                fit += cal_rate(mat[j][i],ue_space[i],uav_space[j])

    return fit

def fspl(dist):
    if dist ==0:
        return 20*np.log10(4*pi*f*d0/c)
    # d单位是米
    else :
        #print(np.log10(4*pi*f*d0/c))
        return 20*np.log10(4*pi*f*d0/c)+10*alpha*math.log10(dist/d0)
def pl_slant(dist,h):
    # f :MHz

    theta = math.atan2(h,dist)
    return A*np.power(f/math.pow(10,6),C)*np.power(dep_tree,E)*np.power(theta+G,H)
def total_pl(dist,h):
    pl = fspl(dist)+pl_slant(dist,h)

    return pl
# input: power of uav to ue ,position of uav and ue
def cal_rate(power,ue,uav):

    dist = math.sqrt(((uav.loc[0] - ue.loc[0])*100) ** 2 + ((uav.loc[1] - ue.loc[1])*100) ** 2 + (
                (uav.loc[2] - ue.loc[2])) ** 2)

    pl_dB = total_pl(dist , uav.loc[2])  # dist :m

    snr = power - pl_dB - N_power

    return np.log10(1+np.power(10, 1 * snr / 10))

# 计算适应度
# input:     space of class:SOLUTION
# output:    fitness_num
def cal_fit(ue_space,uav_space,solution_space,linked):


    for k in range(len(solution_space)):
        # 用户的吞吐量
        throughput = 0
        for i in range(len(ue_space)):
            for j in range(len(uav_space)):
                if i in linked[j]:
                    dist = math.sqrt((uav_space[j].loc[0] - ue_space[i].loc[0]) ** 2 + (
                                uav_space[j].loc[1] - ue_space[i].loc[1]) ** 2)

                    pldB = pathlossA2G(dist, uav_space[j].loc[2])
                    pl = dB2powerratio(pldB)
                    snr = (solution_space[k].mat[j][i] / pl) / dBm2w(Noise)
                    throughput += np.log10(1 + snr)

        solution_space[k].fit = throughput


def pathlossA2G(dist2,h):
    '''
    :param dist2: 基站在水平面的投影与用户的距离
    :param h:高度
    :return:空地倾斜路径损耗
    '''
    f = 1.4e9  # MHz
    d0 = 1
    c = 3e8
    alpha = 3.5
    pi = math.pi
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    H = 0.05
    dep_tree = 2  # 植被深度
    dist3d = math.sqrt((dist2*30)**2 + h**2)
    fspl = 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10(dist3d / d0)
    theta = math.atan2(h, dist2*30)
    slant =  A * np.power(f / math.pow(10, 6), C) * np.power(dep_tree, E) * np.power(theta + G, H)
    return fspl + slant




# 根据适应度降序排列,add经过取反处理，增加数量越多add越小
def Sort_fitness(space):
    cmpfun = operator.attrgetter('fit')
    space.sort(key=cmpfun, reverse=True)


def RouletteWheelSelection(space):
    sum = 0
    for i in range(POP_SIZE):
        sum += space[i].fit

    pair1 = random.randint(1, int(sum - 1))
    pair2 = random.randint(0, int(sum - 1))
    while pair1 == pair2:
        pair2 = random.randint(0, int(sum - 1))
    i = 0
    while pair1 > 0:
        pair1 = pair1 - space[i].fit
        i += 1
    pair1 = i - 1
    j = 0
    while pair2 >= 0:
        pair2 = pair2 - space[j].fit
        j += 1
    pair2 = j - 1

    if pair1 == pair2:
        if pair1 > 0:
            pair1 -= 1
        else:
            pair2 -= 1
    return pair1, pair2


# 重写Rebuild_loc_list函数，尝试用np.where看是否有加速效果
def Rebuild_loc_list(solution_c):
    new_list = []
    cache_list = np.where(solution_c.mat > 0)
    for i in range(len(cache_list[0])):
        uav_cache = UAV()
        uav_cache.loc = [cache_list[0][i], cache_list[1][i]]
        new_list.append(uav_cache)
    solution_c.list = new_list


# 产生一对子代并加入子代解空间
# input:     space of class:SOLUTION     space for class:SOLUTION of children
def Rand_Exchange(space, children_space):
    pair1, pair2 = RouletteWheelSelection(space)
    c_row = random.randint(0, len(space[0].mat)-1)
    c_col = random.randint(0, len(space[0].mat[0])-1)
    child1 = SOLUTIN()
    child2 = SOLUTIN()
    child1.mat = np.copy(space[pair2].mat)
    child2.mat = np.copy(space[pair1].mat)

    change_in = random.randint(1, 5)

    # 修改切片方式
    child1.mat[0:c_row, 0:c_col] = space[pair1].mat[0:c_row, 0:c_col]
    child1.mat[c_row:ROW, c_col:COL] = space[pair1].mat[c_row:ROW, c_col:COL]
    child2.mat[0:c_row, 0:c_col] = space[pair2].mat[0:c_row, 0:c_col]
    child2.mat[c_row:ROW, c_col:COL] = space[pair2].mat[c_row:ROW, c_col:COL]

    Rebuild_loc_list(child1)
    Rebuild_loc_list(child2)

    children_space.append(child1)
    children_space.append(child2)


# 产生整代的子代(已完成联通性检查)
# input:     space of class:SOLUTION     当前代数(仅用于显示)
# output:    space of class:SOLUTION of children
def Create_children(space):
    children_space = []
    size = math.ceil(POP_SIZE / 4)
    for i in range(size):
        Rand_Exchange(space, children_space)
    return children_space


# 清空class:SOLUTION中的client_list
def Clear_client_list(class_solution):
    class_solution.add = 0
    for i in range(len(class_solution.list)):
        class_solution.list[i].client_list = []


# UE全覆盖检查修复
# input: space of UE     class:SOLUTION
# 直接完成修改
def All_UE_check(space_ue, class_solution):
    Clear_client_list(class_solution)
    un_list = []
    for i in range(len(space_ue)):
        loc_cache = space_ue[i].loc
        j = len(class_solution.list) - 1
        while j >= 0:
            r = cal_distance2(loc_cache, class_solution.list[j].loc)
            if r <= UAV_R2:
                class_solution.list[j].client_list.append(loc_cache)
                break
            else:
                j = j - 1
        if j < 0:
            class_solution.add = class_solution.add - 1
            loc_cache = random.choice(space_ue[i].list_loc)
            new_uav = UAV()
            new_uav.loc = loc_cache
            new_uav.client_list.append(space_ue[i].loc)
            class_solution.list.append(new_uav)
            class_solution.mat[loc_cache[0], loc_cache[1]] = 1


def Generation2pic(fit_list):
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
    plt.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题
    fig = plt.figure()
    plt.plot(fit_list, color='black', linewidth=2.0)

    plt.xlabel('遗传代数')
    plt.ylabel('适应度')

    plt.xlim(0, 100)
    plt.grid()

    plt.show()
    fig.savefig("gen.png", dpi=600)





def Save_result(class_solution,file_name):
    #file_name = 'uav_save.pickle'
    with open(file_name, "wb") as fp:
        pickle.dump(class_solution, fp, protocol=pickle.HIGHEST_PROTOCOL)


# solution_list转为csv
def Solution2csv(space_solution, file_name_csv):
    csv_name = file_name_csv
    csv_file = open(csv_name, 'w', newline='')
    writer = csv.writer(csv_file)
    in_data = [['x', 'y', 'number of UE']]
    for i in range(len(space_solution)):
        in_data = [['x', 'y', 'number of UE', 'fit', space_solution[i].fit, 'add', space_solution[i].add]]
        for j in range(len(space_solution[i].list)):
            c_l = [space_solution[i].list[j].loc[0], space_solution[i].list[j].loc[1],
                   len(space_solution[i].list[j].client_list)]
            in_data.append(c_l)
        in_data.append([])
        writer.writerows(in_data)
    csv_file.close()




# solution_list转为csv
def Solution2csv_los(space_solution, file_name_csv, ue_num):
    csv_name = file_name_csv
    csv_file = open(csv_name, 'w', newline='')
    writer = csv.writer(csv_file)
    for i in range(len(space_solution)):
        in_data = [['x', 'y', 'number of UE', 'los_ave', 'fit', space_solution[i].fit, 'add', space_solution[i].add,
                    'los_total_ave', space_solution[i].los_total / ue_num]]
        for j in range(len(space_solution[i].list)):
            client_num = len(space_solution[i].list[j].client_list)
            c_l = [space_solution[i].list[j].loc[0], space_solution[i].list[j].loc[1], client_num,
                   (space_solution[i].list[j].los_sum) / client_num]
            in_data.append(c_l)
        in_data.append([])
        writer.writerows(in_data)
    csv_file.close()


def generation2csv(list_record):
    csv_name = "generation_record.csv"
    csv_file = open(csv_name, 'w', newline='')
    writer = csv.writer(csv_file)
    n_list = []
    n_list.append(["fit"])
    for i in range(len(list_record)):
        part = []
        part.append(list_record[i])
        n_list.append(part)
    writer.writerows(n_list)
    csv_file.close()

def comm_connection(UE_list,BS_list):
    '''
    :param UE_list: 用户列表
    :param BS_list: 基站列表
    :return: 01连接关系矩阵 , size：基站数*用户数
    '''
    dist = np.zeros([len(UE_list),len(BS_list)])
    for i in range(len(UE_list)):
        for j in range(len(BS_list)):

            dist[i][j] = math.sqrt(
                (BS_list[j].loc[0] - UE_list[i].loc[0]) ** 2 + (BS_list[j].loc[1] - UE_list[i].loc[1]) ** 2)

    power_scheme = np.ones([len(BS_list),len(UE_list)])*0.2
    link = np.zeros([len(BS_list),len(UE_list)])
    powertmp = np.zeros(len(BS_list))
    linked = []
    # 为用户选择基站，从用户的角度出发
    for i in range(len(UE_list)):
        # 选择信噪比最大的基站建立连接

        idx = [k for k, x in enumerate(dist[i]) if abs(dist[i][k] - min(dist[i])) < 0.1]
        # 如果存在多个最大信噪比 选择基站速率最小的，因为提供服务的空间更大

        idx0 = -1
        # 前提是仍能继续提供功率
        if len(idx) >= 1:
            for k in idx:
                if powertmp[k] < 2.4:
                    idx0 = k
        if idx0>=0 and dist[i][idx0] <= UAV_R:
            link[idx0][i] = 1
            powertmp[idx0] += power_scheme[idx0][i]

    # 未建立连接 功率为0
    for i in range(len(BS_list)):
        for j in range(len(UE_list)):
            if link[i][j] == 0:
                power_scheme[i][j] = 0

    for i in range(len(BS_list)):
        tmp1 = []
        for j in range(len(UE_list)):
            if link[i][j] == 1:
                tmp1.append(j)
        linked.append(tmp1)
        BS_list[i].link = len(tmp1)

    return link,linked,power_scheme


def connectionUavUe_class(UE_list, BS_list,):
    numuav = len(BS_list)
    numue = len(UE_list)

    maxpower = 2
    # distance matrix
    uelinked = np.zeros(numue)
    snrUavUe = np.zeros([numue, numuav])
    rateUavUe = np.zeros([numue, numuav])
    rateUav = np.zeros(numuav)
    distUavUe = np.zeros([numue,numuav])

    powertmp = np.zeros(len(BS_list))
    power_scheme = np.ones([len(BS_list), len(UE_list)]) *0.15

    for i in range(numue):
        for j in range(numuav):

            dist = math.sqrt(
                (BS_list[j].loc[0] - UE_list[i].loc[0]) ** 2 + (BS_list[j].loc[1] - UE_list[i].loc[1]) ** 2)
            distUavUe[i][j] = dist

            pldB = pathlossA2G(dist, BS_list[j].loc[2])
            pl = dB2powerratio(pldB)
            snr = (BS_list[j].power_aver / pl) / dBm2w(Noise)

            rate = np.log10(1 +  snr / 10)
            #print('pl',pl_dB,pl,snr,rate)
            snrUavUe[i][j] = snr
            rateUavUe[i][j] = rate

    # connection matrix
    # element: ue index
    linked = []
    for i in range(numuav):
        linked.append([])
    for i in range(numue):
        # 第i个用户没有连接无人机
        #for j in range(numuav):

        maxSnrIdx = np.argmax(snrUavUe[i])  # 与用户i最近的无人机

        if snrUavUe[i][maxSnrIdx] < UE_list[i].snr:#BS_list[maxSnrIdx].snrth: #无法连接
            break

        if powertmp[maxSnrIdx] < maxpower:
            powertmp[maxSnrIdx] += power_scheme[maxSnrIdx][i]
            linked[maxSnrIdx].append(i)
            uelinked[i] = 1
            #print(rateUav)
    power_scheme = np.zeros([len(BS_list), len(UE_list)])
    for i in range(len(BS_list)):
        for j in range(len(linked[i])):
            power_scheme[i][linked[i][j]] = maxpower / len(linked[i])  # uav_space[i].power_aver
            if power_scheme[i][linked[i][j]] > BS_list[i].power_aver_upper:
                power_scheme[i][linked[i][j]] = 0.18

    for i in range(numuav):

        BS_list[i].serveUser = len(linked[i])
    return uelinked, linked,power_scheme

def limit(solution_space):
    for i in range(len(solution_space)):
        for j in range(len(solution_space[i].mat)):
            for k in range(len(solution_space[i].mat[0])):
                if solution_space[i].mat[j][k] > 0.2:
                    solution_space[i].mat[j][k] = 0.2
        for j in range(len(solution_space[i].mat)):
            if sum(solution_space[i].mat[0])>2 :
                eita = sum(solution_space[i].mat[0])/2
                for k in range(len(solution_space[i].mat[0])):
                    solution_space[i].mat[j][k] *= eita

def Solve_process(ue_space,uav_space):

    widgets = ["task-schedule: ", progressbar.Percentage(), " ", progressbar.Bar(), " ", progressbar.Timer(), '   ',
               progressbar.ETA()]

    record_list = []  # 用于记录收敛状况
    solution_space = []  # 存放解空间
    uelinked, linked,power_scheme = comm_connection(ue_space,uav_space)#connectionUavUe_class(ue_space, uav_space)
    #print("Start create initial space...")
    # power_scheme = np.zeros([len(uav_space),len(ue_space)])
    # for i in range(len(uav_space)):
    #     power_scheme[i] = uav_space[i].power_allocation

    # 生成初始化种群
    Create_solution_space(ue_space,uav_space, solution_space,linked,power_scheme)
    #print("Init finished.")

    # 不允许超过最大值
    limit(solution_space)


    # 计算适应度
    cal_fit(ue_space,uav_space,solution_space,linked)
    #print(solution_space[0].fit,solution_space[0].mat)
    # 排序
    Sort_fitness(solution_space)

    record_list.append(solution_space[0].fit)


    num_gen = 200
    pbar = progressbar.ProgressBar(maxval=num_gen, widgets=widgets).start()
    start = time.time()

    for n in range(num_gen):
        pbar.update(n + 1)
        children_space = Create_children(solution_space) # ？
        limit(solution_space)
        cal_fit(ue_space,uav_space,children_space,linked)
        solution_space.extend(children_space)
        #cal_fit(ue_space,uav_space,solution_space)


        Sort_fitness(solution_space)

        solution_space = solution_space[:POP_SIZE]
        record_list.append(solution_space[0].fit)



    pbar.finish()

    return solution_space, record_list




def Save_process(solution_space, record_list):
    # 适应度收敛曲线
    #Generation2pic(record_list)
    print("Start to save result...")
    # 保存部署结果(opt将直接读取该结果进行)
    Save_result(solution_space)
    # 保存部署结果为csv文件
    Solution2csv_los(solution_space, "uav_save.csv", UE_TOTAL)
    # 保存适应度值变化至csv文件
    #generation2csv(record_list)
    print("Save finished.")


def main(file):

    file_name = file + ".pickle"
    # 从文件中读取ue_space
    #file_name = 'ue_save.pickle'
    with open(file_name, 'rb') as fp:
        ue_space = pickle.load(fp)

    # GA-process
    starttime = time.time()

    ue_list = []
    for i in range(len(ue_space)):
        ue_list.append(ue_space[i].loc)

    uav_space = []
    uav = UAV()
    uav.loc = [1,1]
    uav_space.append(uav)
    solution_space = []  # 存放解空间

    solution_space, record_list = Solve_process(ue_space,uav_space)




    print(len(solution_space),solution_space)




if __name__ == '__main__':
    addr = r"D:\pycharm\va_3d\data\list_ue60_2.csv"

    UE_list = read_UElist(addr)
    print(len(UE_list), UE_list[0].loc)

    # 初始化
    numuav = 7  # int(avernum[int(len(us_px) / 10 - 5)]) - 2
    UE_list1 = UE_list.copy()
    BS_list = greedy_init.main3(UE_list1, 0, R_mbs=15, sumuav=numuav, R_uav=R_uav)

    solution_space, record_list = Solve_process(UE_list, BS_list)
    print(record_list)
    plt.figure()
    x = np.arange(len(record_list))
    plt.plot(x,record_list)
    plt.show()
    print(solution_space[0].mat)
