import torch
import torch.nn as nn
import time
from thop import profile  # 需要安装：pip install thop


class Actor(nn.Module):
    """示例Actor网络（全连接结构）"""

    def __init__(self, state_dim=24, action_dim=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh()
        )

    def forward(self, state):
        return self.net(state)


# 创建模型和测试数据
device = torch.device( "cpu")
model = Actor().to(device)
dummy_input = torch.randn(1, 24).to(device)  # 批大小1, 状态维度24

# ===== 计算FLOPs和参数量 =====
flops, params = profile(model, inputs=(dummy_input,))
print(f"参数量(Parameters): {params / 1e6:.2f} M")
print(f"FLOPs: {flops / 1e6:.5f} M")

# ===== 测量推理延迟 =====
warmup_steps = 100  # 预热次数
test_steps = 1000  # 测试次数

# GPU预热（让CUDA完成初始化）
for _ in range(warmup_steps):
    _ = model(dummy_input)

# 实际测量
start_time = time.time()
for _ in range(test_steps):
    with torch.no_grad():  # 禁用梯度计算
        _ = model(dummy_input)

    # 若使用GPU，需同步
    if device.type == "cuda":
        torch.cuda.synchronize()

end_time = time.time()

avg_latency = (end_time - start_time) * 1000 / test_steps  # 毫秒/次
print(f"平均推理延迟: {avg_latency:.4f} ms")

# ===== 输出详细每层统计 =====
print("\n各层详细信息：")
model_summary = profile(model, inputs=(dummy_input,), verbose=True)