import math
import Users
import numpy as np
import random
from pprint import pprint
import constants as C
import tool
import ObjectiveFuction as OF

class PSO:
    def __init__(self,pre_uavs,users,r_c,r_p,k,N=70,maxgen=200,):
        #之后解析JSON，从JSON中读参数
        self.N=N
        self.k=k
        self.pre_uavs=pre_uavs
        self.users=users
        self.d = 3  # 可行解维数
        self.limit = [C.min_range,C.max_range]  # 边界位置限制
        self.vlimit = [-500, 500]  # 边界速度限制
        self.vlimit_h=[-10,10]
        self.r_c = r_c
        self.r_p=r_p
        # self.height = 100
        self.min_height = C.min_height
        self.max_height = C.max_height
        self.wmax = 0.9  # 惯性权重
        self.wmin = 0.4
        self.c1 = 1.496  # 自我学习因子
        self.c2 = 1.496  # 群体学习因子
        # self.c1 = 0.5  # 自我学习因子
        # self.c2 = 0.5  # 群体学习因子
        self.maxgen = maxgen
        # self.users = Users.getUsers("./data.csv")
        self.users = users
        self.init_alloc()


    # #todo：防止无人机位置碰撞
    # def punish_function(self,uavs):
    #     # user_arr = [1, 2, 5]
    #     punish=0
    #     for i in range(len(uavs)):
    #         for j in range(i+1,len(uavs)):
    #             dis=self.cal_tool.getDistance(uavs[i],uavs[j])
    #             if dis<200:
    #                 punish+=100
    #
    #     return punish
    def punish_f(self,uavs):
        punish=0
        max_dis=(C.time_slot-C.run_time)*C.speed
        for i in range(len(uavs)):
            dis=tool.calculate_3d_distance_uavs(uavs[i],self.pre_uavs[i])
            if dis>max_dis:
                punish=float('-inf')
        return punish

    def re_init(self, pre_uavs, users, r_c, r_p, c_association, p_association, p_all, b_all, ):
        self.pre_uavs = pre_uavs
        self.user = users
        self.r_c = r_c
        self.r_p = r_p
        self.c_association = c_association
        self.p_association = p_association
        self.p_all = p_all
        self.b_all = b_all

    # def init_r_c(self):
    #     r_c={}
    #     for i in range(len(self.users)):
    #         r_c[i] = random.randint(1,4)
    #     return r_c
    # 初始化粒子群，k是无人机数量
    def initParticle(self, users):
        X = []
        # users=Users.generate(100)#用户位置
        V = np.random.uniform(self.vlimit[0], self.vlimit[1], (self.N, self.k, 3))
        V2 = np.random.uniform(self.vlimit_h[0], self.vlimit_h[1], (self.N, self.k, 3))
        for i in range(self.N):
            uavs=[]
            if i==0:
                uavs = tool.getUavInitialPos(self.k, users)
            else:
                for index in range(self.k):
                    x=np.random.randint(self.limit[0],self.limit[1])
                    y=np.random.randint(self.limit[0],self.limit[1])
                    # z=np.random.randint(self.min_height,self.max_height)
                    z=C.min_height
                    uavs.append([x,y,z])


            X.append(uavs)
        X = np.array(X)
        return X, V,

    def init_alloc(self,):
        c_association, p_association, p_all, b_all = OF.objective_resource_avg(self.pre_uavs, self.users, self.r_c)
        self.c_association = c_association
        self.p_association = p_association
        self.p_all = p_all
        self.b_all = b_all


    # 初始化个体和全局最佳历史位置和最佳适应度
    def initBest(self, X,):
        p_pos = X.copy()  # 每个个体的历史最佳位置
        global_pos = np.zeros((self.k, 3))  # 种群的历史最佳位置5*3
        p_best = np.full((self.N, 1), float('-inf'))  # 每个个体的历史最佳适应度
        global_best = float('-inf')  # 种群历史最佳适应度
        return p_pos, global_pos, p_best, global_best


    def pso(self, X, V, update_w,get_fitness,):
        # X, V, users = initParticlePos()
        p_pos, global_pos, p_best, global_best = self.initBest(X)
        result = np.zeros(self.maxgen)

        fitness = np.zeros(self.N)
        for iter in range(self.maxgen):
            # fitness = self.getfitness(X, users, k,P)  # 计算所有粒子适应度

            for i in range(len(X)):
                # punish = self.punish_f(X[i])
                # if punish == float('-inf'):
                #     fitness[i] = punish
                # else:
                f, _, _ = get_fitness(self.pre_uavs, X[i], self.users, self.r_c, self.r_p, self.p_all, self.b_all,
                                      self.c_association, self.p_association)
                fitness[i] = f
                # f,_,_=get_fitness(self.pre_uavs,X[i],self.users,self.r_c,self.r_p)
                # f,_,_=get_fitness(self.pre_uavs, X[i], self.users, self.r_c,self.r_p, self.p_all, self.b_all, self.c_association,
                #             self.p_association)
                # fitness[i]= f+self.punish_f(X[i])
            for i in range(self.N):
                if fitness[i] > p_best[i]:
                    p_best[i] = fitness[i]  # 更新个体历史最佳适应度
                    p_pos[i] = X[i].copy()  # 更新个体历史最佳位置(取值)
            if max(p_best) > global_best:
                global_best = max(p_best)  # 更新群体历史最佳适应度
                max_index = p_best.argmax()
                global_pos = X[max_index].copy()  # 更新群体历史最佳位置


            # 权重更新
            # w = self.wmax - iter * (self.wmax - self.wmin) / self.maxgen
            # w = (self.wmax +self.wmin)/2
            w=update_w(self.wmin,self.wmax,iter,self.maxgen,iter,fitness)
            for i in range(self.N):
                V[i] = V[i] * w + self.c1 * random.random() * (p_pos[i] - X[i]) + self.c2 * random.random() * (
                        global_pos - X[i])
                # V2[i] = V2[i] * w + self.c1 * random.random() * (p_pos[i] - X[i]) + self.c2 * random.random() * (
                #         global_pos - X[i])
                # print(V[i])
                for j in range(len(V[i])):
                    for p in range(len(V[i][j])):
                        if V[i][j][p] > self.vlimit[1]:
                            V[i][j][p] = self.vlimit[1]
                            # V[i][j][p] = self.vlimit[0]+np.random.rand(1)*(self.vlimit[1]-self.vlimit[0])
                        elif V[i][j][p] < self.vlimit[0]:
                            V[i][j][p] = self.vlimit[0]
                        # if V2[i][j][p] > self.vlimit_h[1]:
                        #     V2[i][j][p] = self.vlimit_h[1]
                        #     # V[i][j][p] = self.vlimit[0]+np.random.rand(1)*(self.vlimit[1]-self.vlimit[0])
                        # elif V2[i][j][p] < self.vlimit_h[0]:
                        #     V2[i][j][p] = self.vlimit_h[0]
                # print(V[i])

            # 位置更新
            X = X + V
            # 边界位置处理
            for i in range(self.N):
                for j in range(len(X[i])):
                    for p in range(2):
                        if X[i][j][p] > self.limit[1]:
                            X[i][j][p] = self.limit[1]
                        elif X[i][j][p] < self.limit[0]:
                            X[i][j][p] = self.limit[0]
                    X[i][j][2]=self.min_height
                    # if X[i][j][2] > self.max_height:
                    #     X[i][j][2] = self.max_height
                    # elif X[i][j][2] < self.min_height:
                    #     X[i][j][2] = self.min_height

            result[iter] = global_best
            print(iter,global_pos,global_best)

        # print(global_best)
        # print(global_pos)

        return global_pos,global_best,result



if __name__ == '__main__':
    uavs=[[868.34535881,315.18340086,50.        ],
          [292.96443022,729.7612657,77.75450997],
 [732.25467955,927.62861821,50.        ],
 [303.40107629,240.31208079,150.        ]]
    # 示例
    # print(0.012 / 8 * 1.225 * 0.05 * 0.79 * 400 ** 3 * 0.5 ** 3)
    # print((1+0.1)*100**1.5/math.sqrt(2*1.225*0.79))
    best=[[335,772,67], [382,739,75], [1758,1683,137], [885,762,133], [403,1196,94], [1049,358,88]]







