import torch
import torch.nn as nn
import math


class HybridCNN(nn.Module):
    def __init__(self, n, user):
        super(HybridCNN, self).__init__()
        self.n = n
        self.user = user
        input_dim = 3 * n + 2 * user  # 输入维度
        output_dim = 2 * n  # 输出维度

        # 输入处理层 (batch, input_dim) -> (batch, 1, input_dim)
        self.conv1 = nn.Conv1d(1, 64, kernel_size=3, padding=1)

        # 特征提取层
        self.fc1 = nn.Linear(64 * input_dim, 128)
        self.fc2 = nn.Linear(128, 64)

        # 输出层 - 3*n维输出
        self.output_head = nn.Linear(64, output_dim)

        # 缩放参数，确保输出在[-10, 10]范围内
        self.scale_factor = 10.0

    def forward(self, x):
        # 输入验证
        expected_dim = 3 * self.n + 2 * self.user
        if x.size(-1) != expected_dim:
            raise ValueError(f"输入维度应为 {expected_dim}，实际得到 {x.size(-1)}")

        # 输入处理 (batch, D) -> (batch, 1, D)
        x = x.unsqueeze(1)

        # 卷积层
        x = torch.relu(self.conv1(x))  # (batch, 64, D)

        # 展平
        x = x.view(x.size(0), -1)  # (batch, 64*D)

        # 全连接层
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))

        # 输出层 - 使用tanh将输出限制在[-1,1]范围内
        output = torch.tanh(self.output_head(x))

        # 缩放输出到[-10,10]范围
        output = output * self.scale_factor

        return output


def step(n, user, array):
    # 检查输入数组长度
    expected_length = 3 * n + 2 * user
    if len(array) != expected_length:
        raise ValueError(f"输入数组长度应为 {expected_length}，实际得到 {len(array)}")

    # 初始化网络
    net = HybridCNN(n=n, user=user)

    # 转换为张量
    test_input = torch.tensor([array], dtype=torch.float32)

    # 前向传播
    output = net(test_input)

    # 转换为列表并返回
    return output.detach().numpy().flatten().tolist()