import torch
import torch.nn as nn
import math


class HybridCNN(nn.Module):
    def __init__(self, n, user):
        super(HybridCNN, self).__init__()
        self.n = n
        self.user = user

        # 输入处理层 (batch, 3n + user) -> (batch, 1, 3n + user)
        self.conv1 = nn.Conv1d(1, 64, kernel_size=3, padding=1)

        # 特征提取层
        self.fc = nn.Linear(64 * (3 * n + user*2), 64)

        # 输出头定义
        self.cont_head = nn.Linear(64, 2 * n)  # 连续输出
        self.binary_head = nn.Linear(64, user * n)  # 二值输出

        # 预生成缩放模板
        self.register_buffer('scale_template',
                             torch.Tensor([(20 if i % 2 else 2 * math.pi) for i in range(2 * n)]))

    def forward(self, x):
        # 输入验证
        expected_dim = 3 * self.n + self.user*2
        if x.size(-1) != expected_dim:
            raise ValueError(f"输入维度应为 {expected_dim}，实际得到 {x.size(-1)}")

        # 输入处理 (batch, D) -> (batch, 1, D)
        x = x.unsqueeze(1)

        # 特征提取
        x = torch.relu(self.conv1(x))  # (batch, 64, D)
        x = x.view(x.size(0), -1)  # (batch, 64*D)
        x = torch.relu(self.fc(x))  # (batch, 64)

        # 连续输出处理
        cont_output = torch.sigmoid(self.cont_head(x)) * self.scale_template

        # 二值输出处理（训练时使用STE技巧）
        bin_logits = self.binary_head(x)
        bin_output = (bin_logits > 0).float()
        if self.training:
            bin_output = bin_output + bin_logits - bin_logits.detach()

        return torch.cat([cont_output, bin_output], dim=1)

def step(n, user, array):
    # 初始化参数

    # 初始化网络
    net = HybridCNN(n=n, user=user)
  #  print(f"网络结构(n={n}, user={user}):\n{net}")

    test_input = [array]
    test_input = torch.tensor(test_input)
  #  print(f"\n测试输入形状: {test_input.shape}")

    # 前向传播
    output = net(test_input)
    return output

  #  print(test_input)

 #   print(output)

    # 验证输出
  #  print(f"\n输出形状: {output.shape} (2n + user*n = {2 * n} + {user * n} = {2 * n + user * n})")



# 验证函数
def validate_output(output, n, user):
    # 连续部分检查
    cont = output[:, :2 * n]
    even_check = torch.all((cont[:, ::2] >= 0) & (cont[:, ::2] <= 2 * math.pi))
    odd_check = torch.all((cont[:, 1::2] >= 0) & (cont[:, 1::2] <= 20))

    # 二值部分检查
    bin_part = output[:, 2 * n:]
    binary_check = torch.all(torch.isin(bin_part, torch.tensor([0.0, 1.0])))

    return even_check and odd_check and binary_check


