from sklearn.cluster import KMeans, DBSCAN
import numpy as np
import matplotlib.pyplot as plt
from functions import *

class ClusterResult:
    def __init__(self,ue_list,K,eps):
        self.ue_list = ue_list

        self.K = K
        self.X_Kmeans = []
        self.labels_Kmeans = []
        self.result_Kmeans = []

        self.eps = eps
        self.X_DBS = []
        self.labels_DBS = []
        self.result_DBS = []
        # self.min_samples = min_samples

    def kmeans_clustering(self):
        # 创建 K-means 对象
        kmeans = KMeans(n_clusters=self.K)
        # 将 ue_list 转换为适合聚类算法的输入格式
        self.X_Kmeans = np.array(self.ue_list)
        # 执行聚类并获取聚类结果
        self.labels_Kmeans = kmeans.fit_predict(self.X_Kmeans)
        self.result_Kmeans = self.get_centroid(self.labels_Kmeans)
    def show_kmeans_result(self):
        # 绘制聚类结果的散点图
        # plt.scatter(self.X_Kmeans[:, 0], self.X_Kmeans[:, 1], c=self.labels_Kmeans)
        color = 'b'
        notLegend1 = 1
        notLegend2 = 1
        for i in range(len(self.ue_list)):
            r1 = calculate_gamma_pdf(i, 1)
            r2 = calculate_gamma_pdf(i, 2)
            color = 'b' if r1<r2 else 'k'
            if notLegend1 and color == 'b':
                plt.scatter(self.ue_list[i][0], self.ue_list[i][1], c=color, alpha=0.5,label='tele')
                notLegend1 = 0
            if notLegend2 and color =='k':
                plt.scatter(self.ue_list[i][0], self.ue_list[i][1], c=color, alpha=0.5, label='vedio')
                notLegend2 = 0
            plt.scatter(self.ue_list[i][0], self.ue_list[i][1], c=color, alpha=0.5)

        # plt.scatter(self.X_Kmeans[:, 0], self.X_Kmeans[:, 1], c=color, alpha=0.5)  # gamma分布

        plt.scatter(self.result_Kmeans[0][0], self.result_Kmeans[0][1],c='r',marker='v',label='uav')
        plt.scatter(self.result_Kmeans[:, 0], self.result_Kmeans[:, 1],c='r',marker='v')
        plt.legend(loc='lower right',frameon=False)
        plt.xlabel('X')
        plt.ylabel('Y')
        plt.title('K-means Clustering')
        plt.show()

    def get_centroid(self,labels):
        # 输出每个样本的聚类标签和坐标
        cluster_centers = {}
        for i in range(len(self.ue_list)):
            label = labels[i]
            point = self.ue_list[i]
            if label not in cluster_centers:
                cluster_centers[label] = [point]
            else:
                cluster_centers[label].append(point)

        # 计算每个类的簇心坐标
        #
        cluster_centers_coordinates = {}
        for label, points in cluster_centers.items():
            center = np.mean(points, axis=0)
            cluster_centers_coordinates[label] = center
        # # 输出每个类的簇心坐标
        # for label, center in cluster_centers_coordinates.items():
        #     print("类别", label, "的簇心坐标为", center)
        return np.array(list(cluster_centers_coordinates.values()))
    def dbscan_clustering(self):
        # 创建 DBSCAN 对象
        eps = self.eps
        dbscan = DBSCAN(eps=eps, min_samples=1)

        # 将 ue_list 转换为适合聚类算法的输入格式
        X = np.array(self.ue_list)
        # 执行聚类并获取聚类结果
        labels = dbscan.fit_predict(X)
        # 检查聚类结果是否包含噪声点
        while -1 in labels:
            # 增加 eps 的值
            eps += 0.1
            # 重新执行聚类
            dbscan = DBSCAN(eps=eps, min_samples=1)
            labels = dbscan.fit_predict(X)
        n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
        # print("密度聚类的类别数：", n_clusters)
        self.result_DBS = self.get_centroid(labels)
        # 输出每个样本的聚类标签
        # for i in range(len(ue_list)):
        #     print("UE", ue_list[i], "属于聚类", labels[i])
        self.X_DBS = X
        self.labels_DBS = labels
    def show_dbscan_result(self):
        plt.figure()
        plt.scatter(self.X_DBS[:, 0], self.X_DBS[:, 1], c=self.labels_DBS)
        plt.scatter(self.result_DBS[:, 0], self.result_DBS[:, 1], c='k',marker='s')
        plt.xlim(0,100)
        plt.xlabel('X')
        plt.ylabel('Y')
        plt.title('Deployment')
        plt.show()
    def process(self):
        self.kmeans_clustering()
        # self.show_kmeans_result()
        self.dbscan_clustering()
        # self.show_dbscan_result()

if __name__ == '__main__':

    addr = r"data\list_ue100_7.csv"

    ue_list = read_UElist(addr)[:30]
    # 示例数据
    # ue_list = [[2, 3], [4, 5], [7, 9], [10, 12], [15, 20], [18, 25], [22, 28], [25, 30], [28, 32], [30, 35]]
    K = 3
    eps = 20
    c1 = ClusterResult(ue_list,K,eps)
    c1.process()

    plt.show()

