理解大语言模型推理的 KVCache

大语言模型的一个重要方向是“推理”优化,即如何在有限的硬件环境中提升推理的效率。对于所有的 MaaS 服务提供方,这都是至关重要的。一方面关乎用户的使用体验(诸如TTFT,time to first token)、另一方面关于服务提供的成本(有限的GPU如何提供更高的吞吐量)。

1. 概述

从 Transformer 架构的 Decoder 阶段原理来看,一个常见的、自然的优化就是使用“KV Cache”大大减少推理(自回归阶段)过程需要计算量,实现以显存换效率,从而加速推理过程。

2. Decoder 模型的自回归计算

在了解了“Attention”、“mask attention”、“autoregression”计算之后,比较自然可以注意到在 Q、K、V 矩阵在“autoregression”的过程中,有很多的部分是无需额外计算的。

这里依旧继续使用《理解大语言模型的核心:Attention》中的示例,这里考虑在文章中的提示词“It’s very hot in summer. Swimming is”,生成新的Token为 “ a”,那么我们看看这个自回归过程某个Head中的计算。完成的代码可以参考:autoregression-of-attention.ipynb

相比与在 prefill 阶段,需要额外计算的,在后续使用黄色标识出来。

2. 1 Token Embedding 和 Positional Embedding

Token Embedding

+

Positional Embedding

这里,只需要计算最新的Token(即这里的“ a”)的Embedding即可。事实上,上面矩阵白色部分再自回归阶段完全不再需要使用了。所以,上述内容计算完成后,内存即可释放,无需缓存。

2. 2 Normalize

即,将每一个token的embedding 进行正规化,将其均值变为0,方差变为1

与前面类似,这里计算完成并推进到下一步后,内存即可释放,无需缓存。

2. 3 Attention 层的参数矩阵

\(W^Q\,,W^K\,,W^V \)

2. 4 矩阵 Q K V的计算

\(Q = XW^Q \)

\(K = XW^K \)

\(V = XW^V \)

2. 5 计算 Attention Score

\(\text{Attention Score} \)

\(= \frac{QK^T}{\sqrt{d}} \)

特别需要注意的,这一步中,“Attention Score Matrix”最后一行的计算,需要前面的Q的最后一行,此外还需要整个 K 矩阵。这就是为什么 K 矩阵是需要缓存的。

2. 6 计算 Masked Attention Score

\(\text{Masked Attention Score} \)

\(= \frac{QK^T}{\sqrt{d}} + \text{mask} \)

2.7 计算 Softmax Masked Attention Score

\(\text{Softmax Masked Attention Score} \)

\(= \text{softmax}(\frac{QK^T}{\sqrt{d}} + \text{mask}) \)

2. 8 计算 Contextual Embeddings

\(\text{Contextual Embeddings} \)

\(= \text{softmax}(\frac{QK^T}{\sqrt{d}} + \text{mask})V \)

所以,这一步中,“Contextual Embeddings”最后一行的计算,需要前面 Softmax Masked Attention Score Matrix 的最后一行,此外还需要整个 V 矩阵。这就是为什么 V 矩阵是需要缓存的。

此外可以看到,在这个自回归的计算中,Q 矩阵前面的所有行(即上一轮计算的Q矩阵)都用不上,这也是为什么 Q 矩阵不需要缓存,即我们需要的“KV Cache”,而不是“QKV Cache”的原因。

3. 计算图示

这里依据使用了图示的方式展示了在“自回归”过程中的数学计算。在下图中,第一个生成的 Token 为“ a”,该 Token 在进入 Decoder 模型再次进行计算时(即“自回归”),下图中:

  • 粉红色背景部分为新的、需要计算的部分;
  • 灰色背景部分为虽然不需要计算,但在计算新的内容时,需要使用的部分。

灰色部分即为“KV Cache”需要缓存的部分。即,每一个 Token 对应的 “K”、“V” 矩阵都需要在后续的计算中使用。亦即,每一个 Token 的 Key 向量都需要保存,用于与新的 Token 的 Query 向量进行点击计算“关注度”值;每一个 Token 的 V 向量也需要保持

在上述的计算中,注意到,在一次的新的“自回归”中,最终需要额外计算的就是新Token(这里是“ a”)对应的 Centextual Embedding,该内容计算,需要使用前述所有 Token 对应的 K、V 值,即这里的 K 和 V 矩阵。

所以,在一次自回归推理中,最好上一次计算的所有 Token 的 K、V 向量都缓存起来,避免重复计算。本次自回归中计算新Token的对应的 K、V 向量也需要缓存,以供后续使用。

4. KV Cache 的内存消耗

在推理优化中,一个重要硬限制便是GPU卡的显存(memory)大小。当前,主流的企业级显卡H100显存为80GB,高端显卡 H200 显存为141 GB。现在的 LLM 参数量通常巨大,参数加载就需要耗费巨大的显存,以最新的 llama 4 17B为例,考虑 FP16 (半精度)考虑,则需要消耗约 30+ GB 。卡片上剩余的内存,才是用于实际的推理使用。而每次推理,例如提示词是1000个Token,输出也是1000个Token,那么,在生成最后一个Token的时候,需要的内存(按5%的经验值计算)约为1.5GB。这时候,单个H100的显卡也只能支持约33个并发,实际的情况则要考虑系统内存等,会比这个预估多很多。

在这篇文章:Mastering LLM Techniques: Inference Optimization@developer.nvidia.com 中也类似的估算:

  • 7B 的模型(如Llama 2 7B),参数是16位(FP16 or BF16)则参数需要消耗约 14 GB 显存
  • Token 数为4096的推理(decoder),则需要约 2 GB KV Cache

从上述粗略的预估可以看得出来,高效使用显存资源对于 LLM 推理来说至关重要。所以,各推理框架则会通过各种方法尝试去优化“KV Cache”以降低显存使用。这些方法包括“量化”(Quantization)、MQA/MGA 等。

5. Multi-Query Attention/Group-Query Attention

可以看到,无论是在模型参数加载的时候,还是推理 KV Cache 阶段,都需要大量的显存。关于 MQA 和 GQA 的经典论文是:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

5.1 关于MQA与GQA

Multi-Query Attention 则尝试通过减少 \(W^K \, W^V \) 参数的数量来减少上述显存,从而增加推理速度与并发能力。参考下图,可以看到在每一个 Layer 中,所有的 Head 共享一组 \(W^K \, W^V \) 参数,那么这两个相关参数就减少到了原来的 \(\frac{1}{h} \)。

更进一步的,为了减少上述方法(MQA)对于模型效果的影响,另一个优化是 Group-Query Attention。即如下图,一组 Heads 共享一组 \(W^K \, W^V \) 。可以依照分组的大小,以平衡模型效果与资源使用。如果一个 Head 一组 \(W^K \, W^V \) 则退化到普通的 Multi-Head Attention;如果所有 Heads 分到一组,则退化到普通的 Multi-Query Attention。

5.2 模型训练 Uptraining

此外,比较关键的,论文提出了一些关于 GQA 架构的训练优化。

例如,从一个 MHA 架构开始训练,然后从某个 checkpoint 开始,将MHA模型改成GQA模型,在初始化分组参数时,则使用原 MHA 模型中参数去求一个均值的方式初始化GQA中对应的 \(W^K \, W^V \) 。然后继续使用语料库对于该新模型训练。

论文指出,这时候只需要使用非常少的计算资源就可以训练处效果还不错的GQA新模型。新的GQA模型,则可以使用更少的显存资源,有更好的并发吞吐能力,同时也达到还比较好的效果。

参考

Leave a Reply

Your email address will not be published. Required fields are marked *