import copy
import math

import numpy as np
from matplotlib import pyplot as plt

from envs.AerialVehicle import AerialVehicle
from numpy import random
import torch
from envs.UserForUAV import UserForUAV

CustomerNum = 40
AerialVehiclesNum = 5
TotalContentNum = 30
MaxPlaceX = 500
MaxPlaceY = 500
Move = 15
ContentSize = 256
CacheLimit = 3600
Hight = 200
EnvA = 11.9
EnvB = 0.13
Frequency = 2e9
Bandwidth = 40e6
TransmitPower = 2
SpeedOfLight = 3e8
AvgLOS = 6
AvgNLOS = 20
ConstrainLOS = 0.02
Noise = 1e-13
DownSize = 0

FalvorNum = 2
K = 3

transSpeedBaseLine = 0.5
timeLImit = 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

HID=[256,256,256]
HID_SIZE = 3
LOG_SIG_MAX = 2#max_log
LOG_SIG_MIN = -20#min_log
alpha=0.036
MIN_K = 2
RANDOM_STEP = 2000
class EnvCore(object):
    dimension1 = AerialVehiclesNum + 1
    dimension2 = 2 * AerialVehiclesNum
    dimension3 = int(CacheLimit // (ContentSize / K)) + CustomerNum
    dimension4 = 2* AerialVehiclesNum +  CustomerNum + TotalContentNum
    AerialVehiclesNum = AerialVehiclesNum


    def __init__(self):

        self.CustomerNum = CustomerNum
        self.AerialVehiclesNum = AerialVehiclesNum
        # StartSystem+
        self.agent_num = AerialVehiclesNum

        self.MaxPlaceX = MaxPlaceX
        self.MaxPlaceY = MaxPlaceY
        self.PlaceList = []

        self.MinPlaceX = 0
        self.MinPlaceY = 0

        self.k = K
        self.Move = Move

        self.FirstStep = True

        xDelta = self.MaxPlaceX - self.MinPlaceX
        yDelta = self.MaxPlaceY - self.MinPlaceY
        self.initX = xDelta * np.random.uniform(0, 1,
                                                self.AerialVehiclesNum) + self.MinPlaceX  # x coordinates of Poisson points
        self.initY = yDelta * np.random.uniform(0, 1, self.AerialVehiclesNum) + self.MinPlaceY

        self.state = []
        self.action = []
        self.obs_dim = self.getStateDimension()  # 设置智能体的观测维度 # set the observation dimension of agents
        self.action_dim = self.getActionDimension()

        self.AerialVehicles = [self.newUAV() for _ in range(AerialVehiclesNum)]
        self.Customers = [self.newUser(i, self.AerialVehicles) for i in range(CustomerNum)]

        self.resetCustomer()
        self.resetUAV()

        self.nowStep = 0
        self.epi = 0

        self.bags = 0
        self.slices = 0
        self.time = 0

        self.allBags = []
        self.allSlices = []
        self.ks = []
        self.epis = []

        self.ContentNum = TotalContentNum
        self.CacheLimit = CacheLimit
        size = ContentSize / self.k
        self.CacheNum = int(self.CacheLimit // size)

        self.CacheContent = [0 for _ in range(  self.CacheNum )]
        self.CacheNumNow = 0
        self.CacheDict = {}
        self.Recommend = 0

        self.AllocReward = 0
        self.DownSize = DownSize

        self.dimension1 = self.getDimension1()
        self.dimension2 = self.getDimension2()
        self.dimension3 = self.getDimension3()
        self.dimension4 = self.getDimension4()
        print("11111111111111111111111")
        print(self.dimension1)
        print(self.dimension2)
        print(self.dimension3)
        print(self.dimension4)
        print(self.dimension1*5 +  self.dimension2 +  self.dimension3 +  self.dimension4)

        self.trace = []
        self.clearTrace()








    def resetUAV(self):
        self.resetUAVPosition()
        self.resetUAVNumK()
        self.resetServiceList()

    def resetUAVCache(self):
        for i in range(0, len(self.AerialVehicles)):
            for j in range(0, TotalContentNum):
                self.AerialVehicles[i].cacheContent(j)

    def resetUAVPosition(self):
        self.RandomUAVPosition()

    # 重置K值，目前先固定
    def resetUAVNumK(self):
        self.setUAVNumK(K)

    def setUAVNumK(self, k):
        self.k = k
        for i in range(0, len(self.AerialVehicles)):
            self.AerialVehicles[i].setK(k)


    def step(self, actions):
        sub_agent_obs = []
        sub_agent_reward = []
        sub_agent_done = []
        sub_agent_info = []
        self.resetServiceList()

        punish = []


        for i in range(0, len(self.AerialVehicles)):
            act = actions[i]
            dist = act[0] * self.Move
            pi = act[1] * math.pi
            punish.append(self.AerialVehicles[i].moveToByDist(dist, pi))

        self.addService(actions[self.AerialVehiclesNum])
        self.resetCache(actions[self.AerialVehiclesNum + 1])


        totalReward = 0
        num = 0
        cacheMit = 0
        totalMiss = 0
        totalTime = 0
        missList = []
        times =  []

        for i in range(0, len(self.Customers)):
            reward, time, miss,  requestIndex= self.Customers[i].tryGetCache(transSpeedBaseLine, timeLImit)
            times.append(time)

            totalReward += reward
            missList.append(miss)

            totalMiss = totalMiss + miss

            if reward >= self.k:
                num += 1

            if requestIndex in self.CacheContent:
                cacheMit += 1
                self.CacheDict[requestIndex] = self.CacheDict[requestIndex] + 1

        for i in range(0, len(self.Customers)):
            avg = []
            for j in range(0, missList[i]):
               # print("bs")
               # print(self.AerialVehicles[i].cal_bs_communication_delay(totalMiss))
                avg.append(self.AerialVehicles[j].cal_bs_communication_delay(totalMiss))
            if missList[i] != 0:
                avg.sort()
                times[i] = avg[len(avg) - 1]

        for i in range(0, len(self.Customers)):
            totalTime += times[i]

        totalTime = totalTime /200

        totalReward = totalReward * 1 / self.k
        totalReward = totalReward + num * 50 + totalTime

        self.appendTrace()
        totalReward = totalReward
        self.nowStep += 1
        self.time += totalTime
        self.bags += num
        self.slices += totalReward
   #     print(self.k)
        if self.nowStep % 200 == 0:
           print('bags: '+str(self.bags))
           print('totalReward: '+str(self.slices))
           print('time: '+str(self.time))
          # print(times)
          # print(missList)

           if self.epi > 600:
               self.plotTrace()


           self.allBags.append(self.bags)
           self.allSlices.append(self.slices)
           self.ks.append(self.k)
           self.epis.append(self.epi)
           self.epi += 1



           if self.epi % 200 == 0:
               self.plotPic(self.epis, self.allBags)
               self.plotPic(self.epis, self.allSlices)
               self.plotPic(self.epis, self.ks)

               print(self.allBags)
               print(self.allSlices)
               print(times)
               print(missList)

           if self.epi > 500 :
               self.DownSize = 0
           if self.epi > 550 :
               self.DownSize = 0

           self.clearTrace()




           self.bags = 0
           self.slices = 0
           self.time = 0




        for i in range(0, len(self.AerialVehicles)):
            sub_agent_reward.append([totalReward - punish[i]*10])
            sub_agent_obs.append(self.getStateForUAVAgent(i))
            sub_agent_done.append(False)

        sub_agent_reward.append([totalReward + self.AllocReward])
        sub_agent_obs.append(self.getStateForMatchAgent())
        sub_agent_done.append(False)

        sub_agent_reward.append([totalReward+cacheMit*0.1+self.RecommendReward()])
        sub_agent_obs.append(self.getStateForCacheAgent())
        sub_agent_done.append(False)

        sub_agent_info.append({})

        #print(self.nowStep)
        #print(self.resetState2())
        #print(sub_agent_obs)
        return [sub_agent_obs, sub_agent_reward, sub_agent_done, sub_agent_info]

    def addService(self, act):
        self.AllocReward = 0
        reward = [0 for _ in range(len(self.Customers))]


        for i in range(0, len(self.Customers)):
            for j in range(0, len(self.AerialVehicles)):
                num = i * len(self.AerialVehicles) + j
                if act[2 * num] == 1:
                  #  print('yres')
                    self.AerialVehicles[j].addService(i)
                    reward[i] = reward[i] + 1

        for i in range(0, len(reward)):
            if reward[i] >= self.k:
                self.AllocReward = self.AllocReward + 1

    def addService22(self, act):
        self.AllocReward = 0
        reward = [0 for _ in range(TotalContentNum)]
        act = [1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]

        for i in range(0, len(self.Customers)):
            for j in range(0, len(self.AerialVehicles)):
                num = i * len(self.AerialVehicles) + j
                if act[ num] == 1:
                    #  print('yres')
                    self.AerialVehicles[j].addService(i)
                    reward[i] = reward[i] + 1

        for i in range(0, len(reward)):
            if reward[i] >= self.k:
                self.AllocReward = self.AllocReward + 1

    def clearTrace(self):
        self.trace.clear()
        for i in range(0, len(self.AerialVehicles)):
            self.trace.append([])

    def appendTrace(self):
        for i in range(0, len(self.AerialVehicles)):
            tmp = []
            tmp.append(self.AerialVehicles[i].PlaceX)
            tmp.append(self.AerialVehicles[i].PlaceY)
            self.trace[i].append(tmp)

    def RecommendReward(self):
        if self.Recommend < 10:
            return 10
        return 0

    def resetCache(self, act):
        index = self.findCache(act)
        self.Recommend = index
        if index not in self.CacheContent:
            self.Recommend = index
            if self.CacheNumNow < self.CacheNum:
                self.CacheContent[self.CacheNumNow] = index
                self.CacheNumNow = self.CacheNumNow + 1
            else:
                j = self.findLeast()
                for i in range(0, len(self.CacheContent)):
                    if j == self.CacheContent[i]:
                        self.CacheContent[i] = index

        contents = self.CacheContent

        self.CacheDict.clear()
        for j in range(0, len(self.AerialVehicles)):
            self.AerialVehicles[j].clearCacheContent()
        for i in range(0, len(contents)):
            self.CacheDict[contents[i]] = 0
            for j in range(0, len(self.AerialVehicles)):
                self.AerialVehicles[j].cacheContent(contents[i])

        for i in range(0, self.DownSize):
            self.AerialVehicles[i].clearCacheContent()

    def findLeast(self):
        least = 10000
        res = 0
        for key in self.CacheDict.keys():
            if self.CacheDict[key] < least:
                least = self.CacheDict[key]
                res = key
        return res
    def findCache(self, act):
        index = 0
        for j in range(0, len(act)):
            if act[j] == 1 :
                return j
        return index

    def findIndex(self, score, num):
        t = copy.deepcopy(score)
        # 求m个最大的数值及其索引
        max_index = []
        for _ in range(num):
            index = np.argmax(t)
            t[index] = -1
            max_index.append(index)
        return max_index


    def getStateForMatchAgent(self):
        state = []

        # 无人机位置部分
        for i in range(0, len(self.AerialVehicles)):
            state.append(self.AerialVehicles[i].getPlaceX()/self.MaxPlaceX)
            state.append(self.AerialVehicles[i].getPlaceY()/self.MaxPlaceX)

            # 用户请求部分
        for i in range(0, len(self.Customers)):
            state.append(self.Customers[i].getRequestIndex())

        for i in range(0, TotalContentNum):
            if self.AerialVehicles[0].whetherCacheContent(i):
                state.append(1)
            else:
                state.append(0)



        return state

    def getStateForCacheAgent(self):
        state = []
        for i in range(0, len(self.CacheContent)):
            state.append(self.CacheContent[i])

        for i in range(0, len(self.Customers)):
            state.append(self.Customers[i].getRequestIndex())

        return state

    def resetState2(self):
        self.state = []

        # 无人机位置部分
        for i in range(0, len(self.AerialVehicles)):
            self.state.append(self.AerialVehicles[i].getPlaceX() /self.MaxPlaceX)
            self.state.append(self.AerialVehicles[i].getPlaceY() /self.MaxPlaceX)
        # 用户请求部分
        for i in range(0, len(self.Customers)):
            self.state.append(self.Customers[i].getRequestIndex())
        # 缓存LIST
        for i in range(0, TotalContentNum):
            if self.AerialVehicles[0].whetherCacheContent(i):
                self.state.append(1)
            else:
                self.state.append(0)

        return self.state

    def getStateDimension2(self):
        return 2 * len(self.AerialVehicles) + len(self.Customers) + TotalContentNum

    def reset(self):
        self.RandomUAVPosition()
        self.resetServiceList()

        sub_agent_obs = []


        for i in range(0, len(self.AerialVehicles)):
            sub_agent_obs.append(self.getStateForUAVAgent(i))

        sub_agent_obs.append(self.getStateForMatchAgent())
        sub_agent_obs.append(self.getStateForCacheAgent())


        return sub_agent_obs




    # 重建state
    def getStateForUAVAgent(self, index):
        # 无人机位置部分
        s = []
        s.append(self.AerialVehicles[index].getPlaceX()/self.MaxPlaceX)
        s.append(self.AerialVehicles[index].getPlaceY()/self.MaxPlaceX)
        for j in range(0, len(self.AerialVehicles)):
            if index == j:
                continue
            s.append(self.AerialVehicles[index].getDist(self.AerialVehicles[j].getPlaceX(), self.AerialVehicles[j].getPlaceY()) /self.MaxPlaceX)

        return s

    def getStateDimension(self):
        return self.AerialVehiclesNum + 1

    def getDimension1(self):
        return self.AerialVehiclesNum + 1

    def getDimension2(self):
        return 2*self.AerialVehiclesNum + self.CustomerNum + TotalContentNum

    def getDimension3(self):
        return len(self.CacheContent) + self.CustomerNum

    def getDimension4(self):
        return 2*self.AerialVehiclesNum + self.CustomerNum + TotalContentNum

    def getActionDimension(self):
        return 2


    def resetCustomer(self):
        self.resetCustomerPosition()

    # 重置无人机状态，后续换为agent，故分开写

    # 随机分配无人机位置
    def RandomUAVPosition(self):
        for i in range(0, len(self.AerialVehicles)):
            self.AerialVehicles[i].moveTo(self.initX[i], self.initY[i])



    # 清空服务列表
    def resetServiceList(self):
        for i in range(0, len(self.AerialVehicles)):
            self.AerialVehicles[i].clearServiceList()
        for i in range(0, len(self.Customers)):
            self.Customers[i].clearService()

    def resetCustomerPosition(self):
        xDelta = self.MaxPlaceX - self.MinPlaceX
        yDelta = self.MaxPlaceY - self.MinPlaceY
        xx = [342.67990918, 476.6966731, 201.97413316, 256.09613169, 406.31048083,
              306.26303341, 360.87765872, 145.93803409, 458.88706126, 357.2878917,
              271.27218401, 171.0850238, 186.67038003, 337.06680753, 220.91658721,
              217.00699667, 308.88348923, 256.56912128, 325.19859097, 300.5194767]
        yy = [402.61159842, 260.8235762, 454.3244404, 159.61804449, 45.22967464,
              150.35002832, 56.99218093, 414.34066315, 23.44815969, 313.14357416,
              273.79307796, 409.64349784, 99.47376984, 428.42515123, 175.82631972,
              377.32384576, 147.98085344, 441.96823978, 162.75581892, 82.50794886]
        xx = xDelta * np.random.uniform(0, 1, self.CustomerNum) + self.MinPlaceX # x coordinates of Poisson points
        yy = yDelta * np.random.uniform(0, 1, self.CustomerNum) + self.MinPlaceY
        print(xx)
        print(yy)

        for i in range(0, len(self.Customers)):
            self.Customers[i].moveTo(xx[i], yy[i])

    def plotTrace(self):
        color = ['red','green','blue','purple','yellow']
        lable = ['UAV1', 'UAV2', 'UAV3', 'UAV4', 'UAV5']

        for i in range(len(self.Customers)):
            plt.plot(self.Customers[i].PlaceX,self.Customers[i].PlaceY,'o', color='black')
            plt.text(self.Customers[i].PlaceX,self.Customers[i].PlaceY,str(i),fontsize=10, ha='right', va='bottom')

        for i in range(len(self.trace)):
            for j in range(len(self.trace[i])):
                plt.plot(self.trace[i][j][0], self.trace[i][j][1], ',', color=color[i])
        # 连接各个点
            for j in range(len(self.trace[i])-1):
                start = (self.trace[i][j][0], self.trace[i][j+1][0])
                end = (self.trace[i][j][1], self.trace[i][j+1][1])
                plt.plot(start, end, color=color[i])

        for i in range(len(self.trace)):
            plt.plot(self.trace[i][0][0], self.trace[i][0][1], '*', color='orange')
            plt.plot(self.trace[i][len(self.trace[i])-1][0], self.trace[i][len(self.trace[i])-1][1], '*', color='pink')
        plt.legend()
        plt.show()




    def newUAV(self):
        return AerialVehicle(CustomerNum, TotalContentNum, ContentSize, Hight,
                 EnvA, EnvB, Frequency, Bandwidth, TransmitPower, SpeedOfLight,
                 AvgLOS, AvgNLOS, ConstrainLOS, Noise, self.MaxPlaceX, self.MinPlaceX,
                self.MaxPlaceY, self.MinPlaceY)

    def newUser(self, UserId, AerialVehicles):
        user = UserForUAV(UserId, TotalContentNum, FalvorNum, K)
        user.setUAVList(AerialVehicles)
        return user

    def plotPic(self, x, y):
        plt.xlabel('episode')
        plt.ylabel('reward')
        plt.plot(x, y)
        # You can specify a rotation for the tick labels in degrees or with keywords.
        #   plt.xticks(x, labels, rotation='vertical')
        plt.show()



#r = EnvCore()
#r.AerialVehicles[0].moveTo(0,0)

#r.AerialVehicles[0].addService(1)
#r.AerialVehicles[0].addService(2)
#r.AerialVehicles[0].addService(3)
#r.AerialVehicles[0].addService(4)
#s = r.AerialVehicles[0].getTransSpeed(500, 0)
#print(s)