import numpy as np
import time
import matplotlib.pyplot as plt
import csv
import init
from matplotlib.patches import Circle
import matplotlib
import math
# calculate Euclidean distance
font_set = matplotlib.font_manager.FontProperties(fname=r"HYQiHei-60S.ttf")
h = 200
uav_power_aver = 0.2#W
Noise = -130 #噪声功率
uav_power_upper = 6
uav_aver_upper = 0.4
power_t = 0.2
class UE:
    def __init__(self):
        self.loc = [0, 0, 0]
        self.snrth = 0
        self.uav = -1
class BS:
    def __init__(self):
        self.loc = [0, 0, 0]

def euclDistance(vector1, vector2):
    return np.sqrt(sum(np.power(vector2 - vector1, 2)))  # 求这两个矩阵的距离，vector1、2均为矩阵

# init centroids with random samples
# 在样本集中随机选取k个样本点作为初始质心
def initCentroids(dataSet, k):
    numSamples, dim = dataSet.shape  # 矩阵的行数、列数
    centroids = np.zeros((k, dim))
    for i in range(k):
        index = int(np.random.uniform(0, numSamples))  # 随机产生一个浮点数，然后将其转化为int型
        centroids[i, :] = dataSet[index, :]
    return centroids

# k-means cluster
# dataSet为一个矩阵
# k为将dataSet矩阵中的样本分成k个类
def kmeans(dataSet, k,R):
    numSamples = dataSet.shape[0]  # 读取矩阵dataSet的第一维度的长度,即获得有多少个样本数据
    # first column stores which cluster this sample belongs to,
    # second column stores the error between this sample and its centroid
    clusterAssment = np.mat(np.zeros((numSamples, 2)))  # 得到一个N*2的零矩阵
    # 尚未分类的需要标注 暂定标注100
    for i in range(len(clusterAssment)):
        clusterAssment[i,0] = 100
    clusterChanged = True
    max_distance = R
    ## step 1: init centroids
    # centroids = initCentroids(dataSet, k)  # 在样本集中(随机)选取k个样本点作为初始质心
    UE_list = []
    for i in range(len(dataSet)):
        new_ue = UE()
        new_ue.loc = [dataSet[i][0],dataSet[i][1]]
        UE_list.append(new_ue)
    # centroids = initCentroids(dataSet, k)
    centroids = init.init_ch(UE_list,k,R)
    while clusterChanged:
        clusterChanged = False
        ## for each sample
        for i in range(numSamples):  # range
            minDist = 100000.0
            minIndex = 0
            ## for each centroid
            ## step 2: find the centroid who is closest
            # 计算每个样本点与质点之间的距离，将其归内到距离最小的那一簇
            for j in range(k):
                distance = euclDistance(centroids[j, :], dataSet[i, :])
                if distance < min(minDist,max_distance):  # 选择距离更小的保存到簇里，加入其他限制（例如，最远距离不超过多少）
                    minDist = distance
                    minIndex = j

                    ## step 3: update its cluster
            # k个簇里面与第i个样本距离最小的的标号和距离保存在clusterAssment中
            # 若所有的样本不在变化，则退出while循环
            if clusterAssment[i, 0] != minIndex:
                clusterChanged = True
                clusterAssment[i, :] = minIndex, minDist ** 2

        ## step 4: update centroids
        for j in range(k):
            # clusterAssment[:,0].A==j是找出矩阵clusterAssment中第一列元素中等于j的行的下标，返回的是一个以array的列表，第一个array为等于j的下标
            pointsInCluster = dataSet[np.nonzero(clusterAssment[:, 0].A == j)[0]]  # 将dataSet矩阵中相对应的样本提取出来
            if len(pointsInCluster):
                centroids[j, :] = np.mean(pointsInCluster, axis=0)  # 计算标注为j的所有样本的平均值
    # 有的距离过远 仍在第0类里
    for i in range(len(clusterAssment)):
        if clusterAssment[i,1] > max_distance**2:
            clusterAssment[i,0] = k+1
    # print('Congratulations, cluster complete!')
    return centroids, np.array(clusterAssment)


# show your cluster only available with 2-D data
# centroids为k个类别，其中保存着每个类别的质心
# clusterAssment为样本的标记，第一列为此样本的类别号，第二列为到此类别质心的距离
def showCluster(dataSet, k, centroids, clusterAssment):
    numSamples, dim = dataSet.shape
    if dim != 2:
        print("Sorry! I can not draw because the dimension of your data is not 2!")
        return 1

    mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr', '<r', 'pr']

    if k > len(mark):
        print("Sorry! Your k is too large! please contact wojiushimogui")
        return 1

        # draw all samples
    for i in range(numSamples):
        markIndex = int(clusterAssment[i, 0])  # 为样本指定颜色
        if markIndex < k:
            plt.plot(dataSet[i, 0], dataSet[i, 1], mark[markIndex])
    #mark = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db', '<b', 'pb']
    # draw the centroids
    for i in range(k):
        plt.plot(centroids[i, 0], centroids[i, 1], markersize=12)

    plt.show()

def read_UElist(ue_addr):
    ue_list = []
    with open(ue_addr, 'r') as f:
        f = csv.reader(f, delimiter=',')
        next(f)
        for row in f:
            ue_list.append(row)
    for i in range(len(ue_list)):
        ue_list[i] = list(map(float, ue_list[i]))
    return ue_list

def cal_unclustered(k,clusterAssment):
    num = 0
    for i in range(len(clusterAssment)):
        if clusterAssment[i,0] >= k:
            num += 1
    return num

def get_link(uav_list,ue_list,clusterAssment):
    link = np.zeros([len(uav_list),len(ue_list)])
    for i in range(len(clusterAssment)):
        if clusterAssment[i][0] < len(uav_list):
            link[int(clusterAssment[i][0])][i] = 1

    linked = []
    for i in range(len(uav_list)):
        tmp1 = []
        for j in range(len(ue_list)):
            if link[i][j] == 1:
                tmp1.append(j)
        linked.append(tmp1)

    return link,linked

def draw_uavus_link(ue_list,uav_list,linked,basecolor='red',size=40,):

    x = []
    y = []
    for i in range(len(uav_list)):
        x.append(round(uav_list[i][0], 2))
        y.append(round(uav_list[i][1], 2))
    for i in range(len(uav_list)):
        plt.scatter(uav_list[i][0], uav_list[i][1], color=basecolor, marker='v', s=size)

    for i in range(len(ue_list)):
        colorue = 'blue'
        plt.scatter(ue_list[i][0], ue_list[i][1], color=colorue, marker='x', s=40)
    plt.scatter(uav_list[0][0],uav_list[0][1],color=basecolor, marker='v', s=size,label='Deployed UAV')

    for i in range(len(linked)):
        for j in range(len(linked[i])):
            plt.plot([uav_list[i][0], ue_list[linked[i][j]][0]], [uav_list[i][1], ue_list[linked[i][j]][1]], color='g', linestyle='-', linewidth=1)

def draw_uavus_link_circle2(ue_list,uav_list,linked,basecolor='red',size=40,name='process_'):
    R_1 = 23*np.ones(len(uav_list))

    loc = []
    for i in range(len(uav_list)):
       loc.append(uav_list[i])
    csv_name = 'location'+name+str(len(uav_list))+'.csv'
    csv_file = open(csv_name, 'w', newline='')
    writer = csv.writer(csv_file)
    writer.writerows(loc)
    csv_file.close()

    linksaved = []
    for i in range(len(uav_list)):
        linksaved.append(linked[i])
    csv_name = 'link' + name + str(len(uav_list)) + '.csv'
    csv_file = open(csv_name, 'w', newline='')
    writer = csv.writer(csv_file)
    writer.writerows(linksaved)
    csv_file.close()

    fig = plt.figure('fig')
    ax = fig.add_subplot(111)  # 111代表1*1的图的第一个子图
    scale = 100
    for i in range(len(linked)):
        for j in range(len(linked[i])):
            plt.plot([uav_list[i][0]*scale, ue_list[linked[i][j]][0]*scale], [uav_list[i][1]*scale, ue_list[linked[i][j]][1]*scale], color='g', linestyle='-', linewidth=1,zorder=1)
    for i in range(len(uav_list)):
        circle = Circle(xy=(uav_list[i][0]*scale, uav_list[i][1]*scale), radius=R_1[i]*scale, alpha=0.1, color='b')
        ax.add_patch(circle)
    ueiscovered = np.zeros(len(ue_list))
    for i in range(len(ue_list)):
        for j in range(len(uav_list)):
            if i in linked[j]:
                ueiscovered[i] = 1

    for i in range(len(ue_list)):
        if ueiscovered[i]:
            plt.scatter(ue_list[i][0] * scale, ue_list[i][1] * scale, color='blue', marker='x', s=40,)
        else:
            plt.scatter(ue_list[i][0] * scale, ue_list[i][1] * scale, color='k', marker='x', s=40)
    for i in range(len(ue_list)):
        if ueiscovered[i]:
            plt.scatter(ue_list[i][0] * scale, ue_list[i][1] * scale, color='blue', marker='x', s=40,label='UE')
            break

    plt.scatter(uav_list[0][0]*scale,uav_list[0][1]*scale,color=basecolor, marker='v', s=size,label='UAV')
    for i in range(len(ue_list)):
        if not ueiscovered[i]:
            plt.scatter(ue_list[i][0] * scale, ue_list[i][1] * scale, color='k', marker='x', s=40,label='uncovered UE')
            break
    for i in range(len(uav_list)):
        plt.scatter(uav_list[i][0] * scale, uav_list[i][1] * scale, color=basecolor, marker='v', s=size,zorder=3)
    plt.xlim(-10*scale, 110*scale)
    plt.ylim(-10*scale, 110*scale)
    plt.yticks(fontproperties='Times New Roman', size=15)  # 设置大小及加粗
    plt.xticks(fontproperties='Times New Roman', size=15)
    plt.xlabel('x(m)')
    plt.ylabel('y(m)')
    plt.gca().set_aspect('equal', adjustable = 'box')
    plt.legend(bbox_to_anchor=(1, 1),loc='upper right', borderaxespad=0, prop=font_set,ncol=3)
    title = name+'uav'+str(len(uav_list)) +'.pdf'
    fig.savefig(title)
    fig.clf()
# pl(dB) = 10*log(pr/pt)
def pathlossA2G(dist2,h):
    '''
    :param dist2: 基站在水平面的投影与用户的距离
    :param h:高度
    :return:空地倾斜路径损耗
    '''
    f = 1.4e9  # MHz
    d0 = 1
    c = 3e8
    alpha = 3.5
    pi = math.pi
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    H = 0.05
    dep_tree = 2  # 植被深度
    dist3d = math.sqrt((dist2*100)**2 + h**2)
    fspl = 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10(dist3d / d0)
    theta = math.atan2(h, dist2*100)
    slant =  A * np.power(f / math.pow(10, 6), C) * np.power(dep_tree, E) * np.power(theta + G, H)
    return fspl + slant
def dB2powerratio(dB):
    return np.power(10,dB/10)
def dBm2w(dBm):
    return 0.001*np.power(10,dBm/10)
def cal_network_rate(uav_list,ue_list,linked):
    total_rate = 0

    for i in range(len(uav_list)):
        for j in range(len(linked[i])):
            tmp = linked[i][j]
            dist2 = np.sqrt((uav_list[i][0] - ue_list[tmp][0]) ** 2 + (uav_list[i][1] - ue_list[tmp][1]) ** 2)
            pldB = pathlossA2G(dist2, h=200)
            pl = dB2powerratio(pldB)
            snr = (0.2 / pl) / dBm2w(Noise)
            total_rate += np.log10(1 + 1 * snr)
            # print(i,tmp,snr,total_rate)
    return total_rate
def main(ue_list,R):

    matrix_kmeans = np.zeros([len(ue_list), len(ue_list[0])])  # 2 * usernum
    for i in range(len(matrix_kmeans)):
        for j in range(len(matrix_kmeans[0])):
            matrix_kmeans[i][j] = ue_list[i][j]
    # clusterAssment 共len(ue_list)个 第一个数值是分配给的簇
    k = 4
    uav_list, clusterAssment = kmeans(matrix_kmeans, k, R)
    while (cal_unclustered(k, clusterAssment)):
        k += 1
        uav_list, clusterAssment = kmeans(matrix_kmeans, k, R)
        link, linked = get_link(uav_list, ue_list, clusterAssment)
        print(k,cal_network_rate(uav_list,ue_list,linked))
        draw_uavus_link_circle2(ue_list, uav_list, linked, basecolor='red', size=40, name='process_')
    BS_list = []

    for i in range(uav_list.__len__()):
        new_uav = BS()
        loc = [uav_list[i][0],uav_list[i][1],h]

        new_uav.loc = loc
        new_uav.sort = 'UAV'

        new_uav.rate = 0
        new_uav.power_allocation = np.ones(len(ue_list)) * power_t
        new_uav.power_upper = uav_power_upper
        new_uav.power = 0
        new_uav.power_aver_upper = uav_aver_upper
        new_uav.power_a2a = power_t
        BS_list.append(new_uav)
    return BS_list#uav_list,clusterAssment
if __name__ == '__main__':
    addr = r"D:\pycharm\deploy_algorithms\data\list_ue140_1.csv"
    ue_list = read_UElist(addr)

    # usernum = 50
    # for i in range(10):
    #     for j in range(10):
    #         file_name = 'data\list_ue'+str(usernum)+'_'+str(j + 1)+'.csv'
    #         ue_list = read_UElist(file_name)
    #         uav_list, clusterAssment = main(ue_list,R=23)
    #         link, linked = get_link(uav_list, ue_list, clusterAssment)
    #         print(len(ue_list),cal_unclustered(len(uav_list), clusterAssment),len(uav_list))
    #         # print(uav_list)
    #     usernum+=10

    uav_list, clusterAssment = main(ue_list, R=23)
    link, linked = get_link(uav_list, ue_list, clusterAssment)
    print(len(ue_list), cal_unclustered(len(uav_list), clusterAssment), len(uav_list))
    # print(clusterAssment)
    plt.figure()
    draw_uavus_link(ue_list, uav_list, linked, basecolor='red', size=40, )
    plt.show()
