import numpy as np
import copy
import time
import matplotlib
import matplotlib.pyplot as  plt
import csv
import init
from virtual_force import *
from scipy.interpolate import griddata
import vforce2

# from vatest import *
numfbs = 0
nummbs = 0
sumgbs = nummbs + numfbs
R_mbs = 5
R_uav = 25
snrth = -5
itermax = 200
init_upper = 20
num1 = 1.01
num2 = 0
num3 = 5
num4 = 0
font_set = matplotlib.font_manager.FontProperties(fname=r"HYQiHei-60S.ttf")


class UE:
    def __init__(self):
        self.loc = [0, 0, 0]
        self.snrth = 0


class BS:
    def __init__(self):
        self.loc = [0, 0, 0]
        self.sort = 'UAV'
        self.power_upper = 0
        self.power_allocation = []  # 长度为用户数
        self.power_a2a = 0
        self.power = 0
        self.rate = 0
        self.rate_upper = 0


def read_UElist(ue_addr):
    ue_list = []
    with open(ue_addr, 'r') as f:
        f = csv.reader(f, delimiter=',')
        next(f)
        for row in f:
            ue_list.append(row)
    for i in range(len(ue_list)):
        ue_list[i] = list(map(float, ue_list[i]))
    UE_list = []
    for i in range(len(ue_list)):
        new_ue = UE()
        new_ue.loc = [ue_list[i][0], ue_list[i][1], 0]
        new_ue.sort = 'ue'
        new_ue.snrth = dB2powerratio(snrth)
        UE_list.append(new_ue)
    return UE_list

def pathlossA2G_hotmap(dist2,h):
    '''
    :param dist2: 基站在水平面的投影与用户的距离
    :param h:高度
    :return:空地倾斜路径损耗
    '''
    f = 1.4e9  # MHz
    d0 = 1
    c = 3e8
    alpha = 3.5
    pi = math.pi
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    H = 0.05
    dep_tree = 2  # 植被深度
    dist3d = math.sqrt((dist2)**2 + h**2)
    fspl = 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10(dist3d / d0)
    theta = math.atan2(h, dist2)
    slant =  A * np.power(f / math.pow(10, 6), C) * np.power(dep_tree, E) * np.power(theta + G, H)
    return fspl + slant

def BS_list_save(title_,BS_list):
    title = title_ + str(len(BS_list)) + '.csv'
    csv_file = open(title, 'w', newline='')
    writer = csv.writer(csv_file)
    BS_posi = []
    for i in range(len(BS_list)):
        BS_posi.append(BS_list[i].loc)
    writer.writerows(BS_posi)
    csv_file.close()
def expend_data(x, y, z, smooth_degree=10):
    num_x = len(x)
    max_x = np.max(x)
    min_x = np.min(x)

    num_y = len(y)
    max_y = np.max(y)
    min_y = np.min(y)

    X, Y = np.meshgrid(x, y)
    coordinates = np.hstack((X.flatten()[:, None], Y.flatten()[:, None]))
    print(z.shape)

    m = np.linspace(min_x, max_x, (num_x - 1) * smooth_degree + 1)
    n = np.linspace(min_y, max_y, (num_y - 1) * smooth_degree + 1)
    M, N = np.meshgrid(m, n)

    U = griddata(coordinates, z, (M, N), method='cubic')

    return M, N, U

def main(addr):

    hotmap = np.zeros([100 , 100])

    title_ = 'hotmap'
    #csv_name = 'va_joint_solution.csv'
    #csv_file = open(csv_name, 'w', newline='')
    #writer = csv.writer(csv_file)


    # 导入用户位置
    UE_list = read_UElist(addr)
    # print(len(UE_list), UE_list[0].loc)
    # 确定用户通信需求 各用户所需的信噪比
    numuav = 4
    BS_list = init.main3(UE_list.copy(), sumgbs, R_mbs=R_mbs, sumuav=numuav, R_uav=R_uav)
    snr, dist = cal_snr(UE_list, BS_list)

    link, linked = comm_connection(UE_list, BS_list, snr)  # 未建立连接的确保发射功率为0
    power_equal_allocation(BS_list, UE_list, linked, link)

    coverate = cal_coverate(link)
    draw_uavus_link_circle(UE_list, BS_list, linked, basecolor='red', size=40, )
    print(len(BS_list), coverate)
    # 计算覆盖率 以及连接情况
    # print(coverate)
    # 对于未被覆盖的边缘用户来说，目前调整位置并不能为他们提供服务
    # 那么如何继续对位置进行调整 或者能不能对功率进行调整从而实现服务呢
    # 约束：覆盖，功率
    # 计算引力，计算斥力
    times = 0
    init_times = 0
    total_rate = []
    total_rate1 = []
    # total_rate.append(cal_network_rate(BS_list, UE_list,linked))
    # total_rate1.append(cal_network_rate(BS_list, UE_list,linked))
    start = time.time()
    # 求最小数量
    while (coverate != 1):
        # 位置调整
        update_bs_power(BS_list)
        Fa, Fa_list = virtual_fa(UE_list, BS_list)
        Fr, Fr_list = virtual_fr(BS_list)

        F = total_force(Fa, Fr)
        v = force2v(F, Max_Speed=5)
        move_length = update_state_position(BS_list, v)
        snr, dist = cal_snr(UE_list, BS_list)
        link, linked = comm_connection(UE_list, BS_list, snr)  # 确立连接关系 未建立连接的确保功率为0
        # power_equal_allocation(BS_list, UE_list, linked, link)
        update_bs_power(BS_list)
        coverate = cal_coverate(link)

        if (coverate == 1):
            # print(times, '实现全覆盖，退出迭代')
            break
        # 重新初始化
        if (coverate != 1) and init_times <= init_upper:

            # print('此次无法全覆盖，重新初始化', numuav, coverate)
            BS_list = init.main3(UE_list.copy(), sumgbs, R_mbs=R_mbs, sumuav=numuav, R_uav=R_uav)

        elif (coverate != 1):
            # print('无法全覆盖，增加无人机数量并重新初始化', numuav, coverate)
            init_times = 0
            numuav += 1

            BS_list = init.main3(UE_list.copy(), sumgbs, R_mbs=R_mbs, sumuav=numuav, R_uav=R_uav)
            snr, dist = cal_snr(UE_list, BS_list)
            link, linked = comm_connection(UE_list, BS_list, snr)  # 确立连接关系 未建立连接的确保功率为0
            draw_uavus_link_circle(UE_list, BS_list, linked, basecolor='red', size=40, )
            coverate = cal_coverate(link)
            print(len(BS_list),coverate)
            # fig = plt.figure('fig')
            # fig.clf()
            BS_list_save(title_, BS_list)
            #draw_uavus_link_circle(UE_list, BS_list, linked, basecolor='red', size=40, )
            plt.figure()
            for i in range(len(UE_list)):
                plt.scatter(UE_list[i].loc[0], UE_list[i].loc[1], color='blue', marker='x', s=40)
            plt.scatter(UE_list[i].loc[0], UE_list[i].loc[1], color='blue', marker='x', s=40, label='UE')
            plt.xlim(-10, 110)
            plt.ylim(-10, 110)
            draw_uavus_link(UE_list, BS_list, linked, basecolor='red', size=40, )#"""

            #pic_name = 'deployprocess'
            #title =  str(numuav) + '.jpg'
            plt.legend(loc='upper right', borderaxespad=num4, prop=font_set)
            # fig.savefig(title)
        init_times += 1  # """
        # times += 1
    BS_list_save(title_, BS_list)
    plt.figure()
    for i in range(len(UE_list)):
        plt.scatter(UE_list[i].loc[0], UE_list[i].loc[1], color='blue', marker='x', s=40)
    plt.scatter(UE_list[i].loc[0], UE_list[i].loc[1], color='blue', marker='x', s=40, label='UE')
    plt.xlim(-10, 110)
    plt.ylim(-10, 110)
    draw_uavus_link(UE_list, BS_list, linked, basecolor='red', size=40, )#"""

    # pic_name = 'deployprocess'
    # title = pic_name + str(numuav) + '.jpg'
    plt.legend(loc='upper right', borderaxespad=num4, prop=font_set)
    # 得到了无人机的位置
    # 根据无人机的位置来画热力图
    print('开始画图',len(BS_list))
    draw_uavus_link_circle(UE_list, BS_list, linked, basecolor='red', size=40, )
    start = time.time()
    # 以五为单位试试

    for k in range(len(BS_list)):
        power = 0.2

        for i in range(int(len(hotmap)/5)):
            for j in range(int(len(hotmap[0])/5)):
                dist2 = np.sqrt((BS_list[k].loc[0] - i*5)**2+(BS_list[k].loc[1]-j*5)**2)
                pldB = pathlossA2G_hotmap((dist2)*100, BS_list[k].loc[2])
                pl = dB2powerratio(pldB)
                hotmap[i*5][j*5] += (power/ pl) / dBm2w(-120)
                if hotmap[i*5][j*5] >10 :hotmap[i*5][j*5]=10
    print(1,power/dB2powerratio(pathlossA2G_hotmap(0,200))/dBm2w(-120))
    # 每个5 5方阵 等于自己左上角的值
    for i in range(len(hotmap)):
        for j in range(len(hotmap)):
            if hotmap[i][j] == 0:

                hotmap[i][j] = hotmap[int((i)/5)*5][int((j)/5)*5]

    print(time.time()-start)

    matrix = []
    for i in range(len(hotmap)):
        tmp = hotmap[i]
        matrix.append(tmp)
    csv_name = 'hotmap1.csv'
    csv_file = open(csv_name, 'w', newline='')
    writer = csv.writer(csv_file)
    writer.writerows(matrix)
    csv_file.close()
    x_draw = np.array(np.arange(100))
    y_draw = np.array(np.arange(100))
    z_draw = hotmap
    xx1, yy1 = np.meshgrid(x_draw, y_draw)
    fig, ax = plt.subplots()
    c1 = ax.pcolormesh(xx1, yy1, z_draw, cmap='viridis_r')
    fig.colorbar(c1, ax=ax, label='AUPR')
    plt.xlabel(r'$\overline{p}$')
    plt.ylabel(r'$\overline{q}$')
    plt.savefig('heatmap1.tif', dpi=300)
    """M1, N1, U1 = expend_data(x_draw, y_draw, z_draw.ravel(), smooth_degree=100)
    fig, ax = plt.subplots()
    c1 = ax.pcolormesh(M1, N1, U1.T, cmap='viridis_r')
    fig.colorbar(c1, ax=ax, label='AUPR')
    plt.xlabel(r'$\overline{p}$')
    plt.ylabel(r'$\overline{q}$')
    plt.savefig('heatmap2.tif', dpi=300)"""
    plt.show()#"""



    return len(BS_list)


if __name__ == '__main__':
    addr = r"D:\pycharm\deploy_algorithms\data\list_ue100_2.csv"
    result = main(addr)
    print(result)
