探索 FP8 训练中 Debug 思路与技巧

举报
江南清风起 发表于 2025/07/26 17:25:44 2025/07/26
【摘要】 探索 FP8 训练中 Debug 思路与技巧 一、为什么 FP8 训练需要专门的 Debug 体系FP8(Float-8)把 Tensor Core 每次运算的 bit 数再砍一半,带来 1.3 ~ 2× 的吞吐收益,但也让数值动态范围进一步缩小。在 1k+ GPU 的大模型训练中,我们最常遇到的三类异常是:Spike:Loss 突然飙升后缓慢回落。Drift:Loss 与 BF16 基线...

探索 FP8 训练中 Debug 思路与技巧

一、为什么 FP8 训练需要专门的 Debug 体系

FP8(Float-8)把 Tensor Core 每次运算的 bit 数再砍一半,带来 1.3 ~ 2× 的吞吐收益,但也让数值动态范围进一步缩小。
在 1k+ GPU 的大模型训练中,我们最常遇到的三类异常是:

  1. Spike:Loss 突然飙升后缓慢回落。
  2. Drift:Loss 与 BF16 基线逐渐偏离,最终差出 1 % 以上。
  3. Nan/Inf:训练直接崩溃。

这三类异常往往与 量化比例(scale factor)某一层 GEMM 的 FP8 选择 有关。下面给出一条可复现、可脚本化的排查流水线。


二、整体排查流程(Road-Map)

步骤 目的 关键开关 预期
① 环境对齐 排除版本/库 Bug TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 与官方 Release Note 完全一致
② 复现 BF16 基线 拿到 Golden Loss --fp8=0 记录 loss, grad norm, lr
③ 打开 FP8 全量 观察宏观现象 --fp8=1 画 loss 曲线,看是否出现 Spike/Drift
④ 三类矩阵二分 定位哪一类 GEMM 出错 NVTE_FP8_DGRAD=0 缩小范围
⑤ 逐层 dump 找到具体层 NVTE_DEBUG_LAYER="decoder.12" 输出 tensor、scale、cosine
⑥ 微调 Recipe 换 Scaling 策略 --recipe=current 在性能与精度间权衡
⑦ 固化脚本 自动化回归 pytest PR Gate

三、环境与最小可复现代码

3.1 软硬件版本

GPU        : H100-SXM5 80GB
Driver     : 535.54.03
CUDA       : 12.2
PyTorch    : 2.3.0a0+git0e9c721
TransformerEngine : 1.2.0+gitd76118d

3.2 50 行最小复现脚本

# debug_fp8.py
import torch, transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling

class TinyTransformer(torch.nn.Module):
    def __init__(self, hidden=1024, vocab=32000):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab, hidden)
        self.layers = torch.nn.ModuleList([
            te.TransformerLayer(hidden, 4*hidden, 16,
                                layer_type=te.LayerType.encoder,
                                self_attn_mask_type="padding")
            for _ in range(12)
        ])
        self.lm_head = torch.nn.Linear(hidden, vocab, bias=False)

    def forward(self, x):
        h = self.embed(x)
        for lyr in self.layers:
            h = lyr(h)
        return self.lm_head(h)

model = TinyTransformer().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
recipe = DelayedScaling(fp8_format="HYBRID", amax_history_len=1024, amax_compute_algo="max")

# 单步训练函数
def train_step(tokens, labels):
    with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
        out = model(tokens)
        loss = torch.nn.functional.cross_entropy(out.flatten(0,1), labels.flatten())
    loss.backward()
    optimizer.step(); optimizer.zero_grad()
    return loss.item()

# 模拟 1000 步
for step in range(1000):
    tokens = torch.randint(0, 32000, (4, 512)).cuda()
    labels = torch.randint(0, 32000, (4, 512)).cuda()
    loss = train_step(tokens, labels)
    if step % 50 == 0:
        print(f"step {step:4d}  loss={loss:.4f}")

运行

python debug_fp8.py 2>&1 | tee fp8.log

若出现 loss=nan 或 >10,则进入下一节。


四、三类 GEMM 二分定位

Transformer Engine 把每个 Transformer layer 拆成 3 个 GEMM:

  • Fprop:前向投影
  • Wgrad:权重梯度
  • Dgrad:输入梯度

通过环境变量一键回退到 BF16:

环境变量 作用
NVTE_FP8_FPROP=0 仅 Fprop 用 BF16
NVTE_FP8_WGRAD=0 仅 Wgrad 用 BF16
NVTE_FP8_DGRAD=0 仅 Dgrad 用 BF16

示例:

# 定位 Dgrad 问题
NVTE_FP8_DGRAD=0 torchrun debug_fp8.py

如果 loss 回到 BF16 基线,则 Dgrad 的 FP8 量化策略需要调整。


五、逐层 Tensor Dump 与可视化

Transformer Engine 内置 Debug Logger:

import os, json, numpy as np
os.environ["NVTE_DEBUG_LAYER"] = "decoder.3"   # 只 dump 第 4 层
os.environ["NVTE_DEBUG_STEP_INTERVAL"] = "50"  # 每 50 步 dump 一次

def debug_hook(step, tensor_dict):
    # tensor_dict: {"name": tensor, ...}
    with open(f"dump/step{step}.json", "w") as f:
        json.dump({k: v.cpu().numpy().tolist()[:32] for k, v in tensor_dict.items()}, f)

te.debug.set_debug_hook(debug_hook)

Dump 后计算 Cosine & MSE:

import torch
a = torch.tensor(json.load(open("dump/step100.json"))["input"])
b = torch.tensor(json.load(open("bf16/step100.json"))["input"])
print("cosine:", torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), 0))
print("mse:", torch.nn.functional.mse_loss(a, b))

若 cosine < 0.99,说明该层量化误差过大。


六、Scaling Recipe 微调

Delayed Scaling 默认用 1024 步历史最大值,可能“滞后”于尖峰。
切换为 Current Scaling(即时 max):

from transformer_engine.common.recipe import Format, DelayedScaling

recipe = DelayedScaling(fp8_format=Format.E4M3,
                        amax_history_len=1,        # 等价 Current
                        amax_compute_algo="max")

性能损失 < 5 %,但可解决 Drift。


七、一个真实案例的完整排查日志

时间 动作 loss 结论
07-21 10:00 BF16 基线 6.74 → 6.32 正常
07-21 11:15 FP8 全量 6.74 → nan 崩溃
07-21 11:30 NVTE_FP8_DGRAD=0 6.74 → 6.33 Dgrad 问题
07-21 12:00 改 Current Scaling 6.74 → 6.34 精度恢复
07-21 12:30 固化脚本回归 ✅ 1000 步无 nan 合入主干

八、可复制的回归脚本

把上面所有步骤封装成 pytest

# test_fp8_debug.py
import subprocess, json, os, pytest

def run(cmd):
    return subprocess.check_output(cmd, shell=True, text=True)

@pytest.mark.parametrize("mode", ["bf16", "fp8_full", "fp8_nodgrad", "fp8_current"])
def test_mode(mode):
    log = run(f"NVTE_FP8_DGRAD={0 if 'nodgrad' in mode else 1} "
              f"python debug_fp8.py --recipe={mode}")
    loss = [float(l.split("loss=")[1]) for l in log.splitlines() if "loss=" in l]
    assert loss[-1] < 7.0, f"{mode} diverged: {loss[-1]}"

CI 里跑 pytest -n4 即可确保后续 PR 不会回退。


九、小结与展望

  • 先宏观后微观:先比较 loss 曲线,再二分 GEMM,再 dump 张量。
  • Scaling 是核心:Delayed vs Current 往往决定 Drift 还是 Spike。
  • 自动化是终点:把 Debug 流程脚本化,FP8 才敢上生产。

未来,Transformer Engine 会支持 per-tensor adaptive scalingonline loss-scale,进一步压缩 Debug 时间。希望本文的脚本与思路能帮助社区把 FP8 训练从“实验室技术”变成“默认选项”。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。