大模型原理--多头注意力机制的代码实现(MQA和GQA)
【摘要】 KV Cache显存占用的因素有大模型的层数、MHA中的头数、kv向量的长度、数字表示的精度、上下文长度以及推理请求的个数。其中上下文长度,推理请求的个数会动态的影响KV Cache,会造成KV Cache缓存规模成倍的增加,物理显存限制和显存带宽会显著影响计算效率。
1.概述
KV Cache显存占用的因素有大模型的层数、MHA中的头数、kv向量的长度、数字表示的精度、上下文长度以及推理请求的个数。其中上下文长度,推理请求的个数会动态的影响KV Cache,会造成KV Cache缓存规模成倍的增加,物理显存限制和显存带宽会显著影响计算效率。
2. MQA降低KV Cache显存占用的原理
MQA 的核心思想是让多个注意力头共享同一套 Key 和 Value,而不是像传统 MHA 那样为每个头分别维护独立的 K/V。

3. 代码实现
import torch
import torch.nn as nn
# 在推理阶段将历史 token 的 Key 和 Value 缓存
past_key_value = {}
# 自注意力机制
class DecoderSelfAttentionHead(nn.Module):
def __init__(self, d_model, head_size):
super().__init__()
self.head_size = head_size
self.n_head = d_model // head_size
self.query = nn.Linear(d_model, head_size)
self.key = nn.Linear(d_model, head_size)
self.value = nn.Linear(d_model, head_size)
def forward(self, x, use_cache, masked, layer_num, head_index):
print("layer_num: ", layer_num)
print("head_index: ", head_index)
# 计算Q, K, V
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# 如果use_cache是True,则代表当前是推理阶段,需要进行KV的缓存
if use_cache:
# 最后一个头更新KV cache
if head_index == self.n_head - 1:
if layer_num in past_key_value:
# 当前层历史KV存在,拼接历史KV,更新KV
K = torch.cat([past_key_value[layer_num][0], K], dim=1)
V = torch.cat([past_key_value[layer_num][1], V], dim=1)
past_key_value[layer_num] = (K, V)
else:
# 当前层历史KV存在,拼接历史KV
if layer_num in past_key_value:
K = torch.cat([past_key_value[layer_num][0], K], dim=1)
V = torch.cat([past_key_value[layer_num][1], V], dim=1)
print("past_key_value: ", past_key_value)
# Q和K的点积
attention = Q @ K.transpose(-2, -1)
# 缩放
attention = attention / (self.head_size ** 0.5)
# 训练的时候使用
if masked:
mask = torch.triu(torch.ones_like(attention), diagonal=1).bool()
attention.masked_fill(mask, float('-inf'))
# 打分
attention = torch.softmax(attention, dim=-1)
# 加权汇总
attention = attention @ V
return attention
class MaskedMultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
# 注意力头的个数,这块也可以看出d_model和注意力头数目由整倍数关系
self.head_size = d_model // n_heads
self.heads = nn.ModuleList([
DecoderSelfAttentionHead(d_model, self.head_size) for _ in range(n_heads)
])
self.multi_header = nn.Linear(d_model, d_model)
# 引入layer_num,表示模型层数. 需要为每一层的每一个自注意力头缓存KV
def forward(self, x, use_cache=False, masked=False, layer_num=0):
# 拼接多个注意力头
multi_header = torch.cat(
[head(x, use_cache, masked, layer_num, head_index) for head_index, head in enumerate(self.heads)], dim=-1)
out = self.multi_header(multi_header)
return out
if __name__ == '__main__':
x1 = torch.randn((1, 5, 12))
x2 = torch.randn((1, 1, 12))
x3 = torch.randn((1, 1, 12))
model = MaskedMultiHeadAttention(12, 4)
y1 = model(x1, use_cache=True)
y2 = model(x2, use_cache=True)
y3 = model(x3, use_cache=True)
4. GQA降低KV Cache显存占用的原理
将注意力头划分为多个组(Group),每组内部的多个 Query 共享同一套 Key 和 Value,而不同组之间则使用独立的 K/V

5. 代码实现
import torch
import torch.nn as nn
# 在推理阶段将历史 token 的 Key 和 Value 缓存
past_key_value = {}
# 自注意力机制
class DecoderSelfAttentionHead(nn.Module):
def __init__(self, d_model, head_size):
super().__init__()
self.head_size = head_size
self.n_head = d_model // head_size
self.query = nn.Linear(d_model, head_size)
self.key = nn.Linear(d_model, head_size)
self.value = nn.Linear(d_model, head_size)
def forward(self, x, use_cache, masked, layer_num, head_index, n_head_per_group):
if n_head_per_group is None: n_head_per_group = self.n_head
assert self.n_head % n_head_per_group == 0, "注意力头数除以一个组具有的头数,必须是整数"
print("layer_num: ", layer_num)
print("head_index: ", head_index)
# 计算Q, K, V
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# 如果use_cache是True,则代表当前是推理阶段,需要进行KV的缓存
if use_cache:
# 把一层的所有注意力头按照每组n_head_per_group个分为一组,group_index是组的编号
group_index = head_index // n_head_per_group
print("group_index: ", group_index)
# 注意力头在组内的编号
group_internal_id = head_index % n_head_per_group
print("group_internal_id: ", group_internal_id)
# 在组内最后一个注意力头更新KV cache
if group_internal_id == n_head_per_group - 1:
if layer_num not in past_key_value:
past_key_value[layer_num] = {}
if group_index in past_key_value[layer_num]:
K = torch.cat([past_key_value[layer_num][group_index][0], K], dim=1)
V = torch.cat([past_key_value[layer_num][group_index][1], V], dim=1)
past_key_value[layer_num][group_index] = (K, V)
else:
if layer_num in past_key_value and group_index in past_key_value[layer_num]:
K = torch.cat([past_key_value[layer_num][group_index][0], K], dim=1)
V = torch.cat([past_key_value[layer_num][group_index][1], V], dim=1)
print("past_key_value: ", past_key_value)
# Q和K的点积
attention = Q @ K.transpose(-2, -1)
# 缩放
attention = attention / (self.head_size ** 0.5)
# 训练的时候使用
if masked:
mask = torch.triu(torch.ones_like(attention), diagonal=1).bool()
attention.masked_fill(mask, float('-inf'))
# 打分
attention = torch.softmax(attention, dim=-1)
# 加权汇总
attention = attention @ V
return attention
class MaskedMultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, n_head_per_group=None):
super().__init__()
# 注意力头的个数,这块也可以看出d_model和注意力头数目由整倍数关系
self.head_size = d_model // n_heads
self.n_head_per_group = n_head_per_group
assert d_model % n_heads == 0, "词向量维度除以头数,必须是整数"
self.heads = nn.ModuleList([
DecoderSelfAttentionHead(d_model, self.head_size) for _ in range(n_heads)
])
self.multi_header = nn.Linear(d_model, d_model)
# 引入layer_num,表示模型层数. 需要为每一层的每一个自注意力头缓存KV
def forward(self, x, use_cache=False, masked=False, layer_num=0):
# 拼接多个注意力头
multi_header = torch.cat(
[head(x, use_cache, masked, layer_num, head_index, self.n_head_per_group) for head_index, head in
enumerate(self.heads)], dim=-1)
out = self.multi_header(multi_header)
return out
if __name__ == '__main__':
x1 = torch.randn((1, 5, 32))
x2 = torch.randn((1, 1, 32))
x3 = torch.randn((1, 1, 32))
model = MaskedMultiHeadAttention(32, 8, n_head_per_group=2)
y1 = model(x1, use_cache=True)
y2 = model(x2, use_cache=True)
y3 = model(x3, use_cache=True)
6. 总结:GQA 在 推理效率和表达能力之间实现了更优平衡。因此被主流大模型所采用,在长序列推理和高并发场景中表现极为出色。
【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)