BERT 量化实战分析

举报
九年义务漏网鲨鱼 发表于 2025/09/15 15:01:26 2025/09/15
【摘要】 BERT 量化实战分析

BERT 量化实战分析

前言:基于BERT实现了情感分析系统以及量化的实现,但是量化的结果导致了模型的精确度急剧下降,从90%降到了54%,为此,在本章中,尽可能的分析导致量化后模型精度下降的原因

上期问题

🔴 在量化过程中,发现无法采用export量化,但是 Eager Mode 成功了, Eager Mode 只对线性层进行了量化,而没有对embedding层进行量化; EXPORT 量化仅支持对所有的层进行量化 (所以量化结果只剩下0.01M),无法支持指定层的量化;

🔴 目前pytorch提供的FX在维护中,无法使用;

🟢 因此,本节的实验依然采用 Eager Mode 动态量化方式进行量化,只对权重进行量化;

# export报错信息:forward() missing 203 required positional arguments: 'p_bert_embeddings_position_embeddings_weight', 'p_bert_embeddings_layernorm_weight', 'p_bert_embeddings_layernorm_bias',  

#⚠️ BERT 模型包含了 nn.Embedding 层,而当前 PT2E 导出流程默认将这些参数导出为必须手动传入的动态参数(如 p_bert_embeddings_position_embeddings_weight),导致你在前向推理时必须手动传入 embedding 权重,否则就会报错。

量化分析方法

为了进一步的优化量化模型,可以从以下方法进行分析:

🟢 Calibration Range 分析

🟢 逐层敏感性分析

🟢 层级 fallback 到 FP32

🟢 误差传播分析

🟢 具体样本误差对比

🔍 Calibration Range 分析

# 权重分析可视化代码——观察量化前后的分布情况
def plot_distribution(fp32_tensor, quant_tensor, layer_name):
    print(type(quant_tensor))  # 应为 torch.quantized.QTensor
    print(quant_tensor.dtype)  # 应为 torch.qint8 或 torch.quint8
    plt.figure(figsize=(10, 4))

    # FP32原始分布
    plt.subplot(121)
    plt.hist(fp32_tensor.flatten(), bins=100, alpha=0.5, label='FP32', color='blue')
    plt.axvline(quant_tensor.q_scale() * (127 - quant_tensor.q_zero_point()), color='red')  # 上界
    plt.axvline(quant_tensor.q_scale() * (-128 - quant_tensor.q_zero_point()), color='red')  # 下界
    plt.title(f"{layer_name} - FP32 vs Quant Bounds")

    # 量化后反量化分布
    plt.subplot(122)
    dequant_tensor = quant_tensor.dequantize()
    plt.hist(dequant_tensor.flatten(), bins=100, alpha=0.5, label='Dequantized', color='orange')
    plt.title("Dequantized Distribution")

    plt.tight_layout()
    plt.show()

🟢 未出现截断情况(即分布区域超过量化上下限)、分布近似

🔴 scale过大

image.png

scale的计算如下所示:scale=max(w)min(w)255scale=\frac{max(w)-min(w)}{255}, 个别层的权重有离群值,会导致scale非常大,严重丢失精度。为此对权重进行裁剪操作:

with torch.no_grad():
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            module.weight.clamp_(-3.0, 3.0) 
# Original FP32 model accuracy: 0.9300
# Quantized INT8 model accuracy: 0.5482 → 0.9151

🔴 量化后出现锯齿

image.png

  • 可能的原因

① 权重分布本身就不光滑(有离群值)

② 权重量化导致连续输入映射为不连续输出

  • 解决

✅ 方法1:替换激活函数 GELU → ReLU

✅ 方法2:尝试采用 QAT

🧪 逐层敏感性分析

核心思想:将原模型逐层量化,观察产生精度下降的原因;

import torch
import copy
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm import tqdm


def preprocess(tokenizer, example):
    return tokenizer(example["sentence"], truncation=True, padding="max_length", max_length=128)

def evaluate(model, dataloader, device="cpu"):
    model.eval()
    preds = []
    labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            batch_preds = torch.argmax(logits, dim=1)

            preds.extend(batch_preds.cpu().numpy())
            labels.extend(label.cpu().numpy())

    return accuracy_score(labels, preds)


def get_linear_layers(model):
    return [(name, module) for name, module in model.named_modules() if isinstance(module, nn.Linear)]


def run_sensitivity_analysis(model_fp32, tokenizer):
    print("Loading SST-2 validation dataset...")
    dataset = load_dataset("glue", "sst2")["validation"]
    dataset = dataset.map(lambda x: preprocess(tokenizer, x), batched=True)
    dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

    device = "cpu"
    model_fp32.to(device)

    print("Evaluating FP32 baseline...")
    acc_fp32 = evaluate(model_fp32, dataloader, device)
    print(f"Baseline FP32 Accuracy: {acc_fp32:.4f}")

    print("Evaluating fully quantized model...")
    model_full_quant = torch.quantization.quantize_dynamic(
        copy.deepcopy(model_fp32),
        {torch.nn.Linear},
        dtype=torch.qint8
    )
    acc_full_quant = evaluate(model_full_quant, dataloader, device)
    print(f"Fully Quantized Accuracy: {acc_full_quant:.4f}")

    results = []
    linear_layers = get_linear_layers(model_fp32)

    print("\nPerforming per-layer sensitivity analysis...\n")
    for name, _ in tqdm(linear_layers):
        # 复制模型
        model_copy = copy.deepcopy(model_fp32)
        
        # 遍历并只量化当前层
        for n, module in model_copy.named_modules():
            if isinstance(module, nn.Linear):
                if n == name:
                    quantized = torch.quantization.quantize_dynamic(module, {nn.Linear}, dtype=torch.qint8)
                    setattr(model_copy, name.split(".")[0], quantized)  # 如果层在 nn.Sequential 可直接这样设置
        acc = evaluate(model_copy, dataloader, device)
        delta = acc_fp32 - acc
        print(f"Layer: {name:40s} | Acc: {acc:.4f} | ΔAcc: {delta:.4f}")
        results.append((name, acc, delta))

    results.sort(key=lambda x: x[2], reverse=True)
    print("\nTop-5 Most Sensitive Layers:")
    for r in results[:5]:
        print(f"{r[0]:40s} | Acc: {r[1]:.4f} | ΔAcc: {r[2]:.4f}")

    return results

🧠 其他分析方法

层级 fallback 到 FP32

与敏感性分析相关,该方法是将原模型逐层量化,观察精度下降情况

误差传播分析

对 float32 模型 和 量化模型,输入相同的样本;

逐层提取中间层输出;

对每层输出计算误差(如 MSE、Cosine 距离等);

画出误差随层数变化的曲线 → 看是否有层明显放大了误差;

具体样本误差对比

目标:某个具体输入,FP32 模型 vs INT8 模型输出差异有多大

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。