from itertools import combinations

import Users
import tool
import GDOP
import copy
import math

def get_p_association_first(users,uavs,service_num_p,service_num_c):
    p_connection = {}
    for i in range(len(users)):
        p_connection[i]=[]
    uavs_initial=[i for i in range(len(uavs))]
    uav_group_all=list(combinations(uavs_initial, 3))
    uav_serive_num = [service_num_p for i in range(len(uav_group_all))]
    uav_serive_num_c=[service_num_c for i in range(len(uavs))]
    # 获得所有信道数量大于0的无人机编号
    for i in range(len(users)):
        uav_need = [i for i in range(len(uav_group_all)) if uav_serive_num[i] > 0]
        uav_group = [uav_group_all[i] for i in uav_need]
        gdop_min = float('inf')
        index = 0
        for j in range(len(uav_group)):
            uavs_pos = [uavs[uav_group[j][0]], uavs[uav_group[j][1]], uavs[uav_group[j][2]]]
            gdop_current = GDOP.calculate_gdop(uavs_pos, copy.copy(users[i]))
            if gdop_current < gdop_min:
                gdop_min = gdop_current
                index = j
        print(gdop_min)
        # for k in range(len(uav_group[index])):
        p_connection[i].extend([uav_group[index][0],uav_group[index][1],uav_group[index][2]])
        uav_serive_num[index] -= 1
    c_connection={}
    for i in range(len(users)):
        uav_need = [i for i in range(len(uavs)) if uav_serive_num_c[i] > 0]
        min=float('inf')
        node=-1
        for j in range(len(p_connection[i])):
            if p_connection[i][j] in uav_need:
                dis=tool.calculate_3d_distance(uavs[p_connection[i][j]],users[i])
                if min>dis:
                    min=dis
                    node=p_connection[i][j]
        c_connection[i]=node
        uav_serive_num_c[node]-=1
    # print(c_connection)
    for i in range(len(p_connection)):
        p_connection[i].remove(c_connection[i])
    # print(p_connection)
    return c_connection,p_connection

if __name__ == '__main__':
    users=Users.getUsers("./data.csv")
    uavs=tool.getUavInitialPos(3,users)
    uavs_initial = [i for i in range(len(uavs))]
    combination=list(combinations(uavs_initial, 3))
    print(get_p_association_first(users,uavs,math.ceil(len(users)/len(combination)),math.ceil(len(users)/len(uavs))))

    # uav_group = list(combinations(uav_need, 1))
    # print(uav_group)