大模型原理--多头注意力机制的代码实现(MHA引入KV cache)
【摘要】 基于Decoder-only的大模型在自回归生成过程中,模型的输出使逐token。每一个新token,模型都会与上下文的所有token进行注意力计算,这将造成巨大的重复计算。
1.概述
基于Decoder-only的大模型在自回归生成过程中,模型的输出使逐token。每一个新token,模型都会与上下文的所有token进行注意力计算,这将造成巨大的重复计算。为了避免这种重复计算,在推理阶段将历史 token 的 Key 和 Value 缓存下来,供后续步骤直接使用。这一机制就是 KV Cache。
2. 工作流程图

3. 手工代码实现
import torch
import torch.nn as nn
# 在推理阶段将历史 token 的 Key 和 Value 缓存
past_key_value = {}
# 自注意力机制
class DecoderSelfAttentionHead(nn.Module):
def __init__(self, d_model, head_size):
super().__init__()
self.head_size = head_size
self.query = nn.Linear(d_model, head_size)
self.key = nn.Linear(d_model, head_size)
self.value = nn.Linear(d_model, head_size)
def forward(self, x, use_cache, masked, layer_num, head_index):
print("layer_num: ",layer_num)
print("head_index: ",head_index)
# 计算Q, K, V
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# 如果use_cache是True,则代表当前是推理阶段,需要进行KV的缓存
if use_cache:
if layer_num in past_key_value and head_index in past_key_value[layer_num]:
# 拼接KV
K = torch.cat([past_key_value[layer_num][head_index][0], K], dim=1)
V = torch.cat([past_key_value[layer_num][head_index][1], V], dim=1)
# 更新KV缓存
past_key_value[layer_num][head_index] = (K, V)
elif layer_num not in past_key_value:
# 每层的初始KV缓存
past_key_value[layer_num] = {head_index:(K,V)}
else:
past_key_value[layer_num][head_index] = (K, V)
print(past_key_value)
# Q和K的点积
attention = Q @ K.transpose(-2, -1)
# 缩放
attention = attention / (self.head_size ** 0.5)
# 训练的时候使用
if masked:
mask = torch.triu(torch.ones_like(attention), diagonal=1).bool()
attention.masked_fill(mask, float('-inf'))
# 打分
attention = torch.softmax(attention, dim=-1)
# 加权汇总
attention = attention @ V
return attention
#
class MaskedMultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
# 注意力头的个数,这块也可以看出d_model和注意力头数目由整倍数关系
self.head_size = d_model // n_heads
self.heads = nn.ModuleList([
DecoderSelfAttentionHead(d_model, self.head_size) for _ in range(n_heads)
])
self.multi_header = nn.Linear(d_model, d_model)
# 引入layer_num,表示模型层数. 需要为每一层的每一个自注意力头缓存KV
def forward(self, x, use_cache=False, masked=False, layer_num=0):
# 拼接多个注意力头
multi_header = torch.cat(
[head(x, use_cache, masked, layer_num, head_index) for head_index, head in enumerate(self.heads)], dim=-1)
out = self.multi_header(multi_header)
return out
if __name__ == '__main__':
x1 = torch.randn((1, 5, 12))
x2 = torch.randn((1, 1, 12))
model = MaskedMultiHeadAttention(12, 4)
y1 = model(x1, use_cache=True)
y2 = model(x2, use_cache=True)
4. 总结:主要是理解KV cache的缓存过程,更加深刻理解训练过程和推理过程的不同。
【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)