from os import makedirs
import time
from os import makedirs

# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from sklearn.metrics import confusion_matrix

import helper_tf_util

from semantic3d_dataset_train import *
from semantic3d_dataset_test3 import *
from helper_ply import write_ply

def log_out(out_str, f_out):
    f_out.write(out_str + '\n')
    f_out.flush()
    print(out_str)

early_stop_count = 6

class  Network:
    def __init__(self, config, dataset_name, sampler_args, test_area_idx, reg_strength):
        self.config = config
        self.dataset_name = dataset_name
        self.sampler_args = sampler_args
        self.test_area_idx = test_area_idx
        self.reg_strength = reg_strength

        self.Log_file = open(join("/home/ncl/projects/SSDR-AL-main/SSRD_AL_semantic3d/record_log", 'log_train_' + dataset_name + "_" + str(test_area_idx) + "_" + get_sampler_args_str(sampler_args) +"_"+str(reg_strength)+'_all_train_all_val_sp1000_nomerge'+ '.txt'), 'a')
        '''调用self.init_input()方法初始化模型的输入'''
        self.init_input()
        self.training_epoch = 0
        self.correct_prediction = 0
        self.accuracy = 0
        '''
        weight = num_per_class / float(sum(num_per_class))
        ce_label_weight = 1 / (weight + 0.02)
        值越大说明 占比越小
        '''
        self.class_weights = DP.get_class_weights(dataset_name)
        print("DP.get_class_weights done.")

        self.training_step = 1

        '''在TensorFlow的作用域'layers'中构建模型。定义占位符self.is_training用于指示模型是否处于训练模式。然后调用
        self.inference(self.is_training)方法构建模型的推理部分,生成3D logits(预测概率?)和最后一层特征。'''
        with tf.variable_scope('layers', reuse=tf.AUTO_REUSE):
            self.is_training = tf.placeholder(tf.bool, shape=())
            self.logits_3d, self.last_second_features = self.inference(self.is_training)

        #####################################################################
        # Ignore the invalid point (unlabeled) when calculating the loss #
        #####################################################################
        '''在'loss'作用域中,通过忽略无效(未标记)的点来计算损失。有效的logits和激活被选出来用于损失计算。
        此外,处理了标签,以便它们与logits的维度匹配。最后,调用self.get_loss方法计算损失'''
        with tf.variable_scope('loss', reuse=tf.AUTO_REUSE):
            self.logits = tf.reshape(self.logits_3d, [-1, config.num_classes])
            self.last_second_features = tf.reshape(self.last_second_features, [-1, 32])

            self.labels = tf.reshape(self.input_labels, [-1])
            self.activation = tf.reshape(self.input_activation, [-1])
            self.pseudo = tf.reshape(self.input_pseudo, [-1])

            # Boolean mask of points that should be ignored
            ignored_bool = tf.zeros_like(self.labels, dtype=tf.bool)
            for ign_label in self.config.ignored_label_inds:   
                # ignored_label_inds = [] 但在本文工作中数据预处理时0标签数据就已经被剔除了
                ignored_bool = tf.logical_or(ignored_bool, tf.equal(self.labels, ign_label))

            # Collect logits and labels that are not ignored
            valid_idx_init = tf.squeeze(tf.where(tf.logical_not(ignored_bool)))
            valid_idx = tf.reshape(valid_idx_init, [-1])
            # 实则还是 self.logits_3d
            valid_logits = tf.gather(self.logits, valid_idx, axis=0)
            valid_activation = tf.gather(self.activation, valid_idx, axis=0)

            valid_labels_init = tf.gather(self.labels, valid_idx, axis=0)
            valid_pseudo_init = tf.gather(self.pseudo, valid_idx, axis=0)
            # Reduce label values in the range of logit shape
            reducing_list = tf.range(self.config.num_classes, dtype=tf.int32)
            inserted_value = tf.zeros((1,), dtype=tf.int32)
            for ign_label in self.config.ignored_label_inds:
                reducing_list = tf.concat([reducing_list[:ign_label], inserted_value, reducing_list[ign_label:]], 0)
            valid_labels = tf.gather(reducing_list, valid_labels_init)
            valid_pseudo = tf.gather(reducing_list, valid_pseudo_init)

            # print('valid_activation:', valid_activation)
            self.loss = self.get_loss(valid_logits, valid_pseudo, valid_activation, self.class_weights)

        with tf.variable_scope('optimizer', reuse=tf.AUTO_REUSE):
            self.learning_rate = tf.Variable(config.learning_rate, trainable=False, name='learning_rate')
            self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
            self.extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        with tf.variable_scope('results', reuse=tf.AUTO_REUSE):
            self.correct_prediction = tf.nn.in_top_k(valid_logits, valid_labels, 1)
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
            self.prob_logits = tf.nn.softmax(self.logits)

            tf.summary.scalar('learning_rate', self.learning_rate)
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('accuracy', self.accuracy)

        '''tf.GraphKeys.GLOBAL_VARIABLES是TensorFlow中的一个特殊集合,
        用于存储计算图中所有的全局变量。这个集合中的变量是跨多个训练步骤保持状态的变量,
        通常包括模型的权重和偏置等参数'''
        self.saving_path = join("./data", dataset_name, str(reg_strength), "saver", get_sampler_args_str(self.sampler_args), "snapshots")
        makedirs(self.saving_path) if not exists(self.saving_path) else None
        my_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        self.saver = tf.train.Saver(my_vars, max_to_keep=100)

        c_proto = tf.ConfigProto()
        c_proto.gpu_options.allow_growth = True
        c_proto.gpu_options.visible_device_list = '0'
        c_proto.gpu_options.per_process_gpu_memory_fraction = 0.8
        self.sess = tf.Session(config=c_proto)
        self.merged = tf.summary.merge_all()

        self.tensorboard_path = join("./data", dataset_name, str(reg_strength), "saver", get_sampler_args_str(self.sampler_args), "tensorboard")
        makedirs(self.tensorboard_path) if not exists(self.tensorboard_path) else None

        self.train_writer = tf.summary.FileWriter(self.tensorboard_path, self.sess.graph)
        self.sess.run(tf.global_variables_initializer())
        print("network init done.")

    def restore_model(self, round_num):
        if round_num == 1:

            restore_snap = join("./data", self.dataset_name, str(self.reg_strength), "saver", "seed", "snapshots", 'snap-{:d}'.format(1))
            self.saver.restore(self.sess, restore_snap)
            print("Model restored from seed")

        # Load trained model
        else:
            restore_snap = join(self.saving_path, 'snap-{:d}'.format(round_num))
            self.saver.restore(self.sess, restore_snap)
            print("Model restored from " + restore_snap)

    def restore_baseline_model(self):

        restore_snap = join("./data", self.dataset_name, str(self.reg_strength), "saver", "baseline", "snapshots", 'snap-{:d}'.format(1))
        self.saver.restore(self.sess, restore_snap)
        print("Model restored from baseline")

    def init_input(self):
        with tf.variable_scope('inputs', reuse=tf.AUTO_REUSE):
            self.input_xyz, self.input_neigh_idx, self.input_sub_idx, self.input_interp_idx = [], [], [], []
            for i in range(self.config.num_layers):
                self.input_xyz.append(tf.placeholder(tf.float32, shape=[None, None, 3]))  #[batch, point, 3]
                self.input_neigh_idx.append(tf.placeholder(tf.int32, shape=[None, None, self.config.k_n]))  #[batch, point, 16]
                self.input_sub_idx.append(tf.placeholder(tf.int32, shape=[None, None, self.config.k_n]))  #[batch, point, 16]
                self.input_interp_idx.append(tf.placeholder(tf.int32, shape=[None, None, 1]))  #[batch, point, 3]
            self.input_features = tf.placeholder(tf.float32, shape=[None, None, 6])  #[batch, point, 3+3] #stacked_features = np.concatenate([transformed_xyz, rgb], axis=-1)
            self.input_labels = tf.placeholder(tf.int32, shape=[None, None])  # [batch, point]
            self.input_activation = tf.placeholder(tf.int32, shape=[None, None])  # [batch, point]
            self.input_pseudo = tf.placeholder(tf.int32, shape=[None, None])  # [batch, point]
            self.input_input_inds = tf.placeholder(tf.int32, shape=[None, None])  # [batch, point]
            self.input_cloud_inds = tf.placeholder(tf.int32, shape=[None])  # [batch]
        print("network init_input done.")

    def inference(self, is_training):
        '''此方法通过构建一个编码器-解码器结构处理点云数据，其中编码器逐渐减少空间维度同时增加特征维度，
        解码器则逐步恢复空间维度并减少特征维度，最终输出每个点的类别预测。'''
        # 调用 self.inference(self.is_training) 方法构建模型的推理部分,生成3D logits和最后一层特征。
        d_out = self.config.d_out # d_out = [16, 64, 128, 256, 512]  # feature dimension
        feature = self.input_features  # 就是  shape=[batch_size, point_num, 6]
        # # 通过一个全连接层扩展特征维度到8
        feature = tf.layers.dense(feature, 8, activation=None, name='fc0')
        # 应用LeakyReLU激活函数和批量归一化
        feature = tf.nn.leaky_relu(tf.layers.batch_normalization(feature, -1, 0.99, 1e-6, training=is_training))
        # 在第二维度上增加一个维度，为了后续的卷积操作
        feature = tf.expand_dims(feature, axis=2)

        # 编码器部分，构建多层编码器网络
        f_encoder_list = []
        for i in range(self.config.num_layers): # 5
            # 对每层输入应用扩张残差块，得到编码后的特征 #  扩张卷积用于增加后面网络层的感知野，以补偿去除下采样而引起的感知野减少。
            f_encoder_i = self.dilated_res_block(feature, self.input_xyz[i], self.input_neigh_idx[i], d_out[i],
                                                 'Encoder_layer_' + str(i), is_training)
            # 对编码后的特征进行随机采样
            f_sampled_i = self.random_sample(f_encoder_i, self.input_sub_idx[i])
            feature = f_sampled_i
            # 保存编码层的输出，用于后续的解码过程
            if i == 0:
                f_encoder_list.append(f_encoder_i)
            f_encoder_list.append(f_sampled_i)
        # 解码器部分，从编码器的最后一层开始解码
        feature = helper_tf_util.conv2d(f_encoder_list[-1], f_encoder_list[-1].get_shape()[3].value, [1, 1],
                                        'decoder_0',
                                        [1, 1], 'VALID', True, is_training)

        f_decoder_list = []
        for j in range(self.config.num_layers):
            # 对特征进行最近邻插值，用于上采样
            f_interp_i = self.nearest_interpolation(feature, self.input_interp_idx[-j - 1])
            # 将上采样的特征与对应的编码器层的特征进行拼接，并应用转置卷积
            f_decoder_i = helper_tf_util.conv2d_transpose(tf.concat([f_encoder_list[-j - 2], f_interp_i], axis=3),
                                                          f_encoder_list[-j - 2].get_shape()[-1].value, [1, 1],
                                                          'Decoder_layer_' + str(j), [1, 1], 'VALID', bn=True,
                                                          is_training=is_training)
            feature = f_decoder_i
            f_decoder_list.append(f_decoder_i)
        # 应用两个卷积层进一步提取特征
        f_layer_fc1 = helper_tf_util.conv2d(f_decoder_list[-1], 64, [1, 1], 'fc1', [1, 1], 'VALID', True, is_training)
        f_layer_fc2 = helper_tf_util.conv2d(f_layer_fc1, 32, [1, 1], 'fc2', [1, 1], 'VALID', True, is_training)
        # 应用dropout防止过拟合
        f_layer_drop = helper_tf_util.dropout(f_layer_fc2, keep_prob=0.5, is_training=is_training, scope='dp1')
        # 最后一个卷积层输出最终的分类结果
        f_layer_fc3 = helper_tf_util.conv2d(f_layer_drop, self.config.num_classes, [1, 1], 'fc', [1, 1], 'VALID', False,
                                            is_training, activation_fn=None)
        # 压缩维度，去除多余的维度
        f_out = tf.squeeze(f_layer_fc3, [2])
        # 返回最终的输出和倒数第二层的特征
        return f_out, f_layer_fc2

    def get_feed_dict(self, dat, is_training):
        feed_dict = {self.is_training: is_training}
        # print("----------------------")
        for j in range(self.config.num_layers):
            feed_dict[self.input_xyz[j]] = np.squeeze(dat[j].numpy(), axis=1)  # [batch, point, 3]
            feed_dict[self.input_neigh_idx[j]] = np.squeeze(dat[self.config.num_layers + j].numpy(), axis=1)  # [batch, point, 16]
            feed_dict[self.input_sub_idx[j]] = np.squeeze(dat[2 * self.config.num_layers + j].numpy(), axis=1)  # [batch, point, 16]
            feed_dict[self.input_interp_idx[j]] = np.squeeze(dat[3 * self.config.num_layers + j].numpy(), axis=1)  # [batch, point, 3]



        feed_dict[self.input_features] = np.squeeze(dat[4 * self.config.num_layers+0].numpy(), axis=1)  # [batch, point, 3+3]
        feed_dict[self.input_labels] = np.squeeze(dat[4 * self.config.num_layers+1].numpy(), axis=1)  # [batch, point]
        feed_dict[self.input_activation] = np.squeeze(dat[4 * self.config.num_layers+2].numpy(), axis=1)  # [batch, point]
        feed_dict[self.input_pseudo] = np.squeeze(dat[4 * self.config.num_layers+3].numpy(), axis=1)  # [batch, point]
        feed_dict[self.input_input_inds] = np.squeeze(dat[4 * self.config.num_layers+4].numpy(), axis=1)  # [batch, point]
        feed_dict[self.input_cloud_inds] = np.squeeze(dat[4 * self.config.num_layers+5].numpy(), axis=1)  # [batch]



        return feed_dict

    def get_feed_dict_sub(self, dat_sub):
        feed_dict = {self.is_training: False}
        # print("----------------------")
        for j in range(self.config.num_layers):
            feed_dict[self.input_xyz[j]] = np.squeeze(dat_sub[j].numpy(), axis=1)  # [batch, point, 3]
            feed_dict[self.input_neigh_idx[j]] = np.squeeze(dat_sub[self.config.num_layers + j].numpy(), axis=1)  # [batch, point, 16]
            feed_dict[self.input_sub_idx[j]] = np.squeeze(dat_sub[2 * self.config.num_layers + j].numpy(), axis=1)  # [batch, point, 16]
            feed_dict[self.input_interp_idx[j]] = np.squeeze(dat_sub[3 * self.config.num_layers + j].numpy(), axis=1)  # [batch, point, 3]

        feed_dict[self.input_features] = np.squeeze(dat_sub[4 * self.config.num_layers+0].numpy(), axis=1)  # [batch, point, 3+3]
        feed_dict[self.input_labels] = np.squeeze(dat_sub[4 * self.config.num_layers+1].numpy(), axis=1)  # [batch, point]
        feed_dict[self.input_activation] = np.squeeze(dat_sub[4 * self.config.num_layers+2].numpy(), axis=1)  # [batch, point]
        feed_dict[self.input_pseudo] = np.squeeze(dat_sub[4 * self.config.num_layers+3].numpy(), axis=1)  # [batch, point]
        feed_dict[self.input_input_inds] = np.squeeze(dat_sub[4 * self.config.num_layers+4].numpy(), axis=1)  # [batch, point]
        feed_dict[self.input_cloud_inds] = np.squeeze(dat_sub[4 * self.config.num_layers+5].numpy(), axis=1)  # [batch]

        return feed_dict



    def get_feed_dict_train(self, dat):
        feed_dict = {self.is_training: True}
        for j in range(self.config.num_layers):
            feed_dict[self.input_xyz[j]] = dat[j]  # [batch, point, 3]
            feed_dict[self.input_neigh_idx[j]] = dat[self.config.num_layers + j]  # [batch, point, 16]
            feed_dict[self.input_sub_idx[j]] = dat[2 * self.config.num_layers + j]  # [batch, point, 16]
            feed_dict[self.input_interp_idx[j]] = dat[3 * self.config.num_layers + j]  # [batch, point, 3]
        feed_dict[self.input_features] = dat[4 * self.config.num_layers + 0]  # [batch, point, 3+3]
        feed_dict[self.input_labels] = dat[4 * self.config.num_layers + 1]  # [batch, point]
        feed_dict[self.input_activation] = dat[4 * self.config.num_layers + 2]  # [batch, point]
        feed_dict[self.input_pseudo] = dat[4 * self.config.num_layers + 3]  # [batch, point]
        feed_dict[self.input_input_inds] = dat[4 * self.config.num_layers + 4]  # [batch, point]
        feed_dict[self.input_cloud_inds] = dat[4 * self.config.num_layers + 5]  # [batch]

        return feed_dict

    def get_feed_dict_test(self, dat):
        feed_dict = {self.is_training: False}
        for j in range(self.config.num_layers):
            feed_dict[self.input_xyz[j]] = dat[j]  # [batch, point, 3]
            feed_dict[self.input_neigh_idx[j]] = dat[self.config.num_layers + j]  # [batch, point, 16]
            feed_dict[self.input_sub_idx[j]] = dat[2 * self.config.num_layers + j]  # [batch, point, 16]
            feed_dict[self.input_interp_idx[j]] = dat[3 * self.config.num_layers + j]  # [batch, point, 3]
        feed_dict[self.input_features] = dat[4 * self.config.num_layers+0]  # [batch, point, 3+3]
        feed_dict[self.input_labels] = dat[4 * self.config.num_layers+1]  # [batch, point]
        feed_dict[self.input_input_inds] = dat[4 * self.config.num_layers+2]  # [batch, point]
        feed_dict[self.input_cloud_inds] = dat[4 * self.config.num_layers+3]  # [batch]

        return feed_dict

    def reset_lr(self):
        op = self.learning_rate.assign(self.config.learning_rate)
        self.sess.run(op)


    def train2(self, round_num):
        self.reset_lr()

        self.training_epoch = 0
        log_out("Round " + str(round_num) + ' | ****EPOCH {}****'.format(self.training_epoch), self.Log_file)
        best_miou = 0
        best_OA = 0
        train_data, test_data, test_probs = None, None, None
        if self.dataset_name == "Semantic3D":
            # train_data 包含每个训练场景下采样后的 input_trees 、input_colors、 input_labels 、input_names、 input_activations += [pseudo_gt[0]]、
            # input_pseudos += [pseudo_gt[1]] 、possibility、  min_possibility
            train_data = Semantic3D_Dataset_Train(reg_strength=self.reg_strength, sampler_args=self.sampler_args,
                                                  round_num=round_num)
            test_data = Semantic3D_Dataset_Test()
            test_probs = [np.zeros(shape=[l.shape[0], self.config.num_classes], dtype=np.float32) for l in
                          test_data.input_labels]

        while self.training_epoch < self.config.max_epoch:
            t_start = time.time()
            activation_sum = 0
            # print("train_data.get_batch()")
            dat = train_data.get_batch()
            while len(dat) > 0:
                '''
                self.train_op：执行模型的训练步骤。
                self.extra_update_ops：执行额外的更新操作，如批量归一化的统计更新。
                self.merged：合并所有的TensorBoard摘要。
                self.loss：计算当前批次的损失。
                self.logits：模型输出的原始预测值。
                self.labels：当前批次的真实标签。
                self.activation：激活状态，具体含义根据上下文可能有所不同。
                self.accuracy：当前批次的准确率
                '''
                ops = [self.train_op,
                       self.extra_update_ops,
                       self.merged,
                       self.loss,
                       self.logits,
                       self.labels,
                       self.activation,
                       self.accuracy]
                _, _, summary, l_out, probs, labels, acti, acc = self.sess.run(ops, feed_dict=self.get_feed_dict_train(dat))

                activation_sum = activation_sum + np.sum(acti)

                self.train_writer.add_summary(summary, self.training_step)
                t_end = time.time()
                if self.training_step % 50 == 0:
                    message = 'Step {:08d} L_out={:5.3f} Acc={:4.2f} ''---{:8.2f} ms/batch'
                    log_out(message.format(self.training_step, l_out, acc, 1000 * (t_end - t_start)), self.Log_file)
                self.training_step += 1

                dat = train_data.get_batch()

            # 记录本轮训练的信息
            log_out("Round " + str(round_num) + ' | epoch=' + str(self.training_epoch) + ", train costTime=" + str(
                time.time() - t_start) + ", | total_activation_sum=" + str(activation_sum), self.Log_file)
            self.training_epoch += 1
            # Update learning rate
            op = self.learning_rate.assign(tf.multiply(self.learning_rate,
                                                       self.config.lr_decays[self.training_epoch]))
            self.sess.run(op)
            # 在训练进入后期时，评估模型性能
            # if self.training_epoch >= int(self.config.max_epoch * 0.6):
            if self.training_epoch >= int(self.config.max_epoch * 0):
                tt12 = time.time()
                if self.dataset_name == "Semantic3D":
                    m_iou, OA = self.evaluate_test_semantic3d(dataset=test_data, test_probs=test_probs)

                log_out("Round " + str(round_num) + ' | epoch=' + str(self.training_epoch) + ", current m_iou=" + str(m_iou), self.Log_file)
                if m_iou > best_miou:
                    # Save the best model
                    snapshot_directory = join(self.saving_path)
                    makedirs(snapshot_directory) if not exists(snapshot_directory) else None
                    self.saver.save(self.sess, join(self.saving_path, "snap"), global_step=round_num)
                    early_count = 0
                    best_miou = m_iou
                    best_OA = OA

                log_out("Round " + str(round_num) + ' | Best m_IoU is: {:5.3f}'.format(best_miou) +
                        ', OA is: {:5.3f}'.format(best_OA) +
                        " | val costTime=" + str(time.time() - tt12), self.Log_file)

            log_out("Round " + str(round_num) + ' | ****EPOCH {}****'.format(self.training_epoch), self.Log_file)

        return best_miou, best_OA

    def close(self):
        print('finished')
        self.train_writer.close()
        self.sess.close()
        self.Log_file.close()

    def evaluate_test_semantic3d(self, dataset, test_probs):
        num_votes = 100
        dataset.init_possibility()
        # Smoothing parameter for votes
        test_smooth = 0.98
        val_proportions = np.zeros(self.config.num_classes, dtype=np.float32)
        i = 0
        for label_val in dataset.label_values:
            if label_val not in dataset.ignored_labels:
                val_proportions[i] = np.sum([np.sum(labels == label_val) for labels in dataset.val_labels])
                i += 1

        step_id = 0
        epoch_id = 0
        last_min = -0.5

        while last_min < num_votes:
            dat = dataset.get_batch()
            while len(dat) > 0:
                ops = (self.prob_logits,
                       self.labels,
                       self.input_input_inds,
                       self.input_cloud_inds,
                       )

                stacked_probs, stacked_labels, point_idx, cloud_idx = self.sess.run(ops,
                                                                                    feed_dict=self.get_feed_dict_test(
                                                                                        dat))

                correct = np.sum(np.argmax(stacked_probs, axis=1) == stacked_labels)
                acc = correct / float(np.prod(np.shape(stacked_labels)))

                stacked_probs = np.reshape(stacked_probs,
                                           [self.config.val_batch_size, self.config.num_points,
                                            self.config.num_classes])

                for j in range(np.shape(stacked_probs)[0]):
                    probs = stacked_probs[j, :, :]
                    inds = point_idx[j, :]
                    c_i = cloud_idx[j]
                    test_probs[c_i][inds] = test_smooth * test_probs[c_i][inds] + (1 - test_smooth) * probs
                step_id += 1

                dat = dataset.get_batch()

            dataset.reset_current_batch()
            # Save predicted cloud
            new_min = np.min(dataset.min_possibility)
            # log_string('Epoch {:3d}, end. Min possibility = {:.1f}'.format(epoch_id, new_min), self.log_out)

            test_probs_insert = test_probs

            if last_min + 4 < new_min:
                # Update last_min
                last_min = new_min
                confusion_list = []

                num_val = len(dataset.input_labels)

                for i_test in range(num_val):
                    probs = test_probs_insert[i_test]
                    preds = dataset.label_values[np.argmax(probs, axis=1)].astype(np.int32)
                    labels = dataset.input_labels[i_test]

                    if not self.config.ignored_label_inds or len(self.config.ignored_label_inds) == 0:
                        pred_valid = preds
                        labels_valid = labels
                    else:

                        invalid_idx = np.where(labels == self.config.ignored_label_inds)[0]
                        labels_valid = np.delete(labels, invalid_idx)
                        labels_valid = labels_valid - 1
                        pred_valid = np.delete(preds, invalid_idx)


                    confusion_list += [
                        confusion_matrix(labels_valid, pred_valid, np.arange(0, ConfigSemantic3D.num_classes, 1))]

                # Regroup confusions
                C = np.sum(np.stack(confusion_list), axis=0).astype(np.float32)

                # Rescale with the right number of point per class
                C *= np.expand_dims(val_proportions / (np.sum(C, axis=1) + 1e-6), 1)

                # Compute IoUs
                IoUs = DP.IoU_from_confusions(C)
                m_IoU = np.mean(IoUs)
                s = '{:5.2f} | '.format(100 * m_IoU)
                for IoU in IoUs:
                    s += '{:5.2f} '.format(100 * IoU)
                # log_out(s + '\n', self.Log_file)

                if int(np.ceil(new_min)) % 1 == 0:

                    # Project predictions
                    proj_probs_list = []

                    for i_val in range(num_val):
                        # Reproject probs back to the evaluations points
                        proj_idx = dataset.val_proj[i_val]
                        probs = test_probs_insert[i_val][proj_idx, :]
                        proj_probs_list += [probs]

                    val_total_correct = 0
                    val_total_seen = 0
                    confusion_list = []
                    for i_test in range(num_val):
                        # Get the predicted labels
                        preds = dataset.label_values[np.argmax(proj_probs_list[i_test], axis=1)].astype(np.uint8)
                        labels = dataset.val_labels[i_test]

                        if not self.config.ignored_label_inds or len(self.config.ignored_label_inds) == 0:
                            pred_valid = preds
                            labels_valid = labels
                        else:

                            invalid_idx = np.where(labels == self.config.ignored_label_inds)[0]
                            labels_valid = np.delete(labels, invalid_idx)
                            labels_valid = labels_valid - 1
                            pred_valid = np.delete(preds, invalid_idx)



                        correct = np.sum(pred_valid == labels_valid)
                        val_total_correct += correct
                        val_total_seen += len(labels_valid)

                        confusion_list += [
                            confusion_matrix(labels_valid, pred_valid, np.arange(0, ConfigSemantic3D.num_classes, 1))]

                    # Regroup confusions
                    C = np.sum(np.stack(confusion_list), axis=0)

                    OA = val_total_correct / float(val_total_seen)
                    IoUs = DP.IoU_from_confusions(C)
                    m_IoU = np.mean(IoUs)
                    s = '{:5.2f} | '.format(100 * m_IoU)
                    for IoU in IoUs:
                        s += '{:5.2f} '.format(100 * IoU)
                    print('finished \n')
                    return m_IoU, OA

            epoch_id += 1
            step_id = 0
            continue

        return m_IoU, 0

    def get_loss(self, logits, labels, activation, pre_cal_weights):
        # calculate the weighted cross entropy according to the inverse frequency
        class_weights = tf.convert_to_tensor(pre_cal_weights, dtype=tf.float32)
        one_hot_labels = tf.one_hot(labels, depth=self.config.num_classes)


        self.ssdr_one_hot_labels = one_hot_labels
        self.ssdr_logits = logits


        unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=one_hot_labels)

        weights = tf.reduce_sum(class_weights * one_hot_labels, axis=1)

        weighted_losses = unweighted_losses * weights
        weighted_losses_acti = weighted_losses * tf.cast(activation, dtype=tf.float32)
        output_loss = tf.reduce_mean(weighted_losses_acti)
        return output_loss

    def dilated_res_block(self, feature, xyz, neigh_idx, d_out, name, is_training):
        f_pc = helper_tf_util.conv2d(feature, d_out // 2, [1, 1], name + 'mlp1', [1, 1], 'VALID', True, is_training)
        f_pc = self.building_block(xyz, f_pc, neigh_idx, d_out, name + 'LFA', is_training)
        f_pc = helper_tf_util.conv2d(f_pc, d_out * 2, [1, 1], name + 'mlp2', [1, 1], 'VALID', True, is_training,
                                     activation_fn=None)
        shortcut = helper_tf_util.conv2d(feature, d_out * 2, [1, 1], name + 'shortcut', [1, 1], 'VALID',
                                         activation_fn=None, bn=True, is_training=is_training)
        return tf.nn.leaky_relu(f_pc + shortcut)

    def building_block(self, xyz, feature, neigh_idx, d_out, name, is_training):

        d_in = feature.get_shape()[-1].value
        f_xyz = self.relative_pos_encoding(xyz, neigh_idx)
        f_xyz = helper_tf_util.conv2d(f_xyz, d_in, [1, 1], name + 'mlp1', [1, 1], 'VALID', True, is_training)
        f_neighbours = self.gather_neighbour(tf.squeeze(feature, axis=2), neigh_idx)
        f_concat = tf.concat([f_neighbours, f_xyz], axis=-1)
        f_pc_agg = self.att_pooling(f_concat, d_out // 2, name + 'att_pooling_1', is_training)

        f_xyz = helper_tf_util.conv2d(f_xyz, d_out // 2, [1, 1], name + 'mlp2', [1, 1], 'VALID', True, is_training)
        f_neighbours = self.gather_neighbour(tf.squeeze(f_pc_agg, axis=2), neigh_idx)
        f_concat = tf.concat([f_neighbours, f_xyz], axis=-1)
        f_pc_agg = self.att_pooling(f_concat, d_out, name + 'att_pooling_2', is_training)
        return f_pc_agg

    def relative_pos_encoding(self, xyz, neigh_idx):
        neighbor_xyz = self.gather_neighbour(xyz, neigh_idx)
        xyz_tile = tf.tile(tf.expand_dims(xyz, axis=2), [1, 1, tf.shape(neigh_idx)[-1], 1])
        relative_xyz = xyz_tile - neighbor_xyz
        relative_dis = tf.sqrt(tf.reduce_sum(tf.square(relative_xyz), axis=-1, keepdims=True))
        relative_feature = tf.concat([relative_dis, relative_xyz, xyz_tile, neighbor_xyz], axis=-1)
        return relative_feature

    @staticmethod
    def random_sample(feature, pool_idx):

        feature = tf.squeeze(feature, axis=2)
        num_neigh = tf.shape(pool_idx)[-1]
        d = feature.get_shape()[-1]
        batch_size = tf.shape(pool_idx)[0]
        pool_idx = tf.reshape(pool_idx, [batch_size, -1])
        pool_features = tf.batch_gather(feature, pool_idx)
        pool_features = tf.reshape(pool_features, [batch_size, -1, num_neigh, d])
        pool_features = tf.reduce_max(pool_features, axis=2, keepdims=True)
        return pool_features

    @staticmethod
    def nearest_interpolation(feature, interp_idx):

        feature = tf.squeeze(feature, axis=2)
        batch_size = tf.shape(interp_idx)[0]
        up_num_points = tf.shape(interp_idx)[1]
        interp_idx = tf.reshape(interp_idx, [batch_size, up_num_points])
        interpolated_features = tf.batch_gather(feature, interp_idx)
        interpolated_features = tf.expand_dims(interpolated_features, axis=2)
        return interpolated_features

    @staticmethod
    def gather_neighbour(pc, neighbor_idx):
        # gather the coordinates or features of neighboring points
        batch_size = tf.shape(pc)[0]
        num_points = tf.shape(pc)[1]
        d = pc.get_shape()[2].value
        index_input = tf.reshape(neighbor_idx, shape=[batch_size, -1])
        features = tf.batch_gather(pc, index_input)
        features = tf.reshape(features, [batch_size, num_points, tf.shape(neighbor_idx)[-1], d])
        return features

    @staticmethod
    def att_pooling(feature_set, d_out, name, is_training):
        batch_size = tf.shape(feature_set)[0]
        num_points = tf.shape(feature_set)[1]
        num_neigh = tf.shape(feature_set)[2]
        d = feature_set.get_shape()[3].value
        f_reshaped = tf.reshape(feature_set, shape=[-1, num_neigh, d])
        att_activation = tf.layers.dense(f_reshaped, d, activation=None, use_bias=False, name=name + 'fc')
        att_scores = tf.nn.softmax(att_activation, axis=1)
        f_agg = f_reshaped * att_scores
        f_agg = tf.reduce_sum(f_agg, axis=1)
        f_agg = tf.reshape(f_agg, [batch_size, num_points, 1, d])
        f_agg = helper_tf_util.conv2d(f_agg, d_out, [1, 1], name + 'mlp', [1, 1], 'VALID', True, is_training)
        return f_agg

if __name__=="__main__":
    b = tf.ones()