import argparse

from RandLANet import Network, log_out
from sampler2 import *

if __name__ == '__main__':
    """create seed samples and model weights"""

    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='the number of GPUs to use [default: 0]')
    parser.add_argument('--dataset', type=str, default='Semantic3D', choices=["S3DIS", "semantic3d", "SemanticKITTI"])
    parser.add_argument('--reg_strength', default=0.012, type=float,
                        help='regularization strength for the minimal partition')
    parser.add_argument('--epoch', default=30, type=int)
    parser.add_argument('--lr_decay', default=0.92, type=float)

    FLAGS = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(FLAGS.gpu)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # '2'



    dataset_name = FLAGS.dataset
    reg_strength = FLAGS.reg_strength
    round_num = 1
    epoch = FLAGS.epoch
    lr_decay = FLAGS.lr_decay

    sampler_args = []
    sampler_args.append("baseline")
    sampler_args.append(str(epoch)) 
    sampler_args.append(str(lr_decay))

    if dataset_name == "Semantic3D":
        # ?
        test_area_idx = 0
        input_ = "input_0.060"
        cfg = ConfigSemantic3D
        cfg.max_epoch = epoch
        cfg.lr_decays = {i: lr_decay for i in range(0, 500)}

    # record_round/Semantic3D_0_baseline-30-0.92_0.012.txt  这是一个log文件，
    round_result_file = open(os.path.join("/home/ncl/projects/SSDR-AL-main/SSRD_AL_semantic3d/record_round", dataset_name + "_" + str(test_area_idx) + "_" + get_sampler_args_str(sampler_args) + "_" + str(reg_strength) + '.txt'), 'a')

    with open(os.path.join("/home/ncl/dataset/", dataset_name, str(reg_strength), "superpoint/total.pkl"), "rb") as f:
        total_obj = pickle.load(f)
    #    434044
    total_sp_num = total_obj["sp_num"]

    # print("total_sp_num", total_sp_num)
    #                       /home/ncl/dataset/Semantic3D/input_0.060       ,  /home/ncl/dataset/Semantic3D/0.012                          ,  434044     ,
    Sampler = SeedSampler("/home/ncl/dataset/" +dataset_name + "/" + input_, "/home/ncl/dataset/" + dataset_name + "/" + str(reg_strength), total_sp_num, sampler_args)
    # 
    w = {"sp_num": 0, "p_num": 0, "p_num_list": [], "sp_id_list": [], "sub_num": 0, "sub_p_num": 0}
    # 
    sp_batch_size = total_sp_num
    #                None, 434044       , 0                       , w                    
    Sampler.sampling(None, sp_batch_size, last_round=round_num - 1, w=w)
    labeling_region_num = w["sp_num"] + w["sub_num"]
    labeling_point_num = w["p_num"] + w["sub_p_num"]
    # round= 1 |                    labeling_region_num=521370, labeling_point_num=29321981, mean_points=56.24025356272897
    log_out("round= " + str(round_num) + " |                    labeling_region_num=" + str(
            labeling_region_num) + ", labeling_point_num=" +
                str(labeling_point_num) + ", mean_points=" + str(labeling_point_num / labeling_region_num),
                round_result_file)
    # 在sampling方法中.gt  activation就已经全被修改成1了,pseudos即为ground truth
    #   [[1. 1. 1. ... 1. 1. 1.]
    #    [4. 4. 4. ... 4. 4. 0.]] 
    #               cfg Semantic3D sampler_args=['baseline','30','0.92']    0    0.012
    model = Network(cfg, dataset_name, sampler_args, test_area_idx, reg_strength=reg_strength)
    # print("model init done.")
    best_miou, best_OA = model.train2(round_num=round_num)

    log_out("round= " + str(round_num) + " | best_miou= " + str(best_miou) + ", best_OA= " + str(best_OA), round_result_file)

    model.close()
    round_result_file.close()
