import init
import pyswarms as ps
from rsma import *
from UA_game import *

class UE:
    def __init__(self):
        self.loc = [0, 0, 0]
        self.snrth = 0
        self.uav = -1

class BS:
    def __init__(self):
        self.loc = [0, 0, 0]
        self.sort = 'UAV'
        self.power_upper = 0
        self.power_allocation = []  # 长度为用户数
        self.power_a2a = 0
        self.power = 0
        self.rate = 0
        self.rate_upper = 0

numfbs = 0
nummbs = 0
sumgbs = nummbs + numfbs
R_mbs = 25
R_uav = 25
initial_ratio = 0.5
varth = 1
# 初始化功率分配比例
public_power_ratio = initial_ratio
eta_step_size = 0.02
private_power_ratio = 1 - public_power_ratio

snrth = -5

def read_UElist(ue_addr):
    ue_list = []
    with open(ue_addr, 'r') as f:
        f = csv.reader(f, delimiter=',')
        next(f)
        for row in f:
            ue_list.append(row)
    for i in range(len(ue_list)):
        ue_list[i] = list(map(float, ue_list[i]))
    UE_list = []
    for i in range(len(ue_list)):
        new_ue = UE()
        new_ue.loc = [ue_list[i][0], ue_list[i][1], 0]
        new_ue.sort = 'ue'
        new_ue.uav = -1
        new_ue.snrth = dB2powerratio(snrth)
        UE_list.append(new_ue)
    return ue_list,UE_list


def get_sinr_variance(sinr_list):
    # 过滤非零值
    nonzero_values = [x for x in sinr_list if x != 0]
    # 计算平均值
    mean = np.mean(nonzero_values)
    # 计算差值的平方和
    squared_diff_sum = sum((x - mean) ** 2 for x in nonzero_values)
    # 计算方差
    variance = squared_diff_sum / len(nonzero_values)
    return variance


def optimize_common_power_eta(UE_list,BS_list):
    def fitness_function(position):
        power_rsma_equal_allocation(BS_list, UE_list, link)
        snr, dist = cal_rsma_sinr(UE_list, BS_list)
        sinr_matrix = np.transpose(snr)
        sinr_list = sinr_matrix[i]
        var = get_sinr_variance(sinr_list)
        return var

    def optimize_eta():
        # 设置 PSO 参数
        num_particles = 100
        dimensions = 1
        options = {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
        bounds = (np.array([0.2]), np.array([0.8]))
        optimizer = ps.single.GlobalBestPSO(n_particles=num_particles, dimensions=dimensions, options=options,
                                            bounds=bounds)
        # 运行 PSO 优化算法
        best_fitness, best_eta = optimizer.optimize(fitness_function, iters=10)
        # print('pso', best_eta, best_fitness)
        return best_eta
    snr, dist = cal_rsma_sinr(UE_list, BS_list)
    link, linked = game_process(BS_list, UE_list, snr, dist)
    power_rsma_equal_allocation(BS_list, UE_list, link)
    snr, dist = cal_rsma_sinr(UE_list, BS_list)

    for i in range(len(BS_list)):
        uav = BS_list[i]
        sinr_matrix = np.transpose(snr)
        sinr_list = sinr_matrix[i]
        sinr_abs_list = [abs(x) for x in sinr_list if x != 0]
        # 找到绝对值列表中的最小值
        min_abs_value = min(sinr_abs_list) if len(sinr_abs_list) else 0
        if (min_abs_value - 0) < 0.1:
            continue
        uav.eta = optimize_eta()
        uav.common_power = uav.eta[0]*uav_power_upper
        # print(uav.eta)
if __name__ == '__main__':
    addr = r"D:\pycharm\deploy_algorithms\data\list_ue50_6.csv"
    ue_list, UE_list = read_UElist(addr)
    numuav = 4
    R = 23
    BS_list = init.main_rsma(UE_list.copy(), sumgbs, R_mbs=R_mbs, sumuav=numuav, R_uav=R_uav)
    optimize_common_power_eta(UE_list,BS_list)
    for uav in BS_list:
        print(uav.eta,uav.power_allocation,sum(uav.power_allocation))
    # BS_list = KMeans.main(ue_list,R)
    # snr, dist = cal_rsma_sinr(UE_list, BS_list)
    # link, linked = game_process(BS_list, UE_list, snr, dist)
    # power_rsma_equal_allocation(BS_list, UE_list, link)
    # snr, dist = cal_rsma_sinr(UE_list, BS_list)
    #
    # for i in range(len(BS_list)):
    #     uav = BS_list[i]
    #     sinr_matrix = np.transpose(snr)
    #     sinr_list = sinr_matrix[i]
    #     sinr_abs_list =  [abs(x) for x in sinr_list if x!=0]
    #     # 找到绝对值列表中的最小值
    #     min_abs_value = min(sinr_abs_list) if len(sinr_abs_list) else 0
    #     if (min_abs_value-0)<0.1:
    #         continue
    #     uav.eta = optimize_eta()
    # for uav in BS_list:
    #     print(uav.eta,sum(uav.power_allocation))