import copy
import math
import numpy as np
import virtual_force
Noise = -130
uav_power_upper = 3
uav_aver_upper = 0.6
def init_preferred_coalition(snr,UE_list):
    coalition = []
    for i in range(len(snr[0])):
        coalition.append([])
    # 用户选择速率最高（且大于阈值）的无人机连接
    for i in range(len(UE_list)):
        idx = np.argmax(snr[i])
        if snr[i][idx] >= UE_list[i].snrth:
            coalition[idx].append(i)
    return coalition

def get_load_coalition(BS_list,UE_list,snr):

    # 先清空
    for i in range(len(UE_list)):
        UE_list[i].uav = -1
    utility = np.zeros([len(BS_list),len(UE_list)])
    # 用户选择速率最高（且大于阈值）的无人机连接
    for i in range(len(UE_list)):
        idx = np.argmax(snr[i])
        if snr[i][idx] >= UE_list[i].snrth:
            UE_list[i].uav = idx
            utility[idx][i] = snr[i][idx]
    return utility

# user 从 联盟i 到 联盟j
# transfer factor
def get_transfer_factor(user,idx_i,idx_j,utility):
    # user from coal_i to coal_j
    Ui_ = 0
    for i in range(len(utility[idx_i])):
        if i!=user:
            Ui_ += utility[idx_i][i]
    Uj_ = utility[idx_i][user]
    for j in range(len(utility[idx_j])):
        Uj_ += utility[idx_j][j]
    Um_sum = 0
    for i in range(len(utility)):
        if i != idx_i and i!=idx_j:
            for j in range(len(utility[i])):
                Um_sum += utility[i][j]

    Ui = 0
    for i in range(len(utility[idx_i])):
        Ui += utility[idx_i][i]
    Uj = 0
    for i in range(len(utility[idx_j])):
        Uj += utility[idx_j][i]
    U_sum = 0
    for i in range(len(utility)):
        if i!= idx_i and i!=idx_j:
            for j in range(len(utility[i])):
                U_sum += utility[i][j]

    return ((Ui_+Uj_+Um_sum)-(Ui+Uj+U_sum))*1e13
# pl(dB) = 10*log(pr/pt)
def dB2powerratio(dB):
    return np.power(10,dB/10)
def dBm2w(dBm):
    return 0.001*np.power(10,dBm/10)
def pathlossA2G(dist2,h):
    '''
    :param dist2: 基站在水平面的投影与用户的距离
    :param h:高度
    :return:空地倾斜路径损耗
    '''
    f = 1.4e9  # MHz
    d0 = 1
    c = 3e8
    alpha = 3.5
    pi = math.pi
    A = 0.25
    C = 0.39
    E = 0.25
    G = 0
    H = 0.05
    dep_tree = 2  # 植被深度
    dist3d = math.sqrt((dist2*100)**2 + h**2)
    fspl = 20 * np.log10(4 * pi * f * d0 / c) + 10 * alpha * math.log10(dist3d / d0)
    theta = math.atan2(h, dist2*100)
    slant =  A * np.power(f / math.pow(10, 6), C) * np.power(dep_tree, E) * np.power(theta + G, H)
    return fspl + slant

def cal_snr_UAV2UE(power,dist):
    pldB = pathlossA2G(dist, h=200)
    pl = dB2powerratio(pldB)
    snr = (power / pl) / dBm2w(Noise)
    return snr
# 联盟j 选择接收的用户m 用户m原来在联盟i
def transfer_condition(user,idx_i,idx_j,utility,snr,BS_list,UE_list):
    if snr[user][idx_j] < UE_list[user].snrth :
        return 0
    Ui_ = 0
    for i in range(len(utility[idx_i])):
        if i != user:
            Ui_ += utility[idx_i][i]
    power = BS_list[idx_i].power_allocation[user]
    dist = np.sqrt((BS_list[idx_j].loc[0]-UE_list[user].loc[0])**2+(BS_list[idx_j].loc[1]-UE_list[user].loc[1])**2)
    Uj_ = cal_snr_UAV2UE(power,dist)
    for j in range(len(utility[idx_j])):
        Uj_ += utility[idx_j][j]

    Ui = 0
    for i in range(len(utility[idx_i])):
        Ui += utility[idx_i][i]
    Uj = 0
    for i in range(len(utility[idx_j])):
        Uj += utility[idx_j][i]
    # print(':',(Ui_ + Uj_) - (Ui + Uj))
    return (Ui_ + Uj_) - (Ui + Uj)

# user transfer from coalition idx_i to coalition idx_j
def transfer_condition_withPC(user,idx_i,idx_j,utility,dist,BS_list,UE_list):

    Ui = 0
    for i in range(len(utility[idx_i])):
        Ui += utility[idx_i][i]
    Uj = 0
    for i in range(len(utility[idx_j])):
        Uj += utility[idx_j][i]
    rate_decrease_i = utility[idx_i][user]
    # 更改连接关系
    power_totali = sum(BS_list[idx_i].power_allocation)
    power_totalj = sum(BS_list[idx_j].power_allocation)
    if len(np.nonzero(BS_list[idx_i].power_allocation)[0]) - 1>0:
        eita_tmp = BS_list[idx_i].power_allocation[user] / (len(np.nonzero(BS_list[idx_i].power_allocation)[0]) - 1)
    else:
        eita_tmp = 0
    # eita_i = (power_totali+BS_list[idx_i].power_allocation[user])/power_totali
    # eita_j = (power_totalj-BS_list[idx_i].power_allocation[user])/power_totalj
    # eita_i = 1
    eita_j = 1
    Ui_ = 0
    for i in range(len(UE_list)):
        if i != user:
            Ui_ += cal_snr_UAV2UE(BS_list[idx_i].power_allocation[i]+eita_tmp,dist[i][idx_i])#utility[idx_i][i]

    Uj_ = cal_snr_UAV2UE(BS_list[idx_i].power_allocation[user],dist[user][idx_j])
    for j in range(len(UE_list)):
        Uj_ += cal_snr_UAV2UE(BS_list[idx_j].power_allocation[j]*eita_j,dist[j][idx_j])#utility[idx_j][j]
    tmp = (Ui_ + Uj_) - (Ui + Uj)
    # if tmp>0:
    #     print(':',tmp,Ui_ , Uj_,Ui, Uj,BS_list[idx_i].power_allocation[user],dist[user][idx_i],dist[user][idx_j],eita_i)
    return (Ui_ + Uj_) - (Ui + Uj)

def get_link(UE_list,BS_list):
    link = np.zeros([len(BS_list),len(UE_list)])

    for j in range(len(UE_list)):
        if UE_list[j].uav > -1:
            link[UE_list[j].uav][j] = 1
    linked = []
    for i in range(len(BS_list)):
        tmp1 = []
        for j in range(len(UE_list)):
            if link[i][j] == 1:
                tmp1.append(j)
        linked.append(tmp1)
        BS_list[i].link = len(tmp1)

    # print(sum(sum(i) for i in link))
    return link,linked

def update_connection(UAV_i,UAV_j,user):
    power_totali = sum(UAV_i.power_allocation)
    power_totalj = sum(UAV_j.power_allocation)
    eita_tmp = UAV_i.power_allocation[user]/(len(np.nonzero(UAV_i.power_allocation)[0]))    # -1

    # print(sum(UAV_j.power_allocation), sum(UAV_i.power_allocation))
    # print(UAV_i.power_allocation[user],power_totali,power_totalj,eita_i,eita_j)
    for i in range(len(UAV_i.power_allocation)):
        if UAV_i.power_allocation[i]>0:
            UAV_i.power_allocation[i] += eita_tmp
        # UAV_i.power_allocation[i] *= eita_i
    # for i in range(len(UAV_j.power_allocation)):
    #     UAV_j.power_allocation[i] *= eita_j
    # print(UAV_j.power_allocation[user] ,UAV_i.power_allocation[user])
    UAV_j.power_allocation[user] = max(UAV_i.power_allocation[user] + eita_tmp,uav_aver_upper)
    UAV_i.power_allocation[user] = 0
    # print(sum(UAV_j.power_allocation), sum(UAV_i.power_allocation))

def game_compare(BS_list,UE_list,snr):
    utility = get_load_coalition(BS_list, UE_list, snr)
    link, linked = get_link(UE_list, BS_list)
    return link, linked
def game_process2(BS_list,UE_list,snr,dist):
    # 无人机接收用户
    transfer_factors = np.zeros([len(BS_list),len(UE_list)])
    utility = get_load_coalition(BS_list,UE_list,snr)
    # print(1,get_link(UE_list,BS_list))
    # for i in range(len(UE_list)):
    #     if UE_list[i].uav > -1 and dist[i][UE_list[i].uav] > 25:
    #         print(UE_list[i].uav, snr[i][UE_list[i].uav], dist[i][UE_list[i].uav])
    iter_ = 5
    while iter_:
        iter_-=1
        transfer_factors = np.zeros([len(BS_list), len(UE_list)])
        utility_sum = sum(sum(i) for i in utility)
        # 计算每个用户需不需要更改连接的无人机
        for i in range(len(UE_list)):
            idx0 = UE_list[i].uav
            # if BS_list[idx0].power_ratio < 0.5:continue
            if BS_list[idx0].link+0 > 10:continue
            # 该用户连接了无人机
            if(idx0>-1):
                for j in range(len(BS_list)):
                    # if j == idx0 or BS_list[j].power_ratio > 0.5 :continue
                    if j == idx0 or BS_list[j].link > 10:continue
                    if cal_snr_UAV2UE(BS_list[idx0].power_allocation[i],dist[i][j])<UE_list[i].snrth:continue
                    transfer_factors_ij = transfer_condition_withPC(i,idx0,j,utility,dist,BS_list,UE_list)
                    if transfer_factors_ij>0:
                        # 用户i从与idx0连接 变成与j连接 可以提升的sinr
                        transfer_factors[j][i] = transfer_factors_ij

        # print(virtual_force.cal_network_rate(BS_list, UE_list, linked))
        # 更改
        for i in range(len(BS_list)):
            # 无人机i选择最好的用户接收
            idx = np.argmax(transfer_factors[i])
            uav_before = UE_list[idx].uav
            # print('?',transfer_factors[i][idx],BS_list[uav_before].power_allocation[idx])
            if transfer_factors[i][idx]>0.00000000001 and BS_list[uav_before].power_allocation[idx]>0:
                update_connection(BS_list[uav_before], BS_list[i], idx)
                UE_list[idx].uav = i
                utility[i][idx] = cal_snr_UAV2UE(BS_list[i].power_allocation[idx],dist[idx][i])
                # print('transfer',transfer_factors[i][idx])
                utility[uav_before][idx] = 0
                break
                # print(virtual_force.cal_network_rate(BS_list, UE_list, linked))
        # if(abs(utility_sum - sum(sum(i) for i in utility)) < 1):
        #     break
    link,linked = get_link(UE_list, BS_list)
    return link,linked

def game_process(BS_list,UE_list,snr,dist):
    link_upper = 10
    # 无人机接收用户
    transfer_factors = np.zeros([len(BS_list),len(UE_list)])
    utility = get_load_coalition(BS_list,UE_list,snr)

    # for i in range(len(UE_list)):
    #     if UE_list[i].uav > -1 and dist[i][UE_list[i].uav] > 25:
    #         print(UE_list[i].uav, snr[i][UE_list[i].uav], dist[i][UE_list[i].uav])
    iter_ = 1
    while iter_:
        iter_-=1
        transfer_factors = np.zeros([len(BS_list), len(UE_list)])
        utility_sum = sum(sum(i) for i in utility)
        # 计算每个用户需不需要更改连接的无人机
        for i in range(len(UE_list)):
            idx0 = UE_list[i].uav
            # if BS_list[idx0].power_ratio < 0.5:continue
            if  BS_list[idx0].link < link_upper:continue
            # 该用户连接了无人机
            if(idx0>-1):
                for j in range(len(BS_list)):
                    # if j == idx0 or BS_list[j].power_ratio > 0.5 :continue
                    if j == idx0 or BS_list[j].link > link_upper:continue
                    snr_ij_new = cal_snr_UAV2UE(BS_list[idx0].power_allocation[i],dist[i][j])
                    if snr_ij_new < UE_list[i].snrth:continue
                    #transfer_factors_ij = snr_ij_new#transfer_condition_withPC(i,idx0,j,utility,dist,BS_list,UE_list)

                    transfer_factors[j][i] = snr_ij_new

        # print(virtual_force.cal_network_rate(BS_list, UE_list, linked))
        # 更改
        for i in range(len(BS_list)):
            # 无人机i选择最好的用户接收
            idx = np.argmax(transfer_factors[i])
            uav_before = UE_list[idx].uav
            # print('?',transfer_factors[i][idx],BS_list[uav_before].power_allocation[idx])
            if transfer_factors[i][idx]>0.00000000001 and BS_list[uav_before].power_allocation[idx]>0:
                update_connection(BS_list[uav_before], BS_list[i], idx)
                UE_list[idx].uav = i
                utility[i][idx] = cal_snr_UAV2UE(BS_list[i].power_allocation[idx],dist[idx][i])
                # print('transfer',transfer_factors[i][idx])
                utility[uav_before][idx] = 0
                break
                # print(virtual_force.cal_network_rate(BS_list, UE_list, linked))
        # if(abs(utility_sum - sum(sum(i) for i in utility)) < 1):
        #     break
    link,linked = get_link(UE_list, BS_list)
    return link,linked

def game_process_rsma(BS_list,UE_list,snr,dist):
    link_upper = 10
    # 无人机接收用户
    transfer_factors = np.zeros([len(BS_list),len(UE_list)])
    utility = get_load_coalition(BS_list,UE_list,snr)

    # for i in range(len(UE_list)):
    #     if UE_list[i].uav > -1 and dist[i][UE_list[i].uav] > 25:
    #         print(UE_list[i].uav, snr[i][UE_list[i].uav], dist[i][UE_list[i].uav])
    iter_ = 1
    while iter_:
        iter_-=1
        transfer_factors = np.zeros([len(BS_list), len(UE_list)])
        utility_sum = sum(sum(i) for i in utility)
        # 计算每个用户需不需要更改连接的无人机
        for i in range(len(UE_list)):
            idx0 = UE_list[i].uav
            # if BS_list[idx0].power_ratio < 0.5:continue
            if  BS_list[idx0].link < link_upper:continue
            # 该用户连接了无人机
            if(idx0>-1):
                for j in range(len(BS_list)):
                    # if j == idx0 or BS_list[j].power_ratio > 0.5 :continue
                    if j == idx0 or BS_list[j].link > link_upper:continue
                    snr_ij_new = cal_snr_UAV2UE(BS_list[idx0].power_allocation[i],dist[i][j])
                    if snr_ij_new < UE_list[i].snrth:continue
                    #transfer_factors_ij = snr_ij_new#transfer_condition_withPC(i,idx0,j,utility,dist,BS_list,UE_list)

                    transfer_factors[j][i] = snr_ij_new

        # print(virtual_force.cal_network_rate(BS_list, UE_list, linked))
        # 更改
        for i in range(len(BS_list)):
            # 无人机i选择最好的用户接收
            idx = np.argmax(transfer_factors[i])
            uav_before = UE_list[idx].uav
            # print('?',transfer_factors[i][idx],BS_list[uav_before].power_allocation[idx])
            if transfer_factors[i][idx]>0.00000000001 and BS_list[uav_before].power_allocation[idx]>0:
                update_connection(BS_list[uav_before], BS_list[i], idx)
                UE_list[idx].uav = i
                utility[i][idx] = cal_snr_UAV2UE(BS_list[i].power_allocation[idx],dist[idx][i])
                # print('transfer',transfer_factors[i][idx])
                utility[uav_before][idx] = 0
                break
                # print(virtual_force.cal_network_rate(BS_list, UE_list, linked))
        # if(abs(utility_sum - sum(sum(i) for i in utility)) < 1):
        #     break
    link,linked = get_link(UE_list, BS_list)
    return link,linked
if __name__ == '__main__':
    coal = []
    for i in range(5):
        coal.append([])
    coal[0].append(2)
    print(coal)