import os
import glob
import math
import openpyxl
from pprint import pprint
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler
import torch.utils.data as data
# ranger
from ranger import Ranger
from cvae_base import CVAE


class GRU_attention(nn.Module):
    def __init__(self, input_size=10, hidden_layer_size=50, num_layers=1):
        super().__init__()
        self.hidden_layer_size = hidden_layer_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_layer_size,num_layers, batch_first=True)
        self.fc1 = nn.Linear(hidden_layer_size, 30)

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(5, 1), stride=(1, 1), padding=(2, 0))
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 1), stride=(1, 1), padding=(2, 0))
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(5, 1), stride=(1, 1), padding=(2, 0))
        self.fc2 = nn.Linear(200*10, 30)
        self.relu = nn.ReLU()

        encoder_layer_sizes = [6, 256, 128]
        latent_size = 32
        decoder_layer_sizes = [128, 256, 3]
        conditional = True
        num_labels = 3

        # 初始化模型
        self.cvae = CVAE(encoder_layer_sizes, latent_size, decoder_layer_sizes, conditional, num_labels)

    def forward(self, input_seq, labels):
        h_0 = torch.zeros(self.num_layers, input_seq.size(0), self.hidden_layer_size).to(input_seq.device)
        output, _= self.gru(input_seq, h_0)#[16,200,50]
        output =self.fc1(output[:, -1, :])
        #print(output.shape)#16,30
        output = output.view(-1,10,3)
        #print(output.shape)#16,10,3
        
        cnn_input = input_seq[:, :, :]
        #print(cnn_input.shape)#16,200,10
        out2 = cnn_input.unsqueeze(1) # 增加一个维度，以匹配卷积层的输入 (batch, channels, height, width)
        #print(out2.shape)#16,1,200,10
        out2 = self.relu(self.conv1(out2)) # 16 * 1 * 10 * 
        #print(out2.shape)#16,16,200,10
        out2 = self.relu(self.conv2(out2))  # 32 * 1 * 10 * 10
        #print(out2.shape)#16,32,200,10
        out2 = self.conv3(out2)  # 1 * 1 * 10 * 10
        #print(out2.shape)#16,1,200,10
        out2 = out2.view(-1, 200*10)  # -1表示自动计算该维度的大小
        #print(out2.shape)#16,2000
        out2 = self.fc2(out2)
        #print(out2.shape)#16,30
        out2 = out2.view(-1,10,3)#16,10,3

        out = torch.cat((output,out2),dim=2)    # [batch size, 10, 6]
       # print(out.shape)#16,10,20
        
        H_yy, means,log_var, z = self.cvae(out, labels)  # [batch size, 10, 3]
       # print(labels.shape)
        # recon_y = acc_to_abs(H_yy, cnn_input)

        return H_yy, means,log_var, z

# 将加速度转换为绝对坐标
def acc_to_abs(acc,obs,delta=1):
    acc = acc.permute(2,1,0)
    pred = torch.empty_like(acc)
    pred[0] = 2*obs[-1] - obs[0] + acc[0]
    pred[1] = 2*pred[0] - obs[-1] + acc[1]
    
    for i in range(2,acc.shape[0]):
        pred[i] = 2*pred[i-1] - pred[i-2] + acc[i]
    return pred.permute(2,1,0)

# 计算均方根误差
def rmse(y1,y2):
    criterion = nn.MSELoss()

    # return loss
    return torch.sqrt(criterion(y1, y2))

def mae(y1, y2):
    criterion = nn.L1Loss()
    loss = criterion(y1, y2)
    return loss

def loss_func(recon_y,y,mean,log_var):
    traj_loss = mae(recon_y,y)
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return traj_loss + KLD, traj_loss


def rmse_loss(y_pred, y_true):
    return torch.sqrt(torch.mean((y_pred - y_true) ** 2))


class Mydataset(data.Dataset):
    def __init__(self, in_len, out_len, scaler_seq, scaler_label):
        folder_path = "total1"
        # folder_path = "/home/zzh/test_data"
        excel_files = glob.glob(os.path.join(folder_path, '*.xlsx'), recursive=True)
        all_train_seq, all_train_label = list(), list()
        for file in excel_files:
            workbook = openpyxl.load_workbook(file)
            sheet = workbook.active
            num_rows = sheet.max_row
            start_index = int(num_rows * 0.15) + 1  # 因为openpyxl的行和列索引从1开始
            end_index = int(num_rows * 0.95)

            tarEli, tar_x, tar_y, tar_z, position_change = [], [], [], [], []
            v,v_change,r,eli_change,h_change= [], [], [], [], []
            for row in sheet.iter_rows(min_row=start_index, max_row=end_index, values_only=True):
                tarEli.append(row[1])
                tar_x.append(row[2])
                tar_y.append(row[3])
                tar_z.append(row[4])
                position_change.append(row[5])

                v.append(row[6])
                v_change.append(row[7])
                r.append(row[8])
                eli_change.append(row[9])
                h_change.append(row[10])

            tarEli = pd.Series(tarEli)
            tar_x = pd.Series(tar_x)
            tar_y = pd.Series(tar_y)
            tar_z = pd.Series(tar_z)
            position_change = pd.Series(position_change)

            v = pd.Series(v)
            v_change = pd.Series(v_change)
            r = pd.Series(r)
            eli_change = pd.Series(eli_change)
            h_change = pd.Series(h_change)
           
            L = len(tarEli)
            for i in range(0, L - in_len - out_len, 30):
                data_tarEli = tarEli[i:i + in_len]
                data_tar_x = tar_x[i:i + in_len]
                data_tar_y = tar_y[i:i + in_len]
                data_tar_z = tar_z[i:i + in_len]
                data_position_change = position_change[i:i + in_len]

                data_v = v[i:i + in_len]
                data_v_change = v_change[i:i + in_len]
                data_r = r[i:i + in_len]
                data_eli_change = eli_change[i:i + in_len]
                data_h_change = h_change[i:i + in_len]
        
                train_seq = list(zip(data_tarEli,data_tar_x,data_tar_y,data_tar_z,data_position_change,data_v,data_v_change,data_r,data_eli_change,data_h_change))

        
                tar_x_train_label = tar_x[i + in_len:i + in_len + out_len]
                tar_y_train_label = tar_y[i + in_len:i + in_len + out_len]
                tar_z_train_label = tar_z[i + in_len:i + in_len + out_len]
             
            
                train_label = list(zip(tar_x_train_label,tar_y_train_label,tar_z_train_label))

                all_train_seq.append(train_seq)
                all_train_label.append(train_label)

        #  all_train_seq 是一个三维数组，shape为 (num_samples, num_rows, num_cols)
        # 将所有训练数据拼接在一起
        concatenated_data = np.concatenate(all_train_seq, axis=0)  # 在第一个轴上进行拼接
        normalized_data = scaler_seq.fit_transform(concatenated_data)

        # 将归一化后的数据拆分回原来的样本
        start_index = 0
        self.normalized_all_seq = []
        for train_data in all_train_seq:
            num_samples = len(train_data)
            end_index = start_index + num_samples
            self.normalized_all_seq.append(normalized_data[start_index:end_index, :])
            start_index = end_index

        # 将归一化后的数据存储在一个新的数组中，形状与 all_train_data 相同
        self.normalized_all_train_seq = np.array(self.normalized_all_seq)

        concatenated_data = np.concatenate(all_train_label, axis=0)  # 在第一个轴上进行拼接
        normalized_data = scaler_label.fit_transform(concatenated_data)
        # 将归一化后的数据拆分回原来的样本
        start_index = 0
        self.normalized_all_label = []
        for train_label in all_train_label:
            num_samples = len(train_label)
            end_index = start_index + num_samples
            self.normalized_all_label.append(normalized_data[start_index:end_index, :])
            start_index = end_index

        # 将归一化后的数据存储在一个新的数组中，形状与 all_train_data 相同
        self.normalized_all_train_label = np.array(self.normalized_all_label)

    def __getitem__(self, index):
        input_data = self.normalized_all_train_seq[index]
        target = self.normalized_all_train_label[index]
        return input_data, target

    def __len__(self):
        return len(self.normalized_all_train_seq)

    
    
    
    
    
def evaluate(model, test_loader, device):
    model.eval()  # 设置模型为评估模式
    total_loss = 0.0
    with torch.no_grad():  # 在此模式下，所有计算都不会跟踪梯度
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(torch.float32).to(device), labels.to(torch.float32).to(device)
            y_pred,m,v,_ = model(inputs,labels)
            y_pred_cpu = y_pred.cpu().detach().numpy()
            y_pred = y_pred_cpu.reshape(-1, 3)
               # print(y_pred.shape)
            y_pred1 = scaler_label.inverse_transform(y_pred)
               # print(y_pred.shape)
            labels_cpu = labels.cpu().detach().numpy()
            labels = labels_cpu.reshape(-1, 3)
               # print(labels.shape)
            labels1 = scaler_label.inverse_transform(labels)
               # print(labels.shape)
                # single_loss, single_real_loss = loss_func(y_pred, labels,m,v)
            y_pred = torch.tensor(y_pred1)
                #print(y_pred.shape)
            labels = torch.tensor(labels1)
                #print(labels.shape)
            y_pred = y_pred.to(torch.float32).to(device)
            labels = labels.to(torch.float32).to(device)
            #print("逆归一化后的预测数据",y_pred)
            #print("逆归一化后的原数据",labels)
            loss =rmse_loss(y_pred,labels)
            total_loss += loss.item()
    return labels1, y_pred1, total_loss / len(test_loader)  




if __name__ == '__main__':
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(device)
    # 归一化
    scaler_seq, scaler_label = MinMaxScaler(feature_range=(-1, 1)), MinMaxScaler(feature_range=(-1, 1))
    datasets = Mydataset(200, 10, scaler_seq, scaler_label)
    print(len(datasets))
    test_size = int(len(datasets) * 0.3)
    train_size = len(datasets) - test_size
    train_dataset, test_dataset = data.random_split(datasets, [train_size, test_size])
    print(type(train_dataset), len(train_dataset))
    print(type(test_dataset), len(test_dataset))

    batch_size = 16
    train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, num_workers=2, shuffle=True)
    test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size, num_workers=1, shuffle=True)

    model = GRU_attention().to(device)
    print(model)

    loss_function = nn.L1Loss()
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.00008)
    optimizer = Ranger(model.parameters(), lr = 0.0002)

    print(
        "enhance data + GRU + CNN + LR = 0.0002 + MAEE + batch_size = 16 + in_len = 200 + out_len = 10 + ranger")

    epochs = 1000
    scaler2 = GradScaler()

    for epoch in range(epochs):
        model.train()  # 确保模型处于训练模式
        running_loss = 0.0
        running_num = 0

        for batch in train_dataloader:
            inputs, labels = batch
            inputs = inputs.to(torch.float16).to(device)
            labels = labels.to(torch.float16).to(device)
            optimizer.zero_grad()

            with autocast():
                y_pred,m,v,_ = model(inputs,labels)
                single_loss = loss_function(y_pred, labels)

            scaler2.scale(single_loss).backward()
            scaler2.step(optimizer)
            scaler2.update()
            optimizer.zero_grad()  # 重置梯度

            running_loss += single_loss.item()
            running_num += 1
        
        if epoch % 10 == 0:
        # 在每个epoch结束后计算平均训练损失
            avg_train_loss = running_loss / running_num
            print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.8f}')
        # if epoch % 100 == 0:
        # 在每个epoch结束后在测试集上评估模型
            labels1,y_pred1,test_loss = evaluate(model, test_dataloader, device)
            # print("lables:",labels)
            # print("y_pred:",y_pred)
            print(f'Epoch {epoch+1}/{epochs}, Test Loss: {test_loss:.8f}')

    #torch.save(model.state_dict(), 'model_CNN_GRU_CVAE.pt')