探索 FP8 训练中 Debug 思路与技巧
探索 FP8 训练中 Debug 思路与技巧
一、为什么 FP8 训练需要专门的 Debug 体系
FP8(Float-8)把 Tensor Core 每次运算的 bit 数再砍一半,带来 1.3 ~ 2× 的吞吐收益,但也让数值动态范围进一步缩小。
在 1k+ GPU 的大模型训练中,我们最常遇到的三类异常是:
- Spike:Loss 突然飙升后缓慢回落。
- Drift:Loss 与 BF16 基线逐渐偏离,最终差出 1 % 以上。
- Nan/Inf:训练直接崩溃。
这三类异常往往与 量化比例(scale factor) 或 某一层 GEMM 的 FP8 选择 有关。下面给出一条可复现、可脚本化的排查流水线。
二、整体排查流程(Road-Map)
步骤 | 目的 | 关键开关 | 预期 |
---|---|---|---|
① 环境对齐 | 排除版本/库 Bug | TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 |
与官方 Release Note 完全一致 |
② 复现 BF16 基线 | 拿到 Golden Loss | --fp8=0 |
记录 loss, grad norm, lr |
③ 打开 FP8 全量 | 观察宏观现象 | --fp8=1 |
画 loss 曲线,看是否出现 Spike/Drift |
④ 三类矩阵二分 | 定位哪一类 GEMM 出错 | NVTE_FP8_DGRAD=0 等 |
缩小范围 |
⑤ 逐层 dump | 找到具体层 | NVTE_DEBUG_LAYER="decoder.12" |
输出 tensor、scale、cosine |
⑥ 微调 Recipe | 换 Scaling 策略 | --recipe=current |
在性能与精度间权衡 |
⑦ 固化脚本 | 自动化回归 | pytest |
PR Gate |
三、环境与最小可复现代码
3.1 软硬件版本
GPU : H100-SXM5 80GB
Driver : 535.54.03
CUDA : 12.2
PyTorch : 2.3.0a0+git0e9c721
TransformerEngine : 1.2.0+gitd76118d
3.2 50 行最小复现脚本
# debug_fp8.py
import torch, transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
class TinyTransformer(torch.nn.Module):
def __init__(self, hidden=1024, vocab=32000):
super().__init__()
self.embed = torch.nn.Embedding(vocab, hidden)
self.layers = torch.nn.ModuleList([
te.TransformerLayer(hidden, 4*hidden, 16,
layer_type=te.LayerType.encoder,
self_attn_mask_type="padding")
for _ in range(12)
])
self.lm_head = torch.nn.Linear(hidden, vocab, bias=False)
def forward(self, x):
h = self.embed(x)
for lyr in self.layers:
h = lyr(h)
return self.lm_head(h)
model = TinyTransformer().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
recipe = DelayedScaling(fp8_format="HYBRID", amax_history_len=1024, amax_compute_algo="max")
# 单步训练函数
def train_step(tokens, labels):
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
out = model(tokens)
loss = torch.nn.functional.cross_entropy(out.flatten(0,1), labels.flatten())
loss.backward()
optimizer.step(); optimizer.zero_grad()
return loss.item()
# 模拟 1000 步
for step in range(1000):
tokens = torch.randint(0, 32000, (4, 512)).cuda()
labels = torch.randint(0, 32000, (4, 512)).cuda()
loss = train_step(tokens, labels)
if step % 50 == 0:
print(f"step {step:4d} loss={loss:.4f}")
运行
python debug_fp8.py 2>&1 | tee fp8.log
若出现 loss=nan 或 >10,则进入下一节。
四、三类 GEMM 二分定位
Transformer Engine 把每个 Transformer layer 拆成 3 个 GEMM:
- Fprop:前向投影
- Wgrad:权重梯度
- Dgrad:输入梯度
通过环境变量一键回退到 BF16:
环境变量 | 作用 |
---|---|
NVTE_FP8_FPROP=0 |
仅 Fprop 用 BF16 |
NVTE_FP8_WGRAD=0 |
仅 Wgrad 用 BF16 |
NVTE_FP8_DGRAD=0 |
仅 Dgrad 用 BF16 |
示例:
# 定位 Dgrad 问题
NVTE_FP8_DGRAD=0 torchrun debug_fp8.py
如果 loss 回到 BF16 基线,则 Dgrad 的 FP8 量化策略需要调整。
五、逐层 Tensor Dump 与可视化
Transformer Engine 内置 Debug Logger:
import os, json, numpy as np
os.environ["NVTE_DEBUG_LAYER"] = "decoder.3" # 只 dump 第 4 层
os.environ["NVTE_DEBUG_STEP_INTERVAL"] = "50" # 每 50 步 dump 一次
def debug_hook(step, tensor_dict):
# tensor_dict: {"name": tensor, ...}
with open(f"dump/step{step}.json", "w") as f:
json.dump({k: v.cpu().numpy().tolist()[:32] for k, v in tensor_dict.items()}, f)
te.debug.set_debug_hook(debug_hook)
Dump 后计算 Cosine & MSE:
import torch
a = torch.tensor(json.load(open("dump/step100.json"))["input"])
b = torch.tensor(json.load(open("bf16/step100.json"))["input"])
print("cosine:", torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), 0))
print("mse:", torch.nn.functional.mse_loss(a, b))
若 cosine < 0.99,说明该层量化误差过大。
六、Scaling Recipe 微调
Delayed Scaling 默认用 1024 步历史最大值,可能“滞后”于尖峰。
切换为 Current Scaling(即时 max):
from transformer_engine.common.recipe import Format, DelayedScaling
recipe = DelayedScaling(fp8_format=Format.E4M3,
amax_history_len=1, # 等价 Current
amax_compute_algo="max")
性能损失 < 5 %,但可解决 Drift。
七、一个真实案例的完整排查日志
时间 | 动作 | loss | 结论 |
---|---|---|---|
07-21 10:00 | BF16 基线 | 6.74 → 6.32 | 正常 |
07-21 11:15 | FP8 全量 | 6.74 → nan | 崩溃 |
07-21 11:30 | NVTE_FP8_DGRAD=0 |
6.74 → 6.33 | Dgrad 问题 |
07-21 12:00 | 改 Current Scaling | 6.74 → 6.34 | 精度恢复 |
07-21 12:30 | 固化脚本回归 | ✅ 1000 步无 nan | 合入主干 |
八、可复制的回归脚本
把上面所有步骤封装成 pytest
:
# test_fp8_debug.py
import subprocess, json, os, pytest
def run(cmd):
return subprocess.check_output(cmd, shell=True, text=True)
@pytest.mark.parametrize("mode", ["bf16", "fp8_full", "fp8_nodgrad", "fp8_current"])
def test_mode(mode):
log = run(f"NVTE_FP8_DGRAD={0 if 'nodgrad' in mode else 1} "
f"python debug_fp8.py --recipe={mode}")
loss = [float(l.split("loss=")[1]) for l in log.splitlines() if "loss=" in l]
assert loss[-1] < 7.0, f"{mode} diverged: {loss[-1]}"
CI 里跑 pytest -n4
即可确保后续 PR 不会回退。
九、小结与展望
- 先宏观后微观:先比较 loss 曲线,再二分 GEMM,再 dump 张量。
- Scaling 是核心:Delayed vs Current 往往决定 Drift 还是 Spike。
- 自动化是终点:把 Debug 流程脚本化,FP8 才敢上生产。
未来,Transformer Engine 会支持 per-tensor adaptive scaling 与 online loss-scale,进一步压缩 Debug 时间。希望本文的脚本与思路能帮助社区把 FP8 训练从“实验室技术”变成“默认选项”。
- 点赞
- 收藏
- 关注作者
评论(0)