import torch
from transformers import GPT2Config, GPT2Model
from fvcore.nn import FlopCountAnalysis
n_embd = 256
# 创建GPT2配置
config = GPT2Config(
    n_embd=n_embd,  # 嵌入维度
    n_layer=4,  # Transformer层数
    n_head=4,  # 注意力头数
    n_positions=256,  # 最大序列长度支持
    vocab_size=0  # 禁用词汇表（使用inputs_embeds）
)

# 创建GPT2模型
model = GPT2Model(config)

# 创建输入数据 (batch_size=1, sequence_length=256, hidden_size=128)
input_data = torch.randn(1, 256, n_embd)  # 序列长度256，嵌入维度128


# 手动计算参数量
def count_parameters(model):
    """计算模型参数量"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


params = count_parameters(model)
print(f"参数量: {params:,} ({params / 1e6:.2f} M)")


# 使用 fvcore 计算 FLOPs
# 创建模型包装器来处理字典输入
class GPT2Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inputs_embeds):
        return self.model(inputs_embeds=inputs_embeds).last_hidden_state


wrapped_model = GPT2Wrapper(model)

# 使用 fvcore 计算 FLOPs
flop_analyzer = FlopCountAnalysis(wrapped_model, input_data)
flops = flop_analyzer.total()

print(f"\n总 FLOPs: {flops:,}")
print(f"FLOPs (百万): {flops / 1e6:.2f} M")
print(f"FLOPs (十亿): {flops / 1e9:.2f} G")


# 理论计算验证
def calculate_gpt2_flops(batch_size, seq_len, n_embd, n_layer, n_head):
    """理论计算GPT-2模型的FLOPs"""
    d_head = n_embd // n_head  # 每个注意力头的维度

    # 计算每层的FLOPs
    flops_per_layer = 0

    # 注意力部分
    # 1. Q、K、V投影层
    flops_per_layer += 3 * batch_size * seq_len * n_embd * n_embd * 2  # 乘加操作

    # 2. 注意力分数计算 (Q·K^T)
    flops_per_layer += batch_size * n_head * seq_len * d_head * seq_len * 2

    # 3. 注意力输出计算 (softmax·V)
    flops_per_layer += batch_size * n_head * seq_len * seq_len * d_head * 2

    # 4. 输出投影层
    flops_per_layer += batch_size * seq_len * n_embd * n_embd * 2

    # 前馈网络部分
    ffn_dim = 4 * n_embd  # 通常扩展4倍

    # 5. FFN第一层
    flops_per_layer += batch_size * seq_len * n_embd * ffn_dim * 2

    # 6. FFN第二层
    flops_per_layer += batch_size * seq_len * ffn_dim * n_embd * 2

    # 总层FLOPs
    total_layer_flops = n_layer * flops_per_layer

    return total_layer_flops


# 计算理论FLOPs
batch_size = 1
seq_len = 256
n_embd = 128
n_layer = 4
n_head = 8

theoretical_flops = calculate_gpt2_flops(batch_size, seq_len, n_embd, n_layer, n_head)

print(f"\n理论计算总FLOPs: {theoretical_flops:,}")
print(f"FLOPs (十亿): {theoretical_flops / 1e9:.2f} G")
print(f"与实际值差异: {(flops - theoretical_flops) / flops * 100:.2f}%")