
import os.path
import argparse
from graphs import compute_graph_nn_2
# from provider import *
from helper_ply import read_ply
import glob
import pickle
import os
import numpy as np
import sys
sys.path.append("partition/cut-pursuit/build/src")
sys.path.append("cut-pursuit/build/src")
sys.path.append("ply_c")
sys.path.append("./partition/ply_c")
sys.path.append("./partition")
import libcp
import libply_c

train_cloud_name_list = ['bildstein_station1_xyz_intensity_rgb',
                             'bildstein_station5_xyz_intensity_rgb',
                             'domfountain_station1_xyz_intensity_rgb',
                             'domfountain_station2_xyz_intensity_rgb',
                             'domfountain_station3_xyz_intensity_rgb',
                             'neugasse_station1_xyz_intensity_rgb',
                             'sg27_station1_intensity_rgb',
                             'sg27_station4_intensity_rgb',
                             'sg27_station5_intensity_rgb',
                             'sg27_station9_intensity_rgb',
                             'sg28_station4_intensity_rgb',
                             'untermaederbrunnen_station1_xyz_intensity_rgb',
                             'untermaederbrunnen_station3_xyz_intensity_rgb']

val_cloud_name_list = ['bildstein_station3_xyz_intensity_rgb',
                       'sg27_station2_intensity_rgb']

test_cloud_name_list = ['MarketplaceFeldkirch_Station4_rgb_intensity-reduced',
                        'sg27_station10_rgb_intensity-reduced',
                        'sg28_Station2_rgb_intensity-reduced',
                        'StGallenCathedral_station6_rgb_intensity-reduced']

def sensatUrban_superpoint(args):
    # path = "/home/ncl/dataset/Semantic3D"

    # output_dir = os.path.join(path, str(args.reg_strength), "superpoint")

    # if not os.path.isdir(output_dir):
    #     os.makedirs(output_dir)
    
    # # /home/ncl/dataset/Semantic3D/input_0.060 使用的是下采样后的点划分 superpoint
    # tree_path = os.path.join(path, 'input_{:.3f}'.format(0.06))

    total_obj = {}
    total_obj["unlabeled"] = {}
    sp_num = 0
    file_num = 0
    point_num = 0
    cloud_name= 'cambridge_block_28'
    sub_ply_file = '/home/ncl/dataset/SensatUrbanSampledBySSDR/cambridge_block_28.ply'
    data = read_ply(sub_ply_file)
    rgb = np.vstack((data['red'], data['green'], data['blue'])).T
    xyz = np.vstack((data['x'], data['y'], data['z'])).T
    graph_nn, target_fea = compute_graph_nn_2(xyz, args.k_nn_adj, args.k_nn_geof)
    geof = libply_c.compute_geof(xyz, target_fea, args.k_nn_geof).astype(
                'float32')
    del target_fea
    features = geof
    geof[:, 3] = 2. * geof[:, 3]
    
    graph_nn["edge_weight"] = np.array(
            1. / (args.lambda_edge_weight + graph_nn["distances"] / np.mean(graph_nn["distances"])),
            dtype='float32')
    print("minimal partition...")

    # graph["target"] = np.transpose(neighbors.flatten(order='C')).astype('uint32') 存的是每个点的10个近邻点的索引
    # graph["source"] = np.matlib.repmat(range(0, n_ver), k_nn1, 1).flatten(order='F').astype('uint32')
    # components: 包含每个超点中点的索引。每个超点是一个数组，其中包含属于该超点的所有点的索引。 component: shape:[[sp1包含的点的索引],[sp2包含的点的索引],[],....]
    # in_component: 是一个数组，显示图中每个点属于哪个超点
    components, in_component = libcp.cutpursuit(features, graph_nn["source"], graph_nn["target"], graph_nn["edge_weight"], args.reg_strength)
    components = np.array(components, dtype='object')
    sp = {}
    sp["components"] = components
    sp["in_component"] = in_component
    with open("/home/ncl/dataset/SensatUrbanSampledBySSDR/grid_size_0.09/cambridge_block_28.superpoint","wb") as f:
        pickle.dump(sp, f)

    pseudo_gt = np.zeros([2, len(xyz)], dtype=np.float32)
    # [[0. 0. 0. ... 0. 0. 0.]
    #  [0. 0. 0. ... 0. 0. 0.]]
    with open("/home/ncl/dataset/SensatUrbanSampledBySSDR/grid_size_0.09/cambridge_block_28.gt", "wb") as f:
        pickle.dump(pseudo_gt, f)

    sp_num = sp_num + len(components)
    file_num = file_num + 1
    point_num = point_num + len(xyz)
    # 为每个场景创建一个从0-num(count_of_sp)的索引
    total_obj["unlabeled"][cloud_name] = np.arange(len(components))
    total_obj["file_num"] = file_num
    total_obj["sp_num"] = sp_num
    total_obj["point_num"] = point_num
    with open(os.path.join("/home/ncl/dataset/SensatUrbanSampledBySSDR/grid_size_0.09/", "total.pkl"), "wb") as f:
        pickle.dump(total_obj, f)

    print("file_num", file_num, "sp_num", sp_num, "point_num", point_num)

    # for cloud_name in train_cloud_name_list:
    #     sub_ply_file = os.path.join(tree_path, '{:s}.ply'.format(cloud_name))
    #     data = read_ply(sub_ply_file)
    #     rgb = np.vstack((data['red'], data['green'], data['blue'])).T
    #     xyz = np.vstack((data['x'], data['y'], data['z'])).T

    #     # xyz = xyz.astype('f4')
    #     # rgb = rgb.astype('uint8')
    #     # ---compute 10 nn graph-------          (xyz, 10           , 45            )
    #     graph_nn, target_fea = compute_graph_nn_2(xyz, args.k_nn_adj, args.k_nn_geof)
    #     # ---compute geometric features-------
    #     # target_fea：包含了每个点与其最近邻点之间的距离  target_fea = (neighbors.flatten()).astype('uint32')
    #     geof = libply_c.compute_geof(xyz, target_fea, args.k_nn_geof).astype(
    #             'float32')
    #     del target_fea
    #     # --compute the partition------
    #     # --- build the spg h5 file --
    #     features = geof
    #     geof[:, 3] = 2. * geof[:, 3]

    #     # graph_nn["distances"] 为 (len(xyz)*10).flatten后的数组 每10个元素代表1个点的与其最近邻10个点的距离
    #     # 对graph_nn["distances"]中的每条边计算权重
    #     graph_nn["edge_weight"] = np.array(
    #             1. / (args.lambda_edge_weight + graph_nn["distances"] / np.mean(graph_nn["distances"])),
    #             dtype='float32')
    #     print("minimal partition...")

    #     sp = {}
    #     sp["components"] = components
    #     sp["in_component"] = in_component
    #     with open(os.path.join(output_dir, cloud_name+".superpoint"),"wb") as f:
    #         pickle.dump(sp, f)

    #     pseudo_gt = np.zeros([2, len(xyz)], dtype=np.float32)
    #     # [[0. 0. 0. ... 0. 0. 0.]
    #     #  [0. 0. 0. ... 0. 0. 0.]]
    #     with open(os.path.join(output_dir, cloud_name+".gt"), "wb") as f:
    #         pickle.dump(pseudo_gt, f)

    #     sp_num = sp_num + len(components)
    #     file_num = file_num + 1
    #     point_num = point_num + len(xyz)
    #     # 为每个场景创建一个从0-num(count_of_sp)的索引
    #     total_obj["unlabeled"][cloud_name] = np.arange(len(components))

    # total_obj["file_num"] = file_num
    # total_obj["sp_num"] = sp_num
    # total_obj["point_num"] = point_num

    # with open(os.path.join(output_dir, "total.pkl"), "wb") as f:
    #     pickle.dump(total_obj, f)

    print("file_num", file_num, "sp_num", sp_num, "point_num", point_num)


def test_superpoint_distribution(args):
    all_files = glob.glob(os.path.join('/home/ncl/dataset/Semantic3D', str(args.reg_strength), 'superpoint', '*.superpoint'))
    sp_count = 0
    point_count = 0

    w = []
    all_sp_count= []
    dis = np.zeros([1000000])

    for i, file_path in enumerate(all_files):
        w.append(0)
        with open(file_path, "rb") as f:
            superpoint = pickle.load(f)
        components = superpoint["components"]
        sp_count = sp_count + len(components)
        all_sp_count.append(len(components))
        for sp in components:
            sp_size = len(sp)
            w[-1] = w[-1] + sp_size
            point_count = point_count + sp_size
            tt = int(sp_size / 10)
            dis[tt] = dis[tt] + 1

    mean_size = point_count / sp_count
    print("######### test_superpoint_less_than_5")
    for i in range(len(dis)):
        if dis[i] > 0:
            print(str(i*10)+"-"+str((i+1)*10)+": " + str(dis[i]))
    print("point_count=" + str(point_count), "sp_count=" + str(sp_count), "mean_size=" + str(mean_size))
    print("every_scene_sp_count")
    print(all_sp_count)
    print("every input size: ")
    print(w)
    
    print("#####################################")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')
    parser.add_argument('--dataset', default='semantic3d', help='s3dis/semantic3d/your_dataset')
    # --k_nn_geof ?
    parser.add_argument('--k_nn_geof', default=45, type=int, help='number of neighbors for the geometric features')
    # --k_nn_adj  ?
    parser.add_argument('--k_nn_adj', default=10, type=int, help='adjacency structure for the minimal partition')
    parser.add_argument('--lambda_edge_weight', default=1., type=float,
                        help='parameter determine the edge weight for minimal part.')
    parser.add_argument('--reg_strength', default=0.012, type=float,
                        help='regularization strength for the minimal partition')

    args = parser.parse_args()
    # ?
    sensatUrban_superpoint(args)
    # test_superpoint_distribution(args)