KV Cache 节省的计算次数推演

举报
晋红轻 发表于 2026/04/28 10:19:38 2026/04/28
【摘要】 本文介绍Transformer架构的语言模型在做推理时KV cache是如何节省计算次数,供参考

我们知道 LLM 语言模型的基础架构源自 Transformer。而 LLM 在推理时基本只使用 Decoder 模块。Decoder 的核心部分可简单分为 2 块:1. Attention 2. MLP。我们的 K,V Cache 发生在 Attention 阶段。

我们做如下假设:

  1. 输入 xx 只有 1 条(Batch=1),维度为 T×CT \times CTT 表示输入的 token 个数,CC 表示每个 token 用长度为 CC 的向量表示。
  2. 权重 Q,K,VQ, K, V 矩阵的维度为 head_size×Chead\_size \times Chead_sizehead\_size 表示多头 Attention 中每头的维度,一般 head_size<Chead\_size < C

为论证简单,我们以计算 KK 为例:

K=xWTK = x W^T

其中:

  • xRT×Cx \in \mathbb{R}^{T \times C}
  • WRhead_size×CW \in \mathbb{R}^{head\_size \times C},实际计算中使用 WTRC×head_sizeW^T \in \mathbb{R}^{C \times head\_size}
  • 根据矩阵乘法,KRT×head_sizeK \in \mathbb{R}^{T \times head\_size}

这一步本质是对输入 xx 进行一次线性投影,简单理解在空间中找到另一个投影点,维度从 CC维 变成 head_sizehead\_size 维。

为方便说明,先从 T=1T=1 开始。

首次计算:

K=xWTK = x W^T

具体表示为:

[x11,x12,,x1C][w11w12w1head_sizew21w22w2head_sizewC1wC2wChead_size]=[k11,k12,,k1head_size][x_{11}, x_{12}, \dots, x_{1C}] \cdot \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1\,head\_size} \\ w_{21} & w_{22} & \cdots & w_{2\,head\_size} \\ \vdots & \vdots & \ddots & \vdots \\ w_{C1} & w_{C2} & \cdots & w_{C\,head\_size} \end{bmatrix} = [k_{11}, k_{12}, \dots, k_{1\,head\_size}]

其中:

k11=i=1Cx1iwi1k12=i=1Cx1iwi2k1head_size=i=1Cx1iwihead_sizek_{11} = \sum_{i=1}^{C} x_{1i} w_{i1} \\ k_{12} = \sum_{i=1}^{C} x_{1i} w_{i2} \\ \vdots \\ k_{1\,head\_size} = \sum_{i=1}^{C} x_{1i} w_{i\,head\_size}

如果我们将一次内积乘法算作 1 次计算(此处将整个向量内积抽象为一次计算,便于复杂度分析),那么 T=1T=1 时共做了 head_sizehead\_size 次计算。

T=2T=2 时:

[x11x12x1Cx21x22x2C][w11w12w1head_sizew21w22w2head_sizewC1wC2wChead_size]=[k11k12k1head_sizek21k22k2head_size]\begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1C} \\ x_{21} & x_{22} & \cdots & x_{2C} \end{bmatrix} \cdot \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1\,head\_size} \\ w_{21} & w_{22} & \cdots & w_{2\,head\_size} \\ \vdots & \vdots & \ddots & \vdots \\ w_{C1} & w_{C2} & \cdots & w_{C\,head\_size} \end{bmatrix} = \begin{bmatrix} k_{11} & k_{12} & \cdots & k_{1\,head\_size} \\ k_{21} & k_{22} & \cdots & k_{2\,head\_size} \end{bmatrix}

此时总计算次数为 2×head_size2 \times head\_size 次。

对比不同输入长度 TT 下的计算次数:

TT 开 KV Cache 计算次数 关 KV Cache 计算次数
1 head_sizehead\_size head_sizehead\_size
2 head_sizehead\_size 2×head_size2 \times head\_size
3 head_sizehead\_size 3×head_size3 \times head\_size
n head_sizehead\_size n×head_sizen \times head\_size

所以,当从第 1 个 token 一直推理到第 nn 个 token 时,仅以 KK 的计算为例——

  • 关 KV Cache 的累计计算次数 = (1+2+3++n)×head_size=n(n+1)2×head_size(1+2+3+\dots+n) \times head\_size = \frac{n(n+1)}{2} \times head\_size
  • 开 KV Cache 的累计计算次数 = n×head_sizen \times head\_size

线性投影阶段,计算开销从 O(n2)O(n^2) 降低到 O(n)O(n)对于 VV 的计算,情况完全一样,本文以 KK 为代表。

总结:KV Cache 推理在线性投影阶段计算开销对比

维度 无 KV Cache(每次全量重算) 有 KV Cache(仅计算增量)
nn 步推理计算量 n×head_sizen \times head\_size 1×head_size1 \times head\_size
累计总计算次数 n(n+1)2×head_size\frac{n(n+1)}{2} \times head\_size n×head_sizen \times head\_size
算法复杂度 O(n2)O(n^2) O(n)O(n)
核心逻辑 每一轮都要重新计算过去所有 Token 的 K,VK, V 矩阵(QQ 每步只算当前 token,无需重算历史)。 仅计算当前最新 Token 的 Q,K,VQ, K, V,历史的 K,VK, V 直接从缓存读取。
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。