我们知道 LLM 语言模型的基础架构源自 Transformer。而 LLM 在推理时基本只使用 Decoder 模块。Decoder 的核心部分可简单分为 2 块:1. Attention 2. MLP。我们的 K,V Cache 发生在 Attention 阶段。
我们做如下假设:
- 输入 x 只有 1 条(Batch=1),维度为 T×C,T 表示输入的 token 个数,C 表示每个 token 用长度为 C 的向量表示。
- 权重 Q,K,V 矩阵的维度为 head_size×C,head_size 表示多头 Attention 中每头的维度,一般 head_size<C。
为论证简单,我们以计算 K 为例:
K=xWT
其中:
- x∈RT×C
- W∈Rhead_size×C,实际计算中使用 WT∈RC×head_size
- 根据矩阵乘法,K∈RT×head_size
这一步本质是对输入 x 进行一次线性投影,简单理解在空间中找到另一个投影点,维度从 C维 变成 head_size 维。
为方便说明,先从 T=1 开始。
首次计算:
K=xWT
具体表示为:
[x11,x12,…,x1C]⋅⎣⎢⎢⎢⎢⎡w11w21⋮wC1w12w22⋮wC2⋯⋯⋱⋯w1head_sizew2head_size⋮wChead_size⎦⎥⎥⎥⎥⎤=[k11,k12,…,k1head_size]
其中:
k11=i=1∑Cx1iwi1k12=i=1∑Cx1iwi2⋮k1head_size=i=1∑Cx1iwihead_size
如果我们将一次内积乘法算作 1 次计算(此处将整个向量内积抽象为一次计算,便于复杂度分析),那么 T=1 时共做了 head_size 次计算。
当 T=2 时:
[x11x21x12x22⋯⋯x1Cx2C]⋅⎣⎢⎢⎢⎢⎡w11w21⋮wC1w12w22⋮wC2⋯⋯⋱⋯w1head_sizew2head_size⋮wChead_size⎦⎥⎥⎥⎥⎤=[k11k21k12k22⋯⋯k1head_sizek2head_size]
此时总计算次数为 2×head_size 次。
对比不同输入长度 T 下的计算次数:
| T |
开 KV Cache 计算次数 |
关 KV Cache 计算次数 |
| 1 |
head_size 次 |
head_size 次 |
| 2 |
head_size 次 |
2×head_size 次 |
| 3 |
head_size 次 |
3×head_size 次 |
| … |
… |
… |
| n |
head_size 次 |
n×head_size 次 |
所以,当从第 1 个 token 一直推理到第 n 个 token 时,仅以 K 的计算为例——
- 关 KV Cache 的累计计算次数 = (1+2+3+⋯+n)×head_size=2n(n+1)×head_size
- 开 KV Cache 的累计计算次数 = n×head_size
在线性投影阶段,计算开销从 O(n2) 降低到 O(n)。对于 V 的计算,情况完全一样,本文以 K 为代表。
总结:KV Cache 推理在线性投影阶段计算开销对比
| 维度 |
无 KV Cache(每次全量重算) |
有 KV Cache(仅计算增量) |
| 第 n 步推理计算量 |
n×head_size |
1×head_size |
| 累计总计算次数 |
2n(n+1)×head_size |
n×head_size |
| 算法复杂度 |
O(n2) |
O(n) |
| 核心逻辑 |
每一轮都要重新计算过去所有 Token 的 K,V 矩阵(Q 每步只算当前 token,无需重算历史)。 |
仅计算当前最新 Token 的 Q,K,V,历史的 K,V 直接从缓存读取。 |
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
评论(0)