# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from functions import *
from cluster_deploy import ClusterResult

UAV_UPDATE_FRE = 4

W_PP = 0.4
W_DIS = 1 - W_PP

class PSO:
    def __init__(self,ue_list):
        self.uav_list = []
        self.ue_list = ue_list
        self.weight = 1
        self.lr = (0.49445,1.49445)
        self.maxgen = 30
        self.rangepop = (5,256)     # HOM TTT
        self.sizepop = 30
        self.rangespeed = (-1,10)

        self.deploy_times = n_steps

        self.cluster = []
        self.desOfUE = [-1]*len(ue_list)
        self.srcOfUE = [0]*len(ue_list)
        self.timerOfUE = [0] * len(ue_list)
        self.enableSwitch = [0] * len(ue_list)

        self.stateOfSwitch = [0]*len(ue_list)   # 0默认 1切换成功 2切换失败
        self.disturbCnt = [0] * len(ue_list)   # 记录切换

        # 源-目-源 认为是pingpong
        # 切换记录
        self.pingpongRecord = [[] for _ in range(len(ue_list))]
        self.pingpongCnt = [0]*len(ue_list)
        self.pingpongP = []
        self.disturbP = []

        self.rsrpMaxtrix = []

    # 默认适应度
    def func(self,x):
        # x输入粒子位置
        # y 粒子适应度值
        if (x[0]==0)&(x[1]==0):
            y = np.exp((np.cos(2*np.pi*x[0])+np.cos(2*np.pi*x[1]))/2)-2.71289
        else:
            y = np.sin(np.sqrt(x[0]**2+x[1]**2))/np.sqrt(x[0]**2+x[1]**2)+np.exp((np.cos(2*np.pi*x[0])+np.cos(2*np.pi*x[1]))/2)-2.71289
        return y

    def fitOfSwitch(self):  # 适应度
        res = -1
        if(len(self.pingpongP)):
            res += W_PP*np.mean(self.pingpongP)
        if(len(self.disturbP)):
            res += W_DIS*np.mean(self.disturbP)
        return (res)
    def getPropPP(self):
        # 总切换次数
        total = 0
        ppCnt = 0
        for i in range(len(self.ue_list)):
            total += len(self.pingpongRecord[i])
            ppCnt += self.pingpongCnt[i]
        result = ppCnt/total if total != 0 else 0
        self.pingpongP.append(result)

    def getPropDis(self):
        cnt = 0
        for i in range(len(self.ue_list)):
            cnt += 1 if self.disturbCnt[i] else 0
        self.disturbP.append(cnt/len(self.ue_list))
    # 初始化
    def initpopvfit(self,sizepop):
        pop = np.zeros((sizepop, 2))
        v = np.zeros((sizepop, 2))
        fitness = np.ones(sizepop)
        for i in range(sizepop):
            pop[i] = [np.random.rand() * self.rangepop[0] * 2, np.random.rand() * self.rangepop[1] * 2]
            v[i] = [(np.random.rand()) * self.rangepop[0] * 2, (np.random.rand()) * self.rangepop[1] * 2]
            self.deploy(pop[i])
            fitness[i] = self.fitOfSwitch()#self.func(pop[i])
        print('初始化')
        return pop, v, fitness

    def getinitbest(self,fitness, pop):
        # 群体最优的粒子位置及其适应度值
        gbestpop, gbestfitness = pop[fitness.argmax()].copy(), fitness.max()
        # gbestpop, gbestfitness = pop[fitness.argmin()].copy(), fitness.min()
        # 个体最优的粒子位置及其适应度值,使用copy()使得对pop的改变不影响pbestpop，pbestfitness类似
        pbestpop, pbestfitness = pop.copy(), fitness.copy()

        return gbestpop, gbestfitness, pbestpop, pbestfitness

    def getRSRP(self,ue_list,uav_list):
        self.rsrpMaxtrix = [[0] * len(uav_list) for _ in range(len(ue_list))]
        for i in range(len(ue_list)):
            for j in range(len(uav_list)):
                self.rsrpMaxtrix[i][j] = get_rss(ue_list[i],uav_list[j])

    def getDistiniation(self,x):
        # 贪婪 待匹配
        for i in range(len(self.ue_list)):
            rsrp_i = self.rsrpMaxtrix[i]
            candidate = []
            values = []
            for j in range(len(rsrp_i)):
                for k in range(j+1,len(rsrp_i)):
                    if rsrp_i[k] > rsrp_i[j] + x[0]:
                        candidate.append(k)
                        values.append(rsrp_i[k])
            if len(candidate):  # 贪婪/匹配 to do here
                if (self.enableSwitch[i]):  # 已经在切换
                    if rsrp_i[k] < rsrp_i[j] + x[0]:
                        self.enableSwitch[i] = -1 # 切换失败
                    pass
                else:   # 触发切换
                    idx = np.argmax(values)
                    if candidate[idx] != self.srcOfUE[i]:
                        self.desOfUE[i] = candidate[idx]  # 满足切换条件1 - 功率
                        self.enableSwitch[i] = 1
    def resetDes(self,x):
        # 每个用户关联的无人机  匹配待增加
        for i in range(len(self.ue_list)):
            self.srcOfUE[i] = self.cluster.labels_DBS[i]
        self.getDistiniation(x)

    def checkPingPong(self):
        for i in range(len(self.ue_list)):
            record_i = self.pingpongRecord[i]
            if(len(record_i)<3):
                pass
            elif record_i[-1] == record_i[-3]:
                self.pingpongCnt[i] += 1
    def checkSwitch(self,x):
        for i in range(len(self.ue_list)):
            if self.enableSwitch[i] == 1:    # 切换中
                self.timerOfUE[i] += 1
                self.stateOfSwitch[i] = 1 if self.timerOfUE[i] > x[1] else -1    # 切换的状态
            elif self.enableSwitch[i] == -1:
                self.disturbCnt[i] += 1
            # record
            if self.stateOfSwitch[i]:
                if(len(self.pingpongRecord[i])) and self.pingpongRecord[i][-1] != self.desOfUE[i]:
                    self.pingpongRecord[i].append(self.desOfUE[i])
                else:
                    self.pingpongRecord[i].append(self.srcOfUE[i])
                    self.pingpongRecord[i].append(self.desOfUE[i])

    def deploy(self,x):
        c1 = ClusterResult(positions[:, 0, :], K, eps)
        c1.process()
        self.cluster = c1
        uav_list = c1.result_DBS
        self.uav_list = uav_list
        self.disturbP = []
        self.pingpongP = []
        for i in range(self.deploy_times):
            ue_list = positions[:, i, :]
            if i % UAV_UPDATE_FRE == 0:  # 更新无人机位置
                c1 = ClusterResult(ue_list, K, eps)
                c1.process()
                self.cluster = c1
                uav_list = c1.result_DBS
                self.uav_list = uav_list
            self.getRSRP(ue_list, uav_list)  # 得到rsrp
            self.resetDes(x)
            self.checkSwitch(x)
            self.checkPingPong()
            self.getPropPP()    # 适应度1
            self.getPropDis()

    def save_result(self,result):
        rate = []
        for i in range(len(result)):
            tmp = [i, result[i]]
            rate.append(tmp)
        csv_name = 'pso_result_high_pp0.csv'
        csv_file = open(csv_name, 'w', newline='')
        writer = csv.writer(csv_file)
        writer.writerows(rate)
        csv_file.close()
    def process(self):
        pop, v, fitness = self.initpopvfit(self.sizepop)  # 初始化 待修改
        gbestpop, gbestfitness, pbestpop, pbestfitness = self.getinitbest(fitness, pop) # 初始化

        result = np.zeros(self.maxgen)
        c1 = ClusterResult(positions[:, 0, :], K, eps)
        c1.process()
        self.cluster = c1
        uav_list = c1.result_DBS
        self.uav_list = uav_list
        for i in range(self.maxgen):
            print(i)
            t = 0.5
            # 速度更新
            for j in range(self.sizepop):
                # print('iter',j)
                self.deploy(pop[j])
                v[j] += self.lr[0] * np.random.rand() * (pbestpop[j] - pop[j]) + self.lr[1] * np.random.rand() * (
                            gbestpop - pop[j])
            v[v < self.rangespeed[0]] = self.rangespeed[0]
            v[v > self.rangespeed[1]] = self.rangespeed[1]

            # 粒子位置更新
            for j in range(self.sizepop):
                pop[j] = t * (0.5 * v[j]) + (1 - t) * pop[j]
            pop[pop < self.rangepop[0]] = self.rangepop[0]
            pop[pop > self.rangepop[1]] = self.rangepop[1]

            # 适应度更新
            for j in range(self.sizepop):
                fitness[j] = self.fitOfSwitch() #self.func(pop[j])

            for j in range(self.sizepop):
                if fitness[j] > pbestfitness[j]:
                    pbestfitness[j] = fitness[j]
                    pbestpop[j] = pop[j].copy()

            if pbestfitness.max() > gbestfitness:
                gbestfitness = pbestfitness.max()
                gbestpop = pop[pbestfitness.argmax()].copy()

            result[i] -= gbestfitness
        print(result)
        self.save_result(result)
        plt.figure()
        plt.plot(result)
        plt.show()

if __name__ == '__main__':
    addr = r"data\list_ue100_7.csv"
    R = 20
    ue_list_0 = read_UElist(addr)[:30]
    n_steps = 10    # 事件次数
    positions = random_walk_2D(ue_list_0, n_steps)
    K = 3
    eps = 20
    # plt.show()
    p1 = PSO(positions[:, 0, :])

    p1.process()
