import copy
import matplotlib.pyplot as plt
import numpy as np
from functions import *
from centroid_cluster import Cluster
from sklearn.cluster import DBSCAN

class Deploy:
    def __init__(self,ue_list,cluster):
        self.ue_list = ue_list
        self.covered_ue_list = []   # C
        self.candidate_list = []    # E
        self.uncovered_ue_list = [1] * (len(ue_list)) # U
        self.cluster = cluster
        self.uav_list = []
        self.rhoidx_record = []
        self.R = 20
        self.du = 30
        self.server_set = []
    def reset_candidate(self):
        self.candidate_list = []

    def current_position(self,k):
        # 取决于k的聚类
        # 只改变候选集合
        covered_ue_list = copy.deepcopy(self.covered_ue_list)
        uncovered_ue_list = copy.deepcopy(self.uncovered_ue_list)

        covered_ue_list.append(k)  # C
        uncovered_ue_list[k] = 0
        self.reset_candidate()
        for i in range(self.cluster.numOfUE):
            if self.cluster.disMatrix[k][i] < 2*self.R and i!=k:  # E 这里的距离不是覆盖距离，存在倍数关系
                self.candidate_list.append(i)
                # self.uncovered_ue_list[i] = 0
                # self.uncovered_ue_list.remove(i)    # U

        cloest_dis = min(self.cluster.disMatrix[k])
        cloest_idx = self.cluster.disMatrix[k].index(cloest_dis)
        selected_points = [self.ue_list[i] for i in self.candidate_list]
        uav = (
            (sum(point[0] for point in selected_points) + self.cluster.ue_list[cloest_idx][0]) / (len(selected_points) + 1),
            (sum(point[1] for point in selected_points) + self.cluster.ue_list[cloest_idx][1]) / (len(selected_points) + 1)
        )
        # reset uncovered
        print('kkk',k,selected_points,uav,[get_distance_2d(uav,node) for node in selected_points])
        return uav

    def find_min_rho_and_process(self):
        # if sum(self.uncovered_ue_list):
        #     list_tocover = np.nonzero(self.uncovered_ue_list)[0]
        #     print(list_tocover)
        #     min_rho = self.cluster.rhoList[list_tocover[0]]
        #     min_rho_idx = list_tocover[0]
        #     for i in list_tocover:  # 候选中密度最小的
        #         if (0 < self.cluster.rhoList[i] < min_rho) and self.uncovered_ue_list[i]:
        #             min_rho = self.cluster.rhoList[i]
        #             min_rho_idx = i
        #     print('res',min_rho_idx,self.cluster.rhoList,self.candidate_list)
        #     return min_rho, min_rho_idx
        # else:
        #     return []
        if len(self.candidate_list):
            min_rho = self.cluster.rhoList[self.candidate_list[0]]
            min_rho_idx = self.candidate_list[0]
            for i in self.candidate_list:  # 候选中密度最小的
                if (0 < self.cluster.rhoList[i] < min_rho) and self.uncovered_ue_list[i]:
                    min_rho = self.cluster.rhoList[i]
                    min_rho_idx = i
            print('res',min_rho_idx,self.cluster.rhoList,self.candidate_list)

            # 避免死循环
            return min_rho, min_rho_idx
        else:
            return []
    def update_ue_based_uav(self,uav):
        for i in self.uncovered_ue_list:
            if get_distance_2d(uav,self.ue_list[i]) < self.R:
                self.uncovered_ue_list[i] = 0
                # self.uncovered_ue_list.remove(i)
                self.covered_ue_list.append(i)
    def add_uav(self,uav,min_rho_idx):
        self.uav_list.append(uav)
        self.rhoidx_record.append(min_rho_idx)
        # 改变rho
        for i in range(self.cluster.numOfUE):
            if get_distance_2d(uav,self.ue_list[i]) < self.R:
                # self.candidate_list.append(i)
                self.cluster.rhoList[i] = self.cluster.numOfUE #0
                self.uncovered_ue_list[i] = 0
                # self.uncovered_ue_list.remove(i)
    def iter_deploy(self):
        # # 判断
        # result = self.find_min_rho_and_process()
        min_rho_idx = 0
        uav = self.current_position(min_rho_idx)
        # # 这个条件
        if get_distance_2d(self.ue_list[min_rho_idx],uav) < self.R and all(get_distance_2d(self.ue_list[ue_idx], uav) < self.R for ue_idx in self.candidate_list):
            print(1)
            self.covered_ue_list.append(min_rho_idx)
            # 单无人机服务集合，已覆盖集合
            # find new idx
        else:
            print(2)
            # self.uav_list.append(uav)
            self.add_uav(uav,min_rho_idx)
            # 加入无人机的对应处理

            # while(sum(self.uncovered_ue_list)):
            cnt = 0
            while cnt<3:
                cnt +=1
                result = self.find_min_rho_and_process()
                # if (len(result)):
                min_rho, min_rho_idx = result
                # min_rho, min_rho_idx = self.find_min_rho_and_process()
                # print(self.uncovered_ue_list)
                uav = self.current_position(min_rho_idx)
                print(2,uav,sum(self.uncovered_ue_list),self.uncovered_ue_list)
                if (any(get_distance_2d(uav_i, uav) < self.du for uav_i in self.uav_list)   # 无人机距离过近
                        and get_distance_2d(self.ue_list[min_rho_idx], uav) < self.R
                        and all(get_distance_2d(self.ue_list[ue_idx], uav) < self.R for ue_idx in self.candidate_list)):
                    self.covered_ue_list.append(min_rho_idx)
                    self.uncovered_ue_list[min_rho_idx] = 0
                    # self.uncovered_ue_list.remove(min_rho_idx)
                    # uav = self.current_position(min_rho_idx)
                    print(21)
                else:
                    # self.uav_list.append(uav)
                    self.add_uav(uav,min_rho_idx)
                    self.cluster.rhoList[min_rho_idx] = self.cluster.numOfUE # 0
                    print(22,sum(self.uncovered_ue_list))
                # break
                # print(self.uav_list)
                # while (len(self.candidate_list)):
                #     print(3)
    def show_result(self):
        plt.figure()
        plt.scatter(np.array(ue_list)[:,0],np.array(ue_list)[:,1],c='b')
        plt.scatter(np.array(self.uav_list)[:,0],np.array(self.uav_list)[:,1],c='r')
        plt.show()
    def process(self):
        self.iter_deploy()
        print(self.uav_list)
        self.show_result()
    def dbscan(self):
        # 创建DBSCAN对象
        dbscan = DBSCAN(eps=20, min_samples=3)
        # 将ue_list转换为适合聚类算法的输入格式
        X = np.array(self.ue_list).reshape(-1, 1)
        # 执行聚类
        labels = dbscan.fit_predict(X)
        # 获取聚类的类别数
        n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
        print("聚类的类别数：", n_clusters)
        # 输出每个样本的聚类标签
        for i in range(len(ue_list)):
            print("UE", ue_list[i], "属于聚类", labels[i])
        self.plot_clusters(X, labels)
    # 绘制散点图
    def plot_clusters(self,X, labels):
        # 获取不同聚类标签的集合
        unique_labels = set(labels)
        # 绘制每个聚类的样本
        for label in unique_labels:
            if label == -1:
                # 对于噪音点，使用灰色表示
                color = 'gray'
            else:
                # 对于其他聚类，使用不同的颜色表示
                color = plt.cm.Spectral(label / len(unique_labels))
            # 获取属于当前聚类的样本的索引
            cluster_samples = [i for i, l in enumerate(labels) if l == label]
            # 根据索引绘制样本点
            plt.scatter(X[cluster_samples], [label] * len(cluster_samples), color=color)
        # 设置图形标题和轴标签
        plt.title('DBSCAN Clustering')
        plt.xlabel('UE')
        plt.ylabel('Cluster Label')
        # 显示图形
        plt.show()

if __name__ == '__main__':
    addr = r"data\list_ue100_7.csv"

    ue_list = read_UElist(addr)[:10]
    c1 = Cluster(ue_list)
    c1.process()
    dep = Deploy(ue_list,c1)
    dep.dbscan()