import shutil
import time
import pickle
import argparse

from RandLANet import Network, log_out
from sampler2 import *

import numpy as np

from base_op import *
from gcn import *
from semantic3d_dataset_sampling import *
from fps_gcn_cuda import GCN_FPS_sampling
import os
from numpy.linalg import svd
from fps_gcn_cuda import *
import open3d as o3d 

# train_cloud_name_list = [
#     # 'neugasse_station1_xyz_intensity_rgb',
#     'sg27_station1_intensity_rgb',
#     'sg27_station4_intensity_rgb',
#     'sg27_station5_intensity_rgb',
#     'sg27_station9_intensity_rgb',
#     'sg28_station4_intensity_rgb',
#     'untermaederbrunnen_station1_xyz_intensity_rgb',
#     'untermaederbrunnen_station3_xyz_intensity_rgb',
# ]

train_cloud_name_list = ['bildstein_station1_xyz_intensity_rgb',
                             'bildstein_station5_xyz_intensity_rgb',
                             'domfountain_station1_xyz_intensity_rgb',
                             'domfountain_station2_xyz_intensity_rgb',
                             'domfountain_station3_xyz_intensity_rgb',
                             'neugasse_station1_xyz_intensity_rgb',
                             'sg27_station1_intensity_rgb',
                             'sg27_station4_intensity_rgb',
                             'sg27_station5_intensity_rgb',
                             'sg27_station9_intensity_rgb',
                             'sg28_station4_intensity_rgb',
                             'untermaederbrunnen_station1_xyz_intensity_rgb',
                             'untermaederbrunnen_station3_xyz_intensity_rgb']


# train_cloud_name_list = ['bildstein_station1_xyz_intensity_rgb']


def load_superpoints_and_ply(data_path, cloud_name):
    with open(join(data_path, cloud_name + ".superpoint"), "rb") as f:
        sp_data = pickle.load(f)
    ply_data = o3d.io.read_point_cloud(join(data_path, cloud_name + ".ply"))
    return sp_data, np.asarray(ply_data.points)

def compute_centroids_and_points(sp_data, ply_points):
    centroids = []
    points_list = []
    for sp_indices in sp_data['components']:
        sp_points = ply_points[sp_indices]
        centroid = np.mean(sp_points, axis=0)
        centroids.append(centroid)
        points_list.append(sp_points)
    return np.array(centroids), points_list

def chamfer_distance(points_set1, points_set2):
    from scipy.spatial import cKDTree
    tree1 = cKDTree(points_set1)
    tree2 = cKDTree(points_set2)
    forward = np.mean(tree1.query(points_set2)[0])
    backward = np.mean(tree2.query(points_set1)[0])
    return forward + backward

def fps_with_combined_distances(centroids, points_list, num_samples):
    N = len(centroids)
    selected = [np.random.randint(N)]  # 随机选择第一个点
    distances = np.full(N, np.inf)

    for i in range(1, num_samples):
        last_selected = selected[-1]
        for j in range(N):
            if j not in selected:
                euclidean_dist = np.linalg.norm(centroids[last_selected] - centroids[j])
                chamfer_dist = chamfer_distance(points_list[last_selected], points_list[j])
                combined_distance = euclidean_dist + chamfer_dist  # 可以通过权重调整两者的影响
                distances[j] = min(distances[j], combined_distance)
        selected.append(np.argmax(distances))

    return selected

def fps(distance_matrix, num_samples):
    N = distance_matrix.shape[0]
    selected = [np.random.randint(N)]  # 随机选择一个起始点
    distances = np.full(N, np.inf)  # 初始化距离数组
    
    for _ in range(1, num_samples):
        current_distances = np.min(distance_matrix[selected], axis=0)
        next_point = np.argmax(current_distances)
        selected.append(next_point)
    
    return selected


if __name__ == '__main__':
    """create seed samples and model weights"""

    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='the number of GPUs to use [default: 0]')
    parser.add_argument('--dataset', type=str, default='Semantic3D', choices=["S3DIS", "semantic3d", "SemanticKITTI"])
    parser.add_argument('--seed_percent', type=float, default=0.008, help='seed percent') # 原本是0.01
    parser.add_argument('--reg_strength', default=0.012, type=float,
                        help='regularization strength for the minimal partition')

    FLAGS = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(FLAGS.gpu)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # '2'

    sampler_args = []
    sampler_args.append("seed")

    dataset_name = FLAGS.dataset
    seed_percent = FLAGS.seed_percent
    reg_strength = FLAGS.reg_strength
    round_num = 1

    if dataset_name == "Semantic3D":
        test_area_idx = 0
        input_ = "input_0.060"
        cfg = ConfigSemantic3D

    chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist().cuda(device=fps_gpu)
    # 先读total.pkl获得sp总数 读每个.superpoint,根据比例计算其应该采样的sp个数
    with open(os.path.join("/data1/ncl/dataset/", dataset_name, str(reg_strength), "superpoint/total.pkl"), "rb") as f:
        total_obj = pickle.load(f)

    file_num = total_obj["file_num"]
    sp_num = total_obj["sp_num"]
    point_num = total_obj["point_num"]
    cloud_sp_count = []

    for cloud_name in train_cloud_name_list:
        cloud_sp_count.append(len(total_obj["unlabeled"][cloud_name]))

    cloud_sp_count = np.array(cloud_sp_count)    
    
    cloud_sp_count_weight = cloud_sp_count / sp_num
    cloud_sp_sampling_count = cloud_sp_count * cloud_sp_count_weight

    cloud_centroids = {}
    for i,cloud_name in enumerate(train_cloud_name_list):
        if i<=5:
            continue
        print("当前处理场景: " +cloud_name+", begin create_cd_cuda")
        # with open(join("/data1/ncl/dataset/Semantic3D/0.012", "superpoint",
        #                cloud_name + ".superpoint"), "rb") as f:
        #     sp = pickle.load(f)
        # components = sp["components"]
        # data = read_ply(
        #     join("/data1/ncl/dataset/Semantic3D/input_0.060", '{:s}.ply'.format(cloud_name)))
        # xyz = np.vstack((data['x'], data['y'], data['z'])).T  # shape=[point_number, 3]
        # for sp_idx,sp_points in enumerate(components):
        #     center_xyz = np.mean(xyz[sp_points], axis=0)
        #     if(cloud_name not in cloud_centroids):
        #         cloud_centroids[cloud_name] = []
        #     cloud_centroids[cloud_name].append(center_xyz)
        # N = len(cloud_centroids[cloud_name])
        # A_ed = np.ones([N, N], dtype=np.float) * 1e10 # Euclidean distance
        # A_cd = np.ones([N, N], dtype=np.float) * 1e10 # Chamfer distance 
        # for j in range(N):
        with open(join("/data1/ncl/dataset/Semantic3D/0.012", "superpoint",
                       cloud_name + ".superpoint"), "rb") as f:
            sp = pickle.load(f)
        components = sp["components"]
        sp_count = len(components)
        data = read_ply(
            join("/data1/ncl/dataset/Semantic3D/input_0.060", '{:s}.ply'.format(cloud_name)))
        xyz = np.vstack((data['x'], data['y'], data['z'])).T  # shape=[point_number, 3]
        A_ed = np.ones([sp_count, sp_count], dtype=np.float) * 1e10 # Euclidean distance
        A_cd = np.ones([sp_count, sp_count], dtype=np.float) * 1e10 # Chamfer distance
        source_ref_idx_list = []
        one_cloud_candicate_superpoints = []
        one_cloud_center_xyz = np.zeros([len(components), 3])
        one_cloud_center_xyz_len = len(one_cloud_center_xyz)
        point_size = 0
        for j in range(one_cloud_center_xyz_len):
            # source_sp_idx = total_cloud[cloud_name][j]["sp_idx"]
            source_sp_idx = j
            source_ref_idx_list.append(j)

            x_y_z = xyz[components[source_sp_idx]]
            one_cloud_center_xyz[j] = (np.min(x_y_z, axis=0) + np.max(x_y_z, axis=0)) / 2.0
            one_cloud_candicate_superpoints.append(x_y_z)

            point_size += len(components[source_sp_idx])
            # 计算倒角距离 ！耗时
        print("当前处理场景: " +cloud_name+", begin create_cd_cuda")
        begin_time = time.time()
        one_clound_cd_dist = create_cd_cuda(superpoint_list=one_cloud_candicate_superpoints,superpoint_centroid_list=one_cloud_center_xyz, chamLoss=chamLoss)
        end_time = time.time()
        elapsed_time_minutes = (end_time - begin_time) / 60
        print(f"{cloud_name}: create_cd_cuda finished, during time: {elapsed_time_minutes:.2f} minutes")
        
        for j in range(one_cloud_center_xyz_len):
            # print("6", j)
            # print(cloud_name_list_len, i, "f2", one_cloud_center_xyz_len, j)
            # 每个sp的中心-当前sp的中心点
            ssdr = one_cloud_center_xyz - one_cloud_center_xyz[j]
            dist = np.sqrt(np.sum(np.multiply(ssdr, ssdr), axis=1))
            A_ed[source_ref_idx_list[j], source_ref_idx_list] = dist
            A_cd[source_ref_idx_list[j], source_ref_idx_list] = one_clound_cd_dist[j]
        
        
        
        with open('/data1/ncl/dataset/Semantic3D/analyse_data/'+cloud_name+'_A_ed.pkl', 'wb') as f:
            pickle.dump(A_ed, f)
        with open('/data1/ncl/dataset/Semantic3D/analyse_data/'+cloud_name+'_A_cd.pkl', 'wb') as f:
            pickle.dump(A_cd, f)
        

        # with open('/data1/ncl/dataset/Semantic3D/analyse_data/bildstein_station1_xyz_intensity_rgb_A_ed.pkl', 'rb') as f:
        #     A_ed = pickle.load(f)
        # with open('/data1/ncl/dataset/Semantic3D/analyse_data/bildstein_station1_xyz_intensity_rgb_A_cd.pkl', 'rb') as f:
        #     A_cd = pickle.load(f)

        combined_distance = A_ed + A_cd
        print("finish dump : ",cloud_name)
        # num_samples =  cloud_sp_sampling_count[i] # 你希望采样的点的数量
        # selected_indices = fps(combined_distance, int(num_samples))
        # print("Selected indices:", selected_indices)
    # 针对每个场景，根据ed,cd fps选出k个sp，然后修改.gt？

    # model.tarin()

