import numpy as np
import math
import copy
import matplotlib.pyplot as plt

Noise = -115 #噪声功率
power_t = 40
snr_upper1 = -10
Ka = 10
Kr = 1000
snr_upper2 = 0
d_upper2 = 25
Max_Speed_uav = 1

X_SIZE = 100
Y_SIZE = 100
power_step = 10
# 定义用户与基站类
class UE:
    def __init__(self):
        self.loc = [0,0,0]
        self.snrth = 0

class BS:
    def __init__(self):
        self.loc = [0,0,0]
        self.sort = 'UAV'
        self.power_upper = 0
        self.power_allocation = [] #长度为用户数
        self.power_a2a = 0
        self.power = 0
        self.rate = 0
        self.rate_upper = 0
# 空地信道，路损计算
def pathlossA2G(dist2,h):
    '''
    :param dist2: 基站在水平面的投影与用户的距离
    :param h:高度
    :return:空地倾斜路径损耗
    '''
    f = 1.4e9  # MHz
    d0 = 1
    c = 3e8
    alpha = 3.5
    pi = math.pi
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    H = 0.05
    dep_tree = 2  # 植被深度
    dist3d = math.sqrt((dist2*100)**2 + h**2)
    fspl = 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10(dist3d / d0)
    theta = math.atan2(h, dist2)
    slant =  A * np.power(f / math.pow(10, 6), C) * np.power(dep_tree, E) * np.power(theta + G, H)
    return fspl + slant

# 空空信道，路损计算
def pathlossA2A(BS_list):
    f = 1.4e9  # MHz
    d0 = 1
    c = 3e8
    alpha = 3.5
    pi = math.pi
    snr = np.zeros([len(BS_list),len(BS_list)])
    dist = np.zeros([len(BS_list), len(BS_list)])
    for i in range(len(BS_list)):
        for j in range(len(BS_list)):
            if j!= i:
                dist2 = math.sqrt((BS_list[i].loc[0]-BS_list[j].loc[0])**2+(BS_list[i].loc[1]-BS_list[j].loc[1])**2)
                fspl = 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10((dist2*100) / d0)
                dist[i][j] = dist2
                snr[i][j] = BS_list[i].power_a2a - fspl -Noise
    return snr,dist
def cal_snr(UE_list,BS_list):
    '''
    :return: 信噪比矩阵
    '''
    snr = np.zeros([len(UE_list),len(BS_list)])*(-100)
    dist = np.zeros([len(UE_list),len(BS_list)])
    for i in range(len(BS_list)):
        for j in range(len(UE_list)):
            dist2 = np.sqrt((BS_list[i].loc[0]-UE_list[j].loc[0])**2+(BS_list[i].loc[1]-UE_list[j].loc[1])**2)
            dist[j][i] = dist2
            snr[j][i] = BS_list[i].power_allocation[j] - pathlossA2G(dist2,BS_list[i].loc[2]) - Noise
    return snr,dist

# 连接关系判定
def comm_connection(UE_list,BS_list,snr):
    '''
    :param UE_list: 用户列表
    :param BS_list: 基站列表
    :return: 01连接关系矩阵 , size：基站数*用户数
    '''

    link = np.zeros([len(BS_list),len(UE_list)])
    linked = []
    # 为用户选择基站，从用户的角度出发
    for i in range(len(UE_list)):
        # 选择信噪比最大的基站建立连接

        idx = [k for k, x in enumerate(snr[i]) if abs(snr[i][k] - max(snr[i])) < 0.1]
        # 如果存在多个最大信噪比 选择基站速率最小的，因为提供服务的空间更大

        idx0 = idx[0]
        if len(idx) > 1:
            for k in idx:
                if BS_list[k].rate < BS_list[idx0].rate:
                    idx0 = k
        if snr[i][idx0] >= UE_list[i].snrth:
            link[idx0][i] = 1
            BS_list[idx0].rate += np.log10(1+np.power(10, 1 * snr[i][idx0] / 10))
    # 未建立连接 功率为0
    for i in range(len(BS_list)):
        for j in range(len(UE_list)):
            if link[i][j] == 0:
                BS_list[i].power_allocation[j] = 0
        BS_list[i].power = sum(BS_list[i].power_allocation)
    for i in range(len(BS_list)):
        tmp1 = []
        for j in range(len(UE_list)):
            if link[i][j] == 1:
                tmp1.append(j)
        linked.append(tmp1)
    return link,linked
# 覆盖率计算
def cal_coverate(link):
    uelinked = np.zeros(len(link[0]))
    for i in range(len(link)):
        for j in range(len(uelinked)):
            if link[i][j] == 1:
                uelinked[j] = 1
    rate = len(np.nonzero(uelinked)[0])/len(uelinked)
    return rate


# 虚拟引力计算,输出合力以及各个分力
def virtual_fa(UE_list,BS_list):

    snr,dist2 = cal_snr(UE_list, BS_list)

    # 产生引力的前提是建立连接，二维的力
    Fa = [] #长度：基站数
    Fa_list = [] #长度：基站*用户 Fa为Fa_list求和
    for i in range(len(BS_list)):
        Fa_tmp = []
        Fa_i = [0,0]
        for j in range(len(UE_list)):
            Fa_tmp_j = [0,0]
            # 建立连接了 则snr一定大于snr upper

            if   snr[j][i] > snr_upper1 and BS_list[i].power < BS_list[i].power_upper:
                Fa_i_abs = Ka * (BS_list[i].power_allocation[j]+power_t) * abs(snr[j][i]-UE_list[j].snrth)/(dist2[j][i]**2 + (BS_list[i].loc[2]/100)**2) # 力的大小的计算
                Fa_tmp_j[0] = Fa_i_abs * (UE_list[j].loc[0]-BS_list[i].loc[0])/dist2[j][i] # x轴上的大小
                Fa_tmp_j[1] = Fa_i_abs * (UE_list[j].loc[1]-BS_list[i].loc[1])/dist2[j][i]
                Fa_i[0] += copy.deepcopy(Fa_tmp_j[0])
                Fa_i[1] += copy.deepcopy(Fa_tmp_j[1])

            Fa_tmp.append(Fa_tmp_j)
        Fa_list.append(Fa_tmp)
        Fa.append(Fa_i)

    return Fa,Fa_list

# 虚拟斥力计算,一个是基站之间，一个是区域边界对基站
def virtual_fr(BS_list):

    snr,dist = pathlossA2A(BS_list)
    Fr = []
    Fr_list = [] # size: 基站*基站
    # 产生斥力的前提是信噪比小于某个值，且功率未达上限
    for i in range(len(BS_list)):
        Fr_tmp = []
        Fr_i = [0,0]
        for j in range(len(BS_list)):
            # 基站i受到的来自基站j的斥力
            Fr_tmp_j = [0,0]
            if j!=i and snr[i][j] > snr_upper2 and BS_list[i].power < BS_list[i].power_upper:
                Fr_i_abs = Kr * (d_upper2 - dist[i][j])  # 力的大小的计算
                Fr_tmp_j[0] = Fr_i_abs * (BS_list[j].loc[0] - BS_list[i].loc[0]) / dist[i][j] #
                Fr_tmp_j[1] = Fr_i_abs * (BS_list[j].loc[1] - BS_list[i].loc[1]) / dist[i][j]
                Fr_i[0] -= copy.deepcopy(Fr_tmp_j[0])
                Fr_i[1] -= copy.deepcopy(Fr_tmp_j[1])

            Fr_tmp.append(Fr_tmp_j)
        Fr_list.append(Fr_tmp)
        Fr.append(Fr_i)
    # 与边界的距离不要小于10m即0.1格点
    for i in range(len(BS_list)):
        if (BS_list[i].loc[0]) < 0.1 :
            Fr[i][0] += Kr * abs(BS_list[i].loc[0])
        if (X_SIZE-BS_list[i].loc[0]) < 0.1:
            Fr[i][0] -= Kr * abs(BS_list[i].loc[0]-X_SIZE)
        if (BS_list[i].loc[1]) < 0.1 :
            Fr[i][1] += Kr *abs(BS_list[i].loc[1])
        if (Y_SIZE-BS_list[i].loc[1]) < 0.1:
            Fr[i][1] -= Kr * abs(BS_list[i].loc[1]-X_SIZE)
    return Fr,Fr_list
# 求合力
def total_force(Fa,Fr):
    F = []
    for i in range(len(Fa)):
        tmp = [0,0]
        tmp[0] = Fa[i][0] + Fr[i][0]
        tmp[1] = Fa[i][1] + Fr[i][1]
        F.append(tmp)
    return F




# 将力进行有界处理
def force2v(F,Max_Speed = Max_Speed_uav):

    v = []
    for i in range(F.__len__()):
        v_i = [0, 0]

        v_i[0] = math.atan2(F[i][0], 1) * 2 / math.pi * Max_Speed
        v_i[1] = math.atan2(F[i][1], 1) * 2 / math.pi * Max_Speed


        v.append(v_i)
    return v

# 将力进行有界处理
def power2v(F,Max_Speed = Max_Speed_uav):

    v = []
    for i in range(F.__len__()):

        v_i = math.atan2(F[i], 1) * 2 / math.pi * Max_Speed

        v.append(v_i)
    return v

def forcelist2vlist(F_list):
    v_list = []
    for i in range(len(F_list)):
        v_i = []
        for j in range(len(F_list[0])):
            v_ij = [0,0]
            v_ij[0] = math.atan2(F_list[i][j][0], 1) * 2 / math.pi * Max_Speed_uav
            v_ij[1] = math.atan2(F_list[i][j][1], 1) * 2 / math.pi * Max_Speed_uav
            v_i.append(v_ij)
        v_list.append(v_i)
    return v_list
# 更新基站的位置 输出移动的距离
def update_state_position(BS_list,v):
    total_length = 0
    for i in range(len(BS_list)):

        BS_list[i].loc[0] += v[i][0]
        BS_list[i].loc[1] += v[i][1]
        lengthi = math.sqrt((v[i][0])**2+(v[i][1])**2)
        total_length += lengthi
    return total_length

def total_fr_power(Fr_list):
    Fr = []
    for i in range(len(Fr_list)):
        Fr_i = [0,0]
        for j in range(len(Fr_list[0])):
            Fr_i[0] += Fr_list[i][j][0]
            Fr_i[1] += Fr_list[i][j][1]
        Fr.append(math.sqrt(Fr_i[0]**2+Fr_i[1]**2))
    return Fr

def cal_network_rate(BS_list,UE_list):
    total_rate = 0
    for i in range(len(BS_list)):
        for j in range(len(UE_list)):
            dist2 = np.sqrt((BS_list[i].loc[0] - UE_list[j].loc[0]) ** 2 + (BS_list[i].loc[1] - UE_list[j].loc[1]) ** 2)
            snr = BS_list[i].power_allocation[j] - pathlossA2G(dist2, BS_list[i].loc[2]) - Noise
            total_rate += np.log2(1+np.power(10, 1 * snr / 10))
    return total_rate
def update_bs_power(BS_list):
    for i in range(len(BS_list)):
        BS_list[i].power = sum(BS_list[i].power_allocation)
# 更新基站的发射功率  ?更新连接关系
def update_state_power(BS_list,va_list,vr_list,link):

    for i in range(len(BS_list)):
        # 斥力:斥力可以求和
        if vr_list[i] != 0 :
            eita = (BS_list[i].power - vr_list[i])/BS_list[i].power
            for j in range(len(BS_list[0].power_allocation)):
                BS_list[i].power_allocation[j] *= eita

        # 引力

        # 建立连接再进行功率调整，未建立连接的全部是0
        for j in range(len(link[0])):
            if link[i][j] == 0:
                BS_list[i].power_allocation[j] = 0

            else:
                # 首次建立连接
                if BS_list[i].power_allocation[j] == 0:
                    BS_list[i].power_allocation[j] == power_t
                # 调整基站i对用户j的发射功率
                else:
                    # 引力的调整
                    BS_list[i].power_allocation[j] += math.sqrt((va_list[i][j][0])**2 + (va_list[i][j][1])**2)


def update_state_power2(BS_list, va, vr_list, link):
    for i in range(len(BS_list)):
        # 斥力:斥力可以求和
        if vr_list[i] != 0:
            eita = (BS_list[i].power - vr_list[i]) / BS_list[i].power
            for j in range(len(BS_list[0].power_allocation)):
                BS_list[i].power_allocation[j] *= eita

        # 引力
        # 建立连接再进行功率调整，未建立连接的全部是0
        for j in range(len(link[0])):
            if link[i][j] == 0:
                BS_list[i].power_allocation[j] = 0

            else:
                # 首次建立连接
                if BS_list[i].power_allocation[j] == 0:
                    BS_list[i].power_allocation[j] == power_t
                # 调整基站i对用户j的发射功率
                else:
                    # 引力的调整
                    eita = (BS_list[i].power + math.sqrt(va[i][0]**2+va[i][1]**2)) / BS_list[i].power
                    BS_list[i].power_allocation[j] *= eita

def draw_uav_update(BS_list,basecolor='red',LoSscale=1,scalecolor='k',size=40,islegend=1):
    x = []
    y = []
    for i in range(len(BS_list)):
        x.append(round(BS_list[i].loc[0],2))
        y.append(round(BS_list[i].loc[1],2))
    for i in range(len(BS_list)):
        #uav_x[i] += 0.5
        #uav_y[i] += 0.5

        #if BS_list[i].sort == 'MBS':
            #R = R_LOS
        #    R = R_mbs

        R = 26


        plt.scatter(BS_list[i].loc[0], BS_list[i].loc[1], color=basecolor, marker='v', s=size)
        if LoSscale == 0: continue
        x1 = np.arange(BS_list[i].loc[0] - R, BS_list[i].loc[0] + R, 0.0001)  # 点的范围
        y1 = np.sqrt(R ** 2 - np.power((x1 - BS_list[i].loc[0]), 2)) + BS_list[i].loc[1]  # 上半个圆的方程
        x2 = np.arange(BS_list[i].loc[0] - R, BS_list[i].loc[0] + R, 0.0001)
        y2 = -1 * np.sqrt(R ** 2 - np.power((x2 - BS_list[i].loc[0]), 2)) + BS_list[i].loc[1]
        # 标注
        plt.plot(x1, y1, x2, y2, color=scalecolor, linestyle=':', linewidth=2)
        """x1 = np.arange(uav_x[i] - R_NLOS, uav_x[i] + R_NLOS, 0.0001)  # 点的范围
        y1 = np.sqrt(R_NLOS ** 2 - np.power((x1 - uav_x[i]), 2)) + uav_y[i]  # 上半个圆的方程
        x2 = np.arange(uav_x[i] - R_NLOS, uav_x[i] + R_NLOS, 0.0001)
        y2 = -1 * np.sqrt(R_NLOS ** 2 - np.power((x2 - uav_x[i]), 2)) + uav_y[i]
        # 标注
        plt.plot(x1, y1, x2, y2, color='k', linestyle=':', linewidth=2)"""
        if i==0 and LoSscale==1 and islegend ==1:
            plt.plot(x1, y1, color='k', linestyle=':', linewidth=2,label='基站覆盖边界')
    i = 1
    for a, b in zip(x, y):  # 添加这个循环显示坐标
        plt.text(a, b, i, ha='center', va='bottom', fontsize=10)
        i+=1

def draw_uavus_link(UE_list,BS_list,linked,basecolor='red',size=40,):
    x = []
    y = []
    for i in range(len(BS_list)):
        x.append(round(BS_list[i].loc[0], 2))
        y.append(round(BS_list[i].loc[1], 2))
    for i in range(len(BS_list)):

        plt.scatter(BS_list[i].loc[0], BS_list[i].loc[1], color=basecolor, marker='v', s=size)

    for i in range(len(UE_list)):
        colorue = 'blue'
        plt.scatter(UE_list[i].loc[0], UE_list[i].loc[1], color=colorue, marker='x', s=40)
    plt.scatter(BS_list[0].loc[0],BS_list[0].loc[1],color=basecolor, marker='v', s=size,label='无人机基站')

    for i in range(len(linked)):
        for j in range(len(linked[i])):
            plt.plot([BS_list[i].loc[0], UE_list[linked[i][j]].loc[0]], [BS_list[i].loc[1], UE_list[linked[i][j]].loc[1]], color='g', linestyle='-', linewidth=1)