大模型原理--多头注意力机制的代码实现(MHA引入KV cache)

举报
剑指南天 发表于 2026/05/04 12:59:29 2026/05/04
【摘要】 基于Decoder-only的大模型在自回归生成过程中,模型的输出使逐token。每一个新token,模型都会与上下文的所有token进行注意力计算,这将造成巨大的重复计算。

1.概述

基于Decoder-only的大模型在自回归生成过程中,模型的输出使逐token。每一个新token,模型都会与上下文的所有token进行注意力计算,这将造成巨大的重复计算。为了避免这种重复计算,在推理阶段将历史 token 的 Key 和 Value 缓存下来,供后续步骤直接使用。这一机制就是 KV Cache。

2. 工作流程图

3. 手工代码实现

import torch
import torch.nn as nn

# 在推理阶段将历史 token  Key  Value 缓存
past_key_value = {}

# 自注意力机制
class DecoderSelfAttentionHead(nn.Module):
    def __init__(self, d_model, head_size):
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(d_model, head_size)
        self.key = nn.Linear(d_model, head_size)
        self.value = nn.Linear(d_model, head_size)

    def forward(self, x, use_cache, masked, layer_num, head_index):
        print("layer_num: ",layer_num)
        print("head_index: ",head_index)
        # 计算Q, K, V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        # 如果use_cacheTrue,则代表当前是推理阶段,需要进行KV的缓存
        if use_cache:
            if layer_num in past_key_value and head_index in past_key_value[layer_num]:
                # 拼接KV
                K = torch.cat([past_key_value[layer_num][head_index][0], K], dim=1)
                V = torch.cat([past_key_value[layer_num][head_index][1], V], dim=1)
                # 更新KV缓存
                past_key_value[layer_num][head_index] = (K, V)
            elif layer_num not in past_key_value:
                # 每层的初始KV缓存
                past_key_value[layer_num] = {head_index:(K,V)}
            else:
                past_key_value[layer_num][head_index] = (K, V)
            print(past_key_value)
        # QK的点积
        attention = Q @ K.transpose(-2, -1)
        # 缩放
        attention = attention / (self.head_size ** 0.5)
        # 训练的时候使用
        if masked:
            mask = torch.triu(torch.ones_like(attention), diagonal=1).bool()
            attention.masked_fill(mask, float('-inf'))
        # 打分
        attention = torch.softmax(attention, dim=-1)
        # 加权汇总
        attention = attention @ V
        return attention

#
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        # 注意力头的个数,这块也可以看出d_model和注意力头数目由整倍数关系
        self.head_size = d_model // n_heads
        self.heads = nn.ModuleList([
            DecoderSelfAttentionHead(d_model, self.head_size) for _ in range(n_heads)
        ])
        self.multi_header = nn.Linear(d_model, d_model)
    # 引入layer_num,表示模型层数. 需要为每一层的每一个自注意力头缓存KV
    def forward(self, x, use_cache=False, masked=False, layer_num=0):
        # 拼接多个注意力头
        multi_header = torch.cat(
            [head(x, use_cache, masked, layer_num, head_index) for head_index, head in enumerate(self.heads)], dim=-1)
        out = self.multi_header(multi_header)
        return out
if __name__ == '__main__':
    x1 = torch.randn((1, 5, 12))
    x2 = torch.randn((1, 1, 12))
    model = MaskedMultiHeadAttention(12, 4)
    y1 = model(x1, use_cache=True)
    y2 = model(x2, use_cache=True)

4. 总结:主要是理解KV cache的缓存过程,更加深刻理解训练过程和推理过程的不同。

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。