大模型原理--多头注意力机制的代码实现(MQA和GQA)

举报
剑指南天 发表于 2026/05/04 17:45:16 2026/05/04
【摘要】 KV Cache显存占用的因素有大模型的层数、MHA中的头数、kv向量的长度、数字表示的精度、上下文长度以及推理请求的个数。其中上下文长度,推理请求的个数会动态的影响KV Cache,会造成KV Cache缓存规模成倍的增加,物理显存限制和显存带宽会显著影响计算效率。

1.概述

KV Cache显存占用的因素有大模型的层数、MHA中的头数、kv向量的长度、数字表示的精度、上下文长度以及推理请求的个数。其中上下文长度,推理请求的个数会动态的影响KV Cache,会造成KV Cache缓存规模成倍的增加,物理显存限制和显存带宽会显著影响计算效率。

2. MQA降低KV Cache显存占用的原理

MQA 的核心思想是让多个注意力头共享同一套 Key 和 Value,而不是像传统 MHA 那样为每个头分别维护独立的 K/V。

3. 代码实现

import torch
import torch.nn as nn

# 在推理阶段将历史 token  Key  Value 缓存
past_key_value = {}

# 自注意力机制
class DecoderSelfAttentionHead(nn.Module):
    def __init__(self, d_model, head_size):
        super().__init__()
        self.head_size = head_size
        self.n_head = d_model // head_size
        self.query = nn.Linear(d_model, head_size)
        self.key = nn.Linear(d_model, head_size)
        self.value = nn.Linear(d_model, head_size)

    def forward(self, x, use_cache, masked, layer_num, head_index):
        print("layer_num: ", layer_num)
        print("head_index: ", head_index)
        # 计算Q, K, V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        # 如果use_cacheTrue,则代表当前是推理阶段,需要进行KV的缓存
        if use_cache:
            # 最后一个头更新KV cache
            if head_index == self.n_head - 1:
                if layer_num in past_key_value:
                    # 当前层历史KV存在,拼接历史KV,更新KV
                    K = torch.cat([past_key_value[layer_num][0], K], dim=1)
                    V = torch.cat([past_key_value[layer_num][1], V], dim=1)
                past_key_value[layer_num] = (K, V)
            else:
                # 当前层历史KV存在,拼接历史KV
                if layer_num in past_key_value:
                    K = torch.cat([past_key_value[layer_num][0], K], dim=1)
                    V = torch.cat([past_key_value[layer_num][1], V], dim=1)
            print("past_key_value: ", past_key_value)

        # QK的点积
        attention = Q @ K.transpose(-2, -1)
        # 缩放
        attention = attention / (self.head_size ** 0.5)
        # 训练的时候使用
        if masked:
            mask = torch.triu(torch.ones_like(attention), diagonal=1).bool()
            attention.masked_fill(mask, float('-inf'))
        # 打分
        attention = torch.softmax(attention, dim=-1)
        # 加权汇总
        attention = attention @ V
        return attention

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        # 注意力头的个数,这块也可以看出d_model和注意力头数目由整倍数关系
        self.head_size = d_model // n_heads
        self.heads = nn.ModuleList([
            DecoderSelfAttentionHead(d_model, self.head_size) for _ in range(n_heads)
        ])
        self.multi_header = nn.Linear(d_model, d_model)

    # 引入layer_num,表示模型层数. 需要为每一层的每一个自注意力头缓存KV
    def forward(self, x, use_cache=False, masked=False, layer_num=0):
        # 拼接多个注意力头
        multi_header = torch.cat(
            [head(x, use_cache, masked, layer_num, head_index) for head_index, head in enumerate(self.heads)], dim=-1)
        out = self.multi_header(multi_header)
        return out
if __name__ == '__main__':
    x1 = torch.randn((1, 5, 12))
    x2 = torch.randn((1, 1, 12))
    x3 = torch.randn((1, 1, 12))

    model = MaskedMultiHeadAttention(12, 4)
    y1 = model(x1, use_cache=True)
    y2 = model(x2, use_cache=True)
    y3 = model(x3, use_cache=True)

4. GQA降低KV Cache显存占用的原理

将注意力头划分为多个组(Group),每组内部的多个 Query 共享同一套 Key 和 Value,而不同组之间则使用独立的 K/V

5. 代码实现

import torch
import torch.nn as nn

# 在推理阶段将历史 token  Key  Value 缓存
past_key_value = {}


# 自注意力机制
class DecoderSelfAttentionHead(nn.Module):
    def __init__(self, d_model, head_size):
        super().__init__()
        self.head_size = head_size
        self.n_head = d_model // head_size
        self.query = nn.Linear(d_model, head_size)
        self.key = nn.Linear(d_model, head_size)
        self.value = nn.Linear(d_model, head_size)

    def forward(self, x, use_cache, masked, layer_num, head_index, n_head_per_group):
        if n_head_per_group is None: n_head_per_group = self.n_head
        assert self.n_head % n_head_per_group == 0, "注意力头数除以一个组具有的头数,必须是整数"
        print("layer_num: ", layer_num)
        print("head_index: ", head_index)
        # 计算Q, K, V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        # 如果use_cacheTrue,则代表当前是推理阶段,需要进行KV的缓存
        if use_cache:
            # 把一层的所有注意力头按照每组n_head_per_group个分为一组,group_index是组的编号
            group_index = head_index // n_head_per_group
            print("group_index: ", group_index)
            # 注意力头在组内的编号
            group_internal_id = head_index % n_head_per_group
            print("group_internal_id: ", group_internal_id)
            
            # 在组内最后一个注意力头更新KV cache
            if group_internal_id == n_head_per_group - 1:
                if layer_num not in past_key_value:
                    past_key_value[layer_num] = {}
                if group_index in past_key_value[layer_num]:
                    K = torch.cat([past_key_value[layer_num][group_index][0], K], dim=1)
                    V = torch.cat([past_key_value[layer_num][group_index][1], V], dim=1)
                past_key_value[layer_num][group_index] = (K, V)
            else:
                if layer_num in past_key_value and group_index in past_key_value[layer_num]:
                    K = torch.cat([past_key_value[layer_num][group_index][0], K], dim=1)
                    V = torch.cat([past_key_value[layer_num][group_index][1], V], dim=1)
            print("past_key_value: ", past_key_value)

        # QK的点积
        attention = Q @ K.transpose(-2, -1)
        # 缩放
        attention = attention / (self.head_size ** 0.5)
        # 训练的时候使用
        if masked:
            mask = torch.triu(torch.ones_like(attention), diagonal=1).bool()
            attention.masked_fill(mask, float('-inf'))
        # 打分
        attention = torch.softmax(attention, dim=-1)
        # 加权汇总
        attention = attention @ V
        return attention


class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_head_per_group=None):
        super().__init__()
        # 注意力头的个数,这块也可以看出d_model和注意力头数目由整倍数关系
        self.head_size = d_model // n_heads
        self.n_head_per_group = n_head_per_group
        assert d_model % n_heads == 0, "词向量维度除以头数,必须是整数"
        self.heads = nn.ModuleList([
            DecoderSelfAttentionHead(d_model, self.head_size) for _ in range(n_heads)
        ])
        self.multi_header = nn.Linear(d_model, d_model)

    # 引入layer_num,表示模型层数. 需要为每一层的每一个自注意力头缓存KV
    def forward(self, x, use_cache=False, masked=False, layer_num=0):
        # 拼接多个注意力头
        multi_header = torch.cat(
            [head(x, use_cache, masked, layer_num, head_index, self.n_head_per_group) for head_index, head in
             enumerate(self.heads)], dim=-1)
        out = self.multi_header(multi_header)
        return out
if __name__ == '__main__':
    x1 = torch.randn((1, 5, 32))
    x2 = torch.randn((1, 1, 32))
    x3 = torch.randn((1, 1, 32))

    model = MaskedMultiHeadAttention(32, 8, n_head_per_group=2)
    y1 = model(x1, use_cache=True)
    y2 = model(x2, use_cache=True)
    y3 = model(x3, use_cache=True)

6. 总结:GQA 在 推理效率和表达能力之间实现了更优平衡。因此被主流大模型所采用,在长序列推理和高并发场景中表现极为出色。


【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。