import math
import numpy as np
from sklearn.cluster import KMeans
import km
import GDOP
from itertools import combinations
from pprint import pprint


import tool
import copy


class Match:
    def __init__(self,users,uavs):
        self.id = id
        # self.cpUsers,self.pUsers,self.cUsers=Users.getAllUsers()
        self.allUsers =users
        # 4台无人机，每台服务20个人
        self.uavs = uavs
        self.serveMax = 20
        self.distance= self.getUavAndUserDistance()
        self.uavConnectNum = [0 for col in range(len(self.uavs))]
        self.uav_group = list(combinations(self.uavs, 3))
        self.snr_thre=0
        self.Pmax=0.1

        self.p_connection, self.c_connection = self.get_all_match()
        self.uav_to_user,self.uav_to_cuser,self.uav_to_puser = self.get_uav_to_user()


    def getUavInitialPos(self, k):
        model = KMeans(n_clusters=k,  # 聚类簇数
                       random_state=1,  # 决定质心初始化的随机数生成，使用int使随机性具有确定性。
                       max_iter=300,  # 执行一次k-means算法所进行的最大迭代数，默认300
                       ).fit(self.allUsers)
        center = model.cluster_centers_
        uavs = []
        # center=[[250, 250], [250, 750], [750, 250], [750, 750]]
        for item in center:
            height = 100
            position = [int(item[0]), int(item[1]), height]
            uavs.append(position)
        # print(center)
        return uavs

    def getUavAndUserDistance(self):
        user_num = len(self.allUsers)
        uav_num = len(self.uavs)
        distance = np.zeros((user_num, uav_num), dtype=np.float16)

        for i in range(user_num):
            for j in range(uav_num):
                distance[i][j] = self.calculateDistance(self.allUsers[i], self.uavs[j])

        return distance

    def calculateDistance(self, user, uav):
        distanceSquare = (user[0] - uav[0]) ** 2 + (user[1] - uav[1]) ** 2 + uav[2] ** 2
        return math.sqrt(distanceSquare)

    def get_appropriate_uavs(self, user):
        uavs = [k for k, v in enumerate(self.uavConnectNum) if v < self.serveMax]
        appropriate_uavs = []
        for uav in uavs:
            snr = tool.cal_snr(self.uavs[uav], user, self.Pmax)
            if snr > self.snr_thre:
                appropriate_uavs.append(uav)
        return appropriate_uavs


    # 获取定位匹配
    def getPositionConnection(self, users):
        # for user in users:
        # 计算该user能通信的uav
        # 获得uav组合
        # 计算user与uav组合的gdop
        # 选择最小的gdop组合，用字典保存,并且更新服务次数
        user_to_uav_group = {}
        for i in range(len(users)):
            uavs = self.get_appropriate_uavs(users[i])
            uav_group = list(combinations(uavs, 3))
            gdop_min = float('inf')
            index = 0
            for j in range(len(uav_group)):
                uavs_pos = [self.uavs[uav_group[j][0]], self.uavs[uav_group[j][1]], self.uavs[uav_group[j][2]]]
                gdop_current = GDOP.calculate_gdop(uavs_pos, copy.copy(users[i]))
                if gdop_current < gdop_min:
                    gdop_min = gdop_current
                    index = j
            user_to_uav_group[i] = uav_group[index]
            self.update_serve_num(uav_group[index])
        return user_to_uav_group

    # 纯定位用户的匹配
    def getPConnection(self):
        # cp_connection=self.getCPConnection()
        p_connection = self.getPositionConnection(self.allUsers)

        return p_connection

    def getCommunicationConnection(self):
        user_num = len(self.allUsers)
        uavs = [k for k, v in enumerate(self.uavConnectNum) if v < self.serveMax]
        uav_num = len(self.uavs)
        distance_sum = self.distance.sum(axis=1)  # 按照行求和
        # print(distance_sum)
        weight = np.zeros((user_num, uav_num), dtype=np.float16)
        graph = []
        for i in range(user_num):
            for j in uavs:
                # weight[i][j]=-self.distance_c[i][j]/(distance_sum[i]-self.distance_c[i][j])
                weight[i][j] = -self.distance[i][j]
        dic = {}
        for i in range(user_num):
            for j in uavs:
                for number in range(self.serveMax - self.uavConnectNum[j]):
                    graph.append((i, j + 100 * number, weight[i][j]))
                    dic[j + 100 * number] = j

        connection = km.run_kuhn_munkres(graph)
        c_connection = {}
        for connect in connection:
            c_connection[connect[0]] = dic[connect[1]]

        self.update_serve_num_by_c_connection(c_connection)

        return c_connection

    def update_serve_num(self, uavs):
        for u in uavs:
            self.uavConnectNum[u] += 1

    # 根据通信连接更新服务次数
    def update_serve_num_by_c_connection(self, c_connection):
        for key in c_connection:
            self.uavConnectNum[c_connection[key]] += 1

    def get_all_match(self):
        p_connection = self.getPConnection()
        c_connection = self.getCommunicationConnection()
        return p_connection, c_connection

    def get_uav_to_user(self):
        uav_to_user = {}
        uav_to_cuser={}
        uav_to_puser={}

        for i in range(len(self.uavs)):
            uav_to_user[i] = []
            uav_to_cuser[i]=[]
            uav_to_puser[i]=[]
        # 编号规则：定位0-19,通信100-119
        for key in self.p_connection:
            for uav_index in self.p_connection[key]:
                uav_to_user[uav_index].append(key)
                uav_to_puser[uav_index].append(key)

        for key in self.c_connection:
            uav_to_user[self.c_connection[key]].append(key + 100)
            uav_to_cuser[self.c_connection[key]].append(key)

        return uav_to_user,uav_to_cuser,uav_to_puser

    def dbmToW(self, p):
        return pow(10, p / 10) / 1000


if __name__ == '__main__':
    m = Match()
    # pprint(m.getPConnection())
    # print(m.get_uav_to_user())
    print(m.p_connection,m.c_connection)
    print(m.uav_to_user)
    # pprint(m.getPConnection())
    # pprint(m.getCommunicationConnection())
    print(m.uavConnectNum)
    print(m.get_min_speb())
