import math
import scipy.integrate
import numpy as np
from sklearn.cluster import KMeans

import GDOP
from itertools import combinations
import copy

import Users
# from Match import Match
import km

def calculate_3d_distance(a, b):
    return math.sqrt(math.pow(a[0] - b[0], 2) + math.pow(a[1] - b[1], 2) + math.pow(a[2] - 0, 2))


# 计算3d距离
def calculate_2d_distance(a, b):
    return math.sqrt(math.pow(a[0] - b[0], 2) + math.pow(a[1] - b[1], 2))


# 计算森林信道路径损耗
def forest_path_loss(uav, user):
    dis = calculate_3d_distance(uav, user)
    alpha = 3.5
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    B = 1
    H = 0.05
    f_c = 2 * 10 ** 9
    c = 1 * 10 ** 8
    path_loss = 10 * alpha * math.log10(dis)
    theta = math.asin(uav[2] / dis)
    theta = math.degrees(theta)
    dis1 = 5 * dis / uav[2]  # 树叶的高度
    slant = A * math.pow(f_c / 1e6, C) * math.pow(dis1, E) * math.pow(theta + G, H)
    fspl = 20 * math.log10(4 * math.pi * 1 * f_c / c)
    L_K_N = slant + fspl + path_loss
    return L_K_N


def expect_L(l1):
    sigma = 3
    R_u = lambda x1: ((10 ** (-(l1 + x1) / 10))) * 1 / (math.sqrt(2 * math.pi) * sigma) * math.exp(
        -(x1 ** 2) / (2 * sigma ** 2))
    R2, err = scipy.integrate.quad(R_u, -float("100"), float("100"))
    return R2

def path_loss_dis(dis,h):
    a = 9.6
    b = 0.28
    los = 1
    nlos = 20
    freePathLoss = 20 * math.log10(dis) + (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
    angle = math.asin(h / dis)
    angle = angle / math.pi * 180
    plos = 1 / (1 + a * math.exp(-b * (angle - a)))
    pnlos = 1 - plos
    pl = plos * los + pnlos * nlos + freePathLoss
    return pl

def free_path_loss(dis):
    freePathLoss = 20 * math.log10(dis) + (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
    return freePathLoss

# 计算snr
def cal_snr(uav, user, p):
    # loss = forest_path_loss(uav, user)
    dis=calculate_3d_distance(uav,user)
    loss=path_loss_dis(dis,uav[2])
    noise_power = (10 ** (-117 / 10)) / 1000
    snr = p * (10 ** (-loss / 10)) / noise_power
    return 10 * math.log10(snr)


# 计算数据传输速率
def cal_data_rate(snr, bandwidth):
    date_rate = bandwidth * math.log2(1 + snr)
    return date_rate


def dbmToW(p):
    return pow(10, p / 10) / 1000

#获得每个用户的定位无人机组
def get_user_loc_uav(users,p_association):
    user_loc_uav={}
    for i in range(len(users)):
        user_loc_uav[i]=[]
    for i in range(len(p_association)):
        for j in range(len(p_association[i])):
            if p_association[i][j]==1:
                user_loc_uav[j].append(i)
    return user_loc_uav

#贪心关联
def get_association(users,uavs):

    return


def getUavInitialPos(k, users):
    model = KMeans(n_clusters=k,  # 聚类簇数
                   random_state=1,  # 决定质心初始化的随机数生成，使用int使随机性具有确定性。
                   max_iter=300,  # 执行一次k-means算法所进行的最大迭代数，默认300
                   ).fit(users)
    center = model.cluster_centers_
    uavs = []
    for item in center:
        # height = random.randint(100, 800)
        height = 100
        position = [int(item[0]), int(item[1]), height]
        uavs.append(position)
    # print(center)
    return uavs


def get_c_association(users,uavs_pos,serveMax):
    user_num = len(users)
    uavs = [k for k in range(len(uavs_pos))]
    uav_num = len(uavs)
    # print(distance_sum)
    weight = np.zeros((user_num, uav_num), dtype=np.float16)
    graph = []
    for i in range(user_num):
        for j in uavs:
            # weight[i][j]=-self.distance_c[i][j]/(distance_sum[i]-self.distance_c[i][j])
            weight[i][j] = -calculate_3d_distance(uavs_pos[j],users[i])
    dic = {}
    for i in range(user_num):
        for j in uavs:
            for number in range(serveMax):
                graph.append((i, j + 100 * number, weight[i][j]))
                dic[j + 100 * number] = j


    connection = km.run_kuhn_munkres(graph)
    c_connection = {}
    for connect in connection:
        c_connection[connect[0]] = dic[connect[1]]
    # print(c_connection)
    return c_connection

def get_uav_to_user(c_connection,uavs):
    uav_to_user={}
    for i in range(len(uavs)):
        uav_to_user[i]=[]
    for i in c_connection:
        uav_to_user[c_connection[i]].append(i)
    return uav_to_user

# 选择好通信的无人机后，确定其他两架无人机与该无人机对用户获得最优的gdop值
def get_p_association(users,uavs,c_connection,service_num):
    p_connection={}
    for i in c_connection:
        p_connection[i]=[]
    if len(uavs)==3:
        for i in range(len(users)):
            uavs_all=[i for i in range(len(uavs))]
            uavs_all.remove(c_connection[i])
            p_connection[i]=[uavs_all[0],uavs_all[1]]
    else:
        uavs_initial = [i for i in range(len(uavs))]
        uav_group_all = list(combinations(uavs_initial, 3))
        uav_serive_num = [service_num for i in range(len(uav_group_all))]
        uav_group_dic={}
        for i in range(len(uav_group_all)):
            uav_group_dic[uav_group_all[i]]=i
        # uav_serive_num=[service_num for i in range(len(uavs))]
        for i in range(len(users)):
            #获得所有信道数量大于0的无人机编号
            # uav_need=[i for i in range(len(uavs)) if uav_serive_num[i]>0]
            uav_need = [i for i in range(len(uav_group_all)) if uav_serive_num[i] > 0]
            uav_groups = [uav_group_all[i] for i in uav_need]#所有可用的uav_group
            #移除通信的无人机
            uav_group=[]
            uav_dic={}
            count=0
            for group in uav_groups:
                if c_connection[i] in group:
                    tmp=[group[0],group[1],group[2]]
                    tmp.remove(c_connection[i])
                    uav_group.append(tmp)
                    uav_dic[count]=group#找到group对饮的下标
                    count+=1
                    # uav_need.remove(c_connection[i])
            # uav_group=list(combinations(uav_need, 2))

            gdop_min = float('inf')
            index = 0
            for j in range(len(uav_group)):
                uavs_pos = [uavs[c_connection[i]], uavs[uav_group[j][0]],uavs[uav_group[j][1]]]
                gdop_current = GDOP.calculate_gdop(uavs_pos, copy.copy(users[i]))
                if gdop_current < gdop_min:
                    gdop_min = gdop_current
                    index = j
            # print(gdop_min)
            p_connection[i].extend([uav_group[index][0], uav_group[index][1]])
            uav_serive_num[uav_group_dic[uav_dic[index]]]-=1
    # print(p_connection)
    return p_connection


def get_p_by_snr(snr,uav,user,noise):
    dis = calculate_3d_distance(uav, user)
    pl= path_loss_dis(dis, uav[2])
    # pl=path_loss_dis(uav,user)
    noise=dbmToW(noise)
    p=(10**(snr/10))*noise/10**(-pl/10)
    # print(p)
    return p

def get_rest_p(users,uavs,p_connection,snr_thre=0,noise=-117,pmax=0.1):
    p_rest=[pmax for i in range(len(uavs))]
    for user_index in p_connection:
        uav_group=p_connection[user_index]
        for uav_index in uav_group:
            p_need=get_p_by_snr(snr_thre,uavs[uav_index],users[user_index],noise)
            p_rest[uav_index]-=p_need

    return p_rest


if __name__ == '__main__':
    users=Users.getUsers("./data.csv")
    uavs=getUavInitialPos(3,users)
    c_connection=get_c_association(users,uavs,math.ceil(len(users)/len(uavs)))
    p_connection=get_p_association(users,uavs,c_connection,math.ceil(len(users)*2/len(uavs)))
    print(get_p_by_snr(20,uavs[2],users[0],-117))
    # print(get_rest_p(users,uavs,p_connection))