import copy

import numpy as np
from functions import *
# 局部区域划分 一种聚类
# 得到距离矩阵
# 局部密度和局部距离
# 分簇
class Cluster:
    def __init__(self,ue_list):
        self.centroids = []
        self.ue_list = ue_list
        self.numOfUE = len(ue_list)
        self.disMatrix = [[0] * self.numOfUE for _ in range(self.numOfUE)]
        self.dc = 20 # 截断距离
        self.rhoList = [0] * self.numOfUE
        self.deltaList = [0] * self.numOfUE
        self.rhoList_tmp = [0] * self.numOfUE
        self.deltaList_tmp = [0] * self.numOfUE

        self.clusters = []
    def reset_ue_list(self,ue_list):
        self.__init__(ue_list)
        self.process()
    def getDisMatrix(self):
        for i in range(self.numOfUE):
            for j in range(self.numOfUE):
                self.disMatrix[i][j] = get_distance_2d(self.ue_list[i],self.ue_list[j])

    def get_rho(self):
        for i in range(self.numOfUE):
            rho = 0
            for j in range(self.numOfUE):
                if self.disMatrix[i][j] < self.dc:
                    rho += 1
            self.rhoList[i] = rho
            positive_values = [x for x in self.disMatrix[i] if x > 0]
            self.deltaList[i] = min(positive_values)
        self.rhoList_tmp = copy.deepcopy(self.rhoList)
        self.deltaList_tmp = copy.deepcopy(self.deltaList)

    def get_centroid(self):
        points =  [(x, y) for x, y in zip(self.rhoList_tmp, self.deltaList_tmp)]
        scores = [min(p) for p in points]

        # 选择得分最大的点
        selected_index = scores.index(max(scores))
        selected_point = points[selected_index]
        return selected_point

    def get_clusters(self):
        clusters = []

        while len(self.rhoList_tmp):
            centroid = self.get_centroid()
            cluster = [centroid]

            self.rhoList_tmp.remove(centroid[0])
            self.deltaList_tmp.remove(centroid[1])

            i = 0
            count = 1
            while i < len(self.rhoList_tmp) and count < 8:
                point = (self.rhoList_tmp[i], self.deltaList_tmp[i])
                distance = get_distance_2d(centroid, point)

                if distance <= 10:
                    cluster.append(point)
                    self.rhoList_tmp.remove(point[0])
                    self.deltaList_tmp.remove(point[1])
                    count += 1
                else:
                    i += 1
            clusters.append(cluster)
        self.clusters = clusters
    def process(self):
        self.getDisMatrix()
        self.get_rho()
        self.get_clusters()

if __name__ == '__main__':
    addr = r"data\list_ue100_7.csv"

    ue_list = read_UElist(addr)
    c1 = Cluster(ue_list)
    c1.process()
    # print(ue_list)