import math
import scipy.integrate
import numpy as np
from sklearn.cluster import KMeans

import GDOP
from itertools import combinations
import copy

import Users
# from Match import Match
import km
import constants as C

def calculate_3d_distance(a, b):
    return math.sqrt(math.pow(a[0] - b[0], 2) + math.pow(a[1] - b[1], 2) + math.pow(a[2] - 0, 2))

def calculate_3d_distance_uavs(a, b):
    return math.sqrt(math.pow(a[0] - b[0], 2) + math.pow(a[1] - b[1], 2) + math.pow(a[2] - b[2], 2))
# 计算3d距离
def calculate_2d_distance(a, b):
    return math.sqrt(math.pow(a[0] - b[0], 2) + math.pow(a[1] - b[1], 2))

def cal_dis_for_user(uavs,user):
    dis=[]
    for uav in uavs:
        d=calculate_3d_distance(uav,user)
        dis.append(d)
    print(dis)

# 计算森林信道路径损耗
def forest_path_loss(uav, user):
    dis = calculate_3d_distance(uav, user)
    # alpha = 3.5
    alpha = 3
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    B = 1
    H = 0.05
    f_c = 2 * 10 ** 9
    c = 1 * 10 ** 8
    path_loss = 10 * alpha * math.log10(dis)
    theta = math.asin(uav[2] / dis)
    theta = math.degrees(theta)
    dis1 = dis  # 无人机和用户间树叶的距离
    slant = A * math.pow(f_c / 1e6, C) * math.pow(dis1, E) * math.pow(theta + G, H)
    fspl = 20 * math.log10(4 * math.pi * 1 * f_c / c)
    L_K_N = slant + fspl + path_loss
    # print(L_K_N)
    # L_K_N=expect(L_K_N)
    return L_K_N
def forest_path_loss_dis(dis):
    # alpha = 3.5
    alpha = 3
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    B = 1
    H = 0.05
    f_c = 2 * 10 ** 9
    c = 1 * 10 ** 8
    path_loss = 10 * alpha * math.log10(dis)
    theta = math.asin(50 / dis)
    theta = math.degrees(theta)
    dis1 = dis  # 无人机和用户间树叶的距离
    slant = A * math.pow(f_c / 1e6, C) * math.pow(dis1, E) * math.pow(theta + G, H)
    fspl = 20 * math.log10(4 * math.pi * 1 * f_c / c)
    L_K_N = slant + fspl + path_loss
    return L_K_N


def expect_L(l1):
    sigma = 3
    R_u = lambda x1: ((10 ** (-(l1 + x1) / 10))) * 1 / (math.sqrt(2 * math.pi) * sigma) * math.exp(
        -(x1 ** 2) / (2 * sigma ** 2))
    R2, err = scipy.integrate.quad(R_u, -float("100"), float("100"))
    return R2

def expect(l1):
    sigma = 3
    R_u = lambda x1: (l1+x1) / (math.sqrt(2 * math.pi) * sigma) * math.exp(
        -(x1 ** 2) / (2 * sigma ** 2))
    R2, err = scipy.integrate.quad(R_u, -float("100"), float("100"))
    # print("阴影衰落：", R2)
    return R2

def path_loss_dis(dis,h):
    a = 9.6
    b = 0.28
    los = 1
    nlos = 20
    freePathLoss = 20 * math.log10(dis) + (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
    angle = math.asin(h / dis)
    angle = angle / math.pi * 180
    plos = 1 / (1 + a * math.exp(-b * (angle - a)))
    pnlos = 1 - plos
    pl = plos * los + pnlos * nlos + freePathLoss
    return pl

def free_path_loss(dis):
    freePathLoss = 20 * math.log10(dis) + (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
    return freePathLoss

# 计算snr
def cal_snr(uav, user, p):
    # loss = forest_path_loss(uav, user)
    dis=calculate_3d_distance(uav,user)
    loss=path_loss_dis(dis,uav[2])
    noise_power = (10 ** (C.N0 / 10)) / 1000
    snr = p * (10 ** (-loss / 10)) / noise_power
    return 10 * math.log10(snr)


def cal_snr2(uav, user, p):
    dis = calculate_3d_distance(uav, user)
    loss = path_loss_dis(dis, uav[2])
    noise_power = (10 ** (C.N0 / 10)) / 1000
    snr = p * (10 ** (-loss / 10)) / noise_power
    return snr

def cal_snr_forest(uav, user, p):
    # loss = forest_path_loss(uav, user)
    # dis=calculate_3d_distance(uav,user)
    loss=forest_path_loss(uav,user)
    noise_power = (10 ** (C.noise / 10)) / 1000
    snr = p * (10 ** (-loss / 10)) / noise_power
    return 10 * math.log10(snr)

def cal_snr_f(uav, user, p):
    dis = calculate_3d_distance(uav, user)
    loss = forest_path_loss(uav, user)

    noise_power = (10 ** (C.noise / 10)) / 1000
    snr = p * (10 ** (-loss / 10)) / noise_power
    # snr=p*loss/noise_power
    return snr

def path_loss(uav,user):
    dis=calculate_3d_distance(uav,user)
    h=uav[2]
    a = 9.6
    b = 0.28
    los = 1
    nlos = 20
    freePathLoss = 20 * math.log10(dis) + (20 * math.log10(2000000000)) + 20 * math.log10(4 * (math.pi) / 3 / 100000000)
    angle = math.asin(h / dis)
    angle = angle / math.pi * 180
    plos = 1 / (1 + a * math.exp(-b * (angle - a)))
    pnlos = 1 - plos
    pl = plos * los + pnlos * nlos + freePathLoss
    return pl

def cal_u(R,d,h):
    dis1=math.sqrt(R**2+h**2)
    dis2=math.sqrt((R+d)**2+h**2)
    loss1=forest_path_loss_dis(dis1)
    loss2=forest_path_loss_dis(dis2)
    noise_power = (10 ** (C.noise / 10)) / 1000
    sum_u=0
    snr1 = C.pmax * C.pmax * (10 ** (-loss1 / 10)) / noise_power
    snr2 = C.pmax* C.pmax * (10 ** (-loss2 / 10)) / noise_power
    u = math.log(1 + snr1) / math.log(1 + snr2)
    # for i in range(1,1001):
    #     snr1 =i/1000*C.pmax * (10 ** (-loss1 / 10)) / noise_power
    #     snr2=i/1001*C.pmax * (10 ** (-loss2 / 10)) / noise_power
    #     u=math.log(1+snr1)/math.log(1+snr2)
    #     sum_u+=u
    # u=sum_u/1000
    print(u)
    return u

def cal_max_radius(Rmin):
    log=Rmin/C.B_max
    print(log)
    c = 1 * 10 ** 8
    noise_power = (10 ** (C.noise / 10)) / 1000
    h=(10**log-1)*noise_power/C.pmax
    L=-10*math.log10(h)
    r=900
    H=50
    dis=math.sqrt(r**2+H**2)
    l=forest_path_loss_dis(dis)
    # fspl = 20 * math.log10(4 * math.pi * 1 * C.f_c / c)
    # r=10**(L-fspl)/30
    print(l)
    return r



#计算速率
def cal_rate(b,p,uav,user):
    snr=cal_snr_f(uav,user,p)
    date_rate = b* math.log2(1 + snr)
    return date_rate

# 计算数据传输速率
def cal_data_rate(snr, bandwidth):
    date_rate = bandwidth * math.log2(1 + snr)
    return date_rate


def dbmToW(p):
    return pow(10, p / 10) / 1000

#获得每个用户的定位无人机组
def get_user_loc_uav(users,p_association):
    user_loc_uav={}
    for i in range(len(users)):
        user_loc_uav[i]=[]
    for i in range(len(p_association)):
        for j in range(len(p_association[i])):
            if p_association[i][j]==1:
                user_loc_uav[j].append(i)
    return user_loc_uav


def getUavInitialPos(k, users):
    model = KMeans(n_clusters=k,  # 聚类簇数
                   random_state=1,  # 决定质心初始化的随机数生成，使用int使随机性具有确定性。
                   max_iter=300,  # 执行一次k-means算法所进行的最大迭代数，默认300
                   ).fit(users)
    center = model.cluster_centers_
    uavs = []
    for item in center:
        # height = random.randint(100, 800)
        height = 50
        position = [int(item[0]), int(item[1]), height]
        uavs.append(position)
    # print(center)
    return uavs

import skfuzzy as fuzz

def fuzzy_c_means_clustering(points, n_clusters, m=2, error=0.005, max_iter=1000, seed=None):
    data = np.array(points)

    # 运行模糊均值聚类
    cntr, _, _, _, _, _, _ = fuzz.cluster.cmeans(
        data.T, n_clusters, m, error, max_iter, seed=seed)

    return cntr
def getUavInitialPos_fcm(k, users):
    center=fuzzy_c_means_clustering(users,k)
    uavs=[]
    for item in center:
        # height = random.randint(100, 800)
        height = 50
        position = [int(item[0]), int(item[1]), height]
        uavs.append(position)
    # print(center)
    return uavs


def get_c_association(users,uavs_pos,serveMax):
    user_num = len(users)
    uavs = [k for k in range(len(uavs_pos))]
    uav_num = len(uavs)
    # print(distance_sum)
    weight = np.zeros((user_num, uav_num), dtype=np.float16)
    graph = []
    for i in range(user_num):
        for j in uavs:
            # weight[i][j]=-self.distance_c[i][j]/(distance_sum[i]-self.distance_c[i][j])
            weight[i][j] = -calculate_3d_distance(uavs_pos[j],users[i])
    dic = {}
    for i in range(user_num):
        for j in uavs:
            for number in range(serveMax):
                graph.append((i, j + 100 * number, weight[i][j]))
                dic[j + 100 * number] = j


    connection = km.run_kuhn_munkres(graph)
    c_connection = {}
    for connect in connection:
        c_connection[connect[0]] = dic[connect[1]]
    # print(c_connection)
    return c_connection

def get_uav_to_user(c_connection,uavs):
    uav_to_user={}
    for i in range(len(uavs)):
        uav_to_user[i]=[]
    for i in c_connection:
        uav_to_user[c_connection[i]].append(i)
    return uav_to_user

# 选择好通信的无人机后，确定其他两架无人机与该无人机对用户获得最优的gdop值
def get_p_association(users,uavs,c_connection,service_num):
    p_connection={}
    for i in c_connection:
        p_connection[i]=[]
    if len(uavs)==3:
        for i in range(len(users)):
            uavs_all=[i for i in range(len(uavs))]
            uavs_all.remove(c_connection[i])
            p_connection[i]=[uavs_all[0],uavs_all[1]]
    else:
        uavs_initial = [i for i in range(len(uavs))]
        uav_group_all = list(combinations(uavs_initial, 3))
        uav_serive_num = [service_num for i in range(len(uav_group_all))]
        uav_group_dic={}
        for i in range(len(uav_group_all)):
            uav_group_dic[uav_group_all[i]]=i
        # uav_serive_num=[service_num for i in range(len(uavs))]
        for i in range(len(users)):
            #获得所有信道数量大于0的无人机编号
            # uav_need=[i for i in range(len(uavs)) if uav_serive_num[i]>0]
            uav_need = [i for i in range(len(uav_group_all)) if uav_serive_num[i] > 0]
            uav_groups = [uav_group_all[i] for i in uav_need]#所有可用的uav_group
            #移除通信的无人机
            uav_group=[]
            uav_dic={}
            count=0
            for group in uav_groups:
                if c_connection[i] in group:
                    tmp=[group[0],group[1],group[2]]
                    tmp.remove(c_connection[i])
                    uav_group.append(tmp)
                    uav_dic[count]=group#找到group对饮的下标
                    count+=1
                    # uav_need.remove(c_connection[i])
            # uav_group=list(combinations(uav_need, 2))

            gdop_min = float('inf')
            index = 0
            for j in range(len(uav_group)):
                uavs_pos = [uavs[c_connection[i]], uavs[uav_group[j][0]],uavs[uav_group[j][1]]]
                gdop_current = GDOP.calculate_gdop(uavs_pos, copy.copy(users[i]))
                if gdop_current < gdop_min:
                    gdop_min = gdop_current
                    index = j
            # print(gdop_min)
            p_connection[i].extend([uav_group[index][0], uav_group[index][1]])
            uav_serive_num[uav_group_dic[uav_dic[index]]]-=1
    # print(p_connection)
    return p_connection

# def cal_max_rate():


def get_p_by_snr(snr,uav,user,noise):
    dis = calculate_3d_distance(uav, user)
    pl= path_loss_dis(dis, uav[2])
    # pl=path_loss_dis(uav,user)
    noise=dbmToW(noise)
    p=(10**(snr/10))*noise/10**(-pl/10)
    # print(p)
    return p

def get_p_by_snr_f(snr,uav,user,noise):
    dis = calculate_3d_distance(uav, user)
    pl= forest_path_loss(uav, user)
    # pl=path_loss_dis(uav,user)
    noise=dbmToW(noise)
    p=(10**(snr/10))*noise/10**(-pl/10)
    # print(p)
    return p
def get_rest_p(users,uavs,p_connection,snr_thre=C.snr_thre,noise=C.N0,pmax=C.pmax):
    p_rest=[pmax for i in range(len(uavs))]
    for user_index in p_connection:
        uav_group=p_connection[user_index]
        for uav_index in uav_group:
            p_need=get_p_by_snr(snr_thre,uavs[uav_index],users[user_index],noise)
            p_rest[uav_index] -= p_need
            if p_rest[uav_index]<=0:
                p_rest[uav_index]=0

    return p_rest

def get_rest_p_f(users,uavs,p_connection,snr_thre=0,noise=C.N0,):
    p_rest=[C.pmax for i in range(len(uavs))]
    for user_index in p_connection:
        uav_group=p_connection[user_index]
        for uav_index in uav_group:
            p_need=get_p_by_snr(snr_thre,uavs[uav_index],users[user_index],noise)
            p_rest[uav_index]-=p_need

    return p_rest

def get_dis(i,uavs_pos,users,uav_to_user):
    dis=[]
    for k in uav_to_user[i]:
        dis.append(calculate_3d_distance(uavs_pos[i],users[k]))
    return dis

def init_p(uavs_pos,users):
    c_association=get_c_association(users,uavs_pos,math.ceil(len(users)/len(uavs_pos)))
    uav_to_user=get_uav_to_user(c_association,uavs_pos)
    # print(uav_to_user)
    p_association = get_p_association(users, uavs_pos, c_association, len(users) * 2)
    p_rest=get_rest_p(users,uavs_pos,p_association)
    # print(p_rest)
    p_alloc=[]
    for i in range(len(uavs_pos)):
        #归一化生成 功率
        dis=get_dis(i,uavs_pos,users,uav_to_user)
        # p=[random.randint(1,len(uav_to_user[i])) for i in range(len(uav_to_user[i]))]
        p_new=[0 for i in range(len(users))]
        for j in range(len(uav_to_user[i])):
            p_new[uav_to_user[i][j]]=dis[j]/sum(dis)*p_rest[i]
            # p_new.append(p[j]/sum(p)*p_rest[i])
        p_alloc.append(p_new)
        # p_alloc.append(np.random.uniform(0, p_rest[i] / len(uav_to_user[i]), (len(uav_to_user[i]))))
    return p_alloc

def cal_fly_energy(dis,speed):
    e1=C.P0*(1+3*speed**2/C.U**2)
    e2=C.P1*math.sqrt(math.sqrt(1+speed**4/4*C.vr**4)-speed**2/2*C.vr**2)
    e3=0.3*1.225*0.05*0.79*speed**3
    e=C.P_0 * (1 + (3 * speed**2) / (C.U_tip**2)) + C.P_i * (np.sqrt(1+(speed**4/(4*(C.v_0**4))))-speed**2/(2*(C.v_0**2)))**0.5 + (0.5 * C.d_0 * C.rho * C.sigma * C.A * speed**3)

     # +(0.5 * d_0 * rho * sigma * A * V**3)
    # return (e1+e2+e3)*dis/speed
    return e*dis/speed

if __name__ == '__main__':
    users=Users.getUsers("./data.csv")
    uavs=getUavInitialPos(4,users)
    r=cal_max_radius(0.3)
    u=cal_u(r,10,50)#1.13#1.21
    # c_connection=get_c_association(users,uavs,math.ceil(len(users)/len(uavs)))
    # p_connection=get_p_association(users,uavs,c_connection,math.ceil(len(users)*2/len(uavs)))
    # print(c_connection)
    # print(p_connection)
    # print(get_uav_to_user(c_connection,uavs))
    # print(cal_snr(uavs[0],users[0],0.003))
    # print(get_p_by_snr(20,uavs[0],users[0],C.N0))
    # print(get_p_by_snr_f(0, uavs[0], users[0], C.noise))
    # print(cal_rate(0.19,0.005,uavs[0],users[0]))
    # print(cal_snr_forest(uavs[0],users[0],0.003))
    # print((10 ** (C.N0 / 10)) / 1000)
    # print(get_rest_p(users,uavs,p_connection))