大语言模型的一个重要方向是“推理”优化,即如何在有限的硬件环境中提升推理的效率。对于所有的 MaaS 服务提供方,这都是至关重要的。一方面关乎用户的使用体验(诸如TTFT,time to first token)、另一方面关于服务提供的成本(有限的GPU如何提供更高的吞吐量)。
1. 概述
从 Transformer 架构的 Decoder 阶段原理来看,一个常见的、自然的优化就是使用“KV Cache”大大减少推理(自回归阶段)过程需要计算量,实现以显存换效率,从而加速推理过程。
2. Decoder 模型的自回归计算
在了解了“Attention”、“mask attention”、“autoregression”计算之后,比较自然可以注意到在 Q、K、V 矩阵在“autoregression”的过程中,有很多的部分是无需额外计算的。
这里依旧继续使用《理解大语言模型的核心:Attention》中的示例,这里考虑在文章中的提示词“It’s very hot in summer. Swimming is”,生成新的Token为 “ a”,那么我们看看这个自回归过程某个Head中的计算。完成的代码可以参考:autoregression-of-attention.ipynb。
相比与在 prefill 阶段,需要额外计算的,在后续使用黄色标识出来。
2. 1 Token Embedding 和 Positional Embedding
Token Embedding
+
Positional Embedding
----------------------------------------------------------------------------------------------------------------------------------
| Token | Token ID | Token Embeddings(first 3 of 768 ) | Positional Embeddings | Token Embedding + Positional |
----------------------------------------------------------------------------------------------------------------------------------
| It | 1026 | [ 0.0390, -0.0869, 0.0662, ...] | [ -0.0188, -0.1974, 0.0040, ...] | [ 0.0202, -0.2844, 0.0702, ...] |
| âĢ | 447 | [ -0.0750, 0.0948, -0.0034, ...] | [ 0.0240, -0.0538, -0.0949, ...] | [ -0.0510, 0.0410, -0.0982, ...] |
| Ļ | 247 | [ -0.0223, 0.0182, 0.2631, ...] | [ 0.0042, -0.0848, 0.0545, ...] | [ -0.0181, -0.0666, 0.3176, ...] |
| s | 82 | [ -0.0640, -0.0469, 0.2061, ...] | [ -0.0003, -0.0738, 0.1055, ...] | [ -0.0643, -0.1207, 0.3116, ...] |
| Ġvery | 845 | [ -0.0553, -0.0348, 0.0606, ...] | [ 0.0076, -0.0251, 0.1270, ...] | [ -0.0477, -0.0599, 0.1876, ...] |
| Ġhot | 3024 | [ 0.0399, -0.0053, 0.0742, ...] | [ 0.0096, -0.0339, 0.1312, ...] | [ 0.0495, -0.0392, 0.2054, ...] |
| Ġin | 287 | [ -0.0337, 0.0108, 0.0293, ...] | [ 0.0027, -0.0205, 0.1196, ...] | [ -0.0310, -0.0098, 0.1490, ...] |
| Ġsummer | 3931 | [ 0.0422, 0.0138, -0.0213, ...] | [ 0.0025, -0.0032, 0.1174, ...] | [ 0.0448, 0.0106, 0.0961, ...] |
| . | 13 | [ 0.0466, -0.0113, 0.0283, ...] | [ -0.0012, -0.0018, 0.1110, ...] | [ 0.0454, -0.0131, 0.1394, ...] |
| ĠSw | 2451 | [ 0.0617, 0.0373, 0.1018, ...] | [ 0.0049, 0.0021, 0.1178, ...] | [ 0.0666, 0.0395, 0.2196, ...] |
| imming | 27428 | [ -0.1385, -0.1774, -0.0181, ...] | [ 0.0016, 0.0062, 0.1004, ...] | [ -0.1369, -0.1711, 0.0823, ...] |
| Ġis | 318 | [ -0.0097, 0.0101, 0.0556, ...] | [ -0.0036, 0.0175, 0.1068, ...] | [ -0.0133, 0.0275, 0.1623, ...] |
| Ġa | 257 | [ -0.0506, 0.0056, 0.0471, ...] | [ 0.0001, 0.0172, 0.0969, ...] | [ -0.0506, 0.0228, 0.1440, ...] |
----------------------------------------------------------------------------------------------------------------------------------
这里,只需要计算最新的Token(即这里的“ a”)的Embedding即可。事实上,上面矩阵白色部分再自回归阶段完全不再需要使用了。所以,上述内容计算完成后,内存即可释放,无需缓存。
2. 2 Normalize
即,将每一个token的embedding 进行正规化,将其均值变为0,方差变为1
------------------------------------------------------
| Token | Token ID | Normalized(first 3 of 768 ) |
------------------------------------------------------
| It | 1026 | [ 0.0129 , -0.1104 , -0.0317] |
| âĢ | 447 | [-0.0530 , 0.0588 , -0.1290] |
| Ļ | 247 | [-0.0170 , -0.0242 , 0.1639] |
| s | 82 | [-0.0754 , -0.0842 , 0.1842] |
| Ġvery | 845 | [-0.0566 , -0.0280 , 0.0953] |
| Ġhot | 3024 | [ 0.0587 , -0.0086 , 0.1073] |
| Ġin | 287 | [-0.0391 , 0.0209 , 0.0731] |
| Ġsummer | 3931 | [ 0.0532 , 0.0397 , 0.0181] |
| . | 13 | [ 0.0553 , 0.0152 , 0.0579] |
| ĠSw | 2451 | [ 0.0807 , 0.0691 , 0.1216] |
| imming | 27428 | [-0.1528 , -0.1249 , -0.0017] |
| Ġis | 318 | [-0.0175 , 0.0605 , 0.0880] |
| Ġa | 257 | [-0.0688 , 0.0540 , 0.0697] |
------------------------------------------------------
与前面类似,这里计算完成并推进到下一步后,内存即可释放,无需缓存。
2. 3 Attention 层的参数矩阵
\(W^Q\,,W^K\,,W^V \)
W^Q [:3] shape (768 x 64) W^K [:3] shape (768 x 64) W^V [:3] shape (768 x 64)
------------------------------------- -------------------------------- --------------------------------
[-0.4738, -0.2614, -0.0978, ...] | [ 0.3660, 0.0771, 0.2226, ...] [ 0.1421, 0.0329, -0.0667, ...]
[ 0.0874, 0.1473, 0.2387, ...] | [-0.4380, -0.1446, -0.4717, ...] [ 0.0162, -0.0633, -0.0636, ...]
[ 0.0039, 0.0695, 0.3668, ...] | [ 0.1237, 0.0174, 0.1181, ...] [ 0.0229, -0.0828, 0.0437, ...]
[ 0.2215, -0.1884, -0.0141, ...] 64 [-0.2247, 0.0148, -0.1859, ...] [-0.0106, 0.0070, 0.0565, ...]
[-0.0947, 0.1678, -0.0143, ...] rows [-0.2001, -0.1052, -0.1743, ...] [ 0.0416, 0.0938, -0.1792, ...]
... | ... ...
[-0.4100, -0.1924, -0.2400, ...] | [,0.1567, 0.2664, 0.1851, ...] [-0.0341, 0.0034, 0.0203, ...]
------------------------------------- -------------------------------- --------------------------------
|<------- columns: 768 ------->| |<------- columns: 768 ------->| |<------- columns: 768 ------->|
这是三个权重矩阵,总是需要常驻内存的,并且可以被多个“推理”共享使用。
2. 4 矩阵 Q K V的计算
\(Q = XW^Q \)
\(K = XW^K \)
\(V = XW^V \)
Q [:3] shape (12 x 64) K [:3] shape (12 x 64) V [:3] shape (12 x 64)
------------------------------------- --------------------------------- --------------------------------
[ 0.4207, -0.9178, 0.1760, ...] | [ -1.4202, 1.6791, 0.9837, ...] [ 0.0452, 0.0628, 0.1463, ...]
[ 0.7757, 0.2485, 0.7349, ...] | [ -2.5320, 2.2932, 1.5592, ...] [-0.1361, 0.1379, 0.0150, ...]
[ 0.4481, 0.0206, -0.0825, ...] | [ -2.2571, 2.7764, 1.8401, ...] [ 0.0039, -0.1295, -0.0311, ...]
[ 0.9500, 0.1481, 0.3469, ...] 12 [ -2.4322, 3.1454, 2.0600, ...] [-0.0391, 0.0581, 0.0511, ...]
[ 0.4989, -0.4376, 0.1678, ...] rows [ -3.5428, 2.1485, 2.0414, ...] [ 0.0963, 0.3563, -0.1477, ...]
... | ... ...
[ 0.4429, -1.1997, 0.5611, ...] | [ -2.2559, 2.0384, 2.2542, ...] [ 0.2759, -0.2783, 0.3240, ...]
[ 0.4989, -0.4376, 0.1678, ...] | [ -2.6703, 2.3629, 1.7493, ...] [ -0.0633, 0.0431, -0.0422, ...]
------------------------------------- --------------------------------- --------------------------------
|<------- columns: 64 ------->| |<------- columns: 64 ------->| |<------- columns: 64 ------->|
计算 Q、K、V 矩阵,这里只有最后一行(即对应最后一个Token “ a”)。这里的矩阵 K 、V 需要进行缓存,在后续每一次自回归的过程都需要完整的使用 K V 矩阵中所有值,下一步会说明原因。Q 矩阵在完成后矩阵计算,就可以释放。
2. 5 计算 Attention Score
\(\text{Attention Score} \)
\(= \frac{QK^T}{\sqrt{d}} \)
|-----------------------------------------------------------------------------------------------------|
| | Attention Score Matrix shape (13 x 13) |
| Token |---------------------------------------------------------------------------------------------|
| | It âĢ Ļ s Ġvery Ġhot Ġin Ġsummer . ĠSw imming Ġis Ġa |
|-------|---------------------------------------------------------------------------------------------|----
|It | [ 0.14, -1.53, -1.45, -1.71, -1.69, -1.74, -2.36, -2.27, -2.37, -1.33, -0.58, -2.40, / ]| |
|âĢ | [ 0.70, -0.93, -1.72, -1.02, -1.52, -2.24, -1.90, -2.19, -1.63, -2.13, -1.66, -2.14, / ]| |
|Ļ | [-0.60, -1.81, -1.99, -1.96, -2.57, -1.84, -1.62, -2.04, -0.98, -1.18, -2.23, -2.25, / ]| |
|s | [-0.46, -1.33, -1.60, -2.65, -2.24, -1.99, -2.89, -1.44, -2.05, -2.77, -2.09, -2.74, / ]| |
|Ġvery | [ 0.29, -1.42, -1.77, -1.15, -0.94, -1.14, -1.81, -1.04, -1.77, -2.13, -0.60, -0.82, / ]| |
|Ġhot | [ 0.03, -0.68, -0.59, -0.95, -1.78, -0.10, -0.95, -0.14, -1.32, -0.57, 0.06, -1.07, / ]| 13
|Ġin | [-0.71, -1.72, -1.53, -2.18, -1.67, -1.93, -3.41, -1.69, -2.74, -1.89, -1.17, -2.02, / ]| rows
|Ġsummer| [-0.34, -1.49, -1.35, -1.31, -1.12, -0.89, -1.49, -1.11, -1.51, -1.15, -1.45, -1.20, / ]| |
|. | [-0.89, -1.73, -2.67, -2.80, -2.45, -2.37, -4.39, -2.33, -4.42, -2.73, -1.82, -3.21, / ]| |
|ĠSw | [-0.05, -1.15, -1.76, -1.15, -1.68, -0.74, -1.15, -1.35, -1.36, -1.29, -0.43, -1.51, / ]| |
|imming | [-0.02, -1.65, -0.87, -0.35, -1.18, -0.65, -0.33, -1.25, -0.38, -1.68, -2.15, -1.08, / ]| |
|Ġis | [-0.97, -2.03, -2.56, -2.94, -1.96, -2.71, -4.07, -2.46, -3.51, -2.68, -1.88, -2.99, / ]| |
|Ġa | [-1.10, -1.95, -2.12, -3.12, -2.72, -2.17, -3.88, -2.06, -3.57, -2.49, -1.86, -2.83, -3.40 ]| |
|-------|---------------------------------------------------------------------------------------------|----
|<------------------------------------ columns: 13 ------------------------------------------>|
特别需要注意的,这一步中,“Attention Score Matrix”最后一行的计算,需要前面的Q的最后一行,此外还需要整个 K 矩阵。这就是为什么 K 矩阵是需要缓存的。
2. 6 计算 Masked Attention Score
\(\text{Masked Attention Score} \)
\(= \frac{QK^T}{\sqrt{d}} + \text{mask} \)
|-----------------------------------------------------------------------------------------------------|
| | Attention Score Matrix shape (13 x 13) |
| Token |---------------------------------------------------------------------------------------------|
| | It âĢ Ļ s Ġvery Ġhot Ġin Ġsummer . ĠSw imming Ġis Ġa |
|-------|---------------------------------------------------------------------------------------------|----
|It | [ 0.14, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf ]| |
|âĢ | [ 0.70, -0.93, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf ]| |
|Ļ | [-0.60, -1.81, -1.99, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf ]| |
|s | [-0.46, -1.33, -1.60, -2.65, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf ]| |
|Ġvery | [ 0.29, -1.42, -1.77, -1.15, -0.94, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf ]| |
|Ġhot | [ 0.03, -0.68, -0.59, -0.95, -1.78, -0.10, -inf, -inf, -inf, -inf, -inf, -inf, -inf ]| 13
|Ġin | [-0.71, -1.72, -1.53, -2.18, -1.67, -1.93, -3.41, -inf, -inf, -inf, -inf, -inf, -inf ]| rows
|Ġsummer| [-0.34, -1.49, -1.35, -1.31, -1.12, -0.89, -1.49, -1.11, -inf, -inf, -inf, -inf, -inf ]| |
|. | [-0.89, -1.73, -2.67, -2.80, -2.45, -2.37, -4.39, -2.33, -4.42, -inf, -inf, -inf, -inf ]| |
|ĠSw | [-0.05, -1.15, -1.76, -1.15, -1.68, -0.74, -1.15, -1.35, -1.36, -1.29, -inf, -inf, -inf ]| |
|imming | [-0.02, -1.65, -0.87, -0.35, -1.18, -0.65, -0.33, -1.25, -0.38, -1.68, -2.15, -inf, -inf ]| |
|Ġis | [-0.97, -2.03, -2.56, -2.94, -1.96, -2.71, -4.07, -2.46, -3.51, -2.68, -1.88, -2.99, -inf ]| |
|Ġa | [-1.10, -1.95, -2.12, -3.12, -2.72, -2.17, -3.88, -2.06, -3.57, -2.49, -1.86, -2.83, -3.40 ]| |
|-------|---------------------------------------------------------------------------------------------|----
|<------------------------------------ columns: 13 ------------------------------------------>|
2.7 计算 Softmax Masked Attention Score
\(\text{Softmax Masked Attention Score} \)
\(= \text{softmax}(\frac{QK^T}{\sqrt{d}} + \text{mask}) \)
|---------------------------------------------------------------------------------------|
| | Softmax Masked Attention Score Matrix shape (13 x 13) |
| Token |-------------------------------------------------------------------------------|
| | It âĢ Ļ s Ġvery Ġhot Ġin Ġsummer . ĠSw imming Ġis |
|-------|-------------------------------------------------------------------------------|----
|It | [1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]| |
|âĢ | [0.84 0.16 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]| | V [:3] shape (12 x 64)
|Ļ | [0.65 0.19 0.16 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]| | --------------------------------
|s | [0.54 0.23 0.17 0.06 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]| | [ 0.0452, 0.0628, 0.1463, ...]
|Ġvery | [0.54 0.10 0.07 0.13 0.16 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]| | [-0.1361, 0.1379, 0.0150, ...]
|Ġhot | [0.29 0.14 0.16 0.11 0.05 0.25 0.00 0.00 0.00 0.00 0.00 0.00 0.00]| 13 [ 0.0039, -0.1295, -0.0311, ...]
|Ġin | [0.36 0.13 0.16 0.08 0.14 0.11 0.02 0.00 0.00 0.00 0.00 0.00 0.00]| rows [-0.0391, 0.0581, 0.0511, ...]
|Ġsummer| [0.26 0.08 0.09 0.10 0.12 0.15 0.08 0.12 0.00 0.00 0.00 0.00 0.00]| | [ 0.0963, 0.3563, -0.1477, ...]
|. | [0.40 0.17 0.07 0.06 0.08 0.09 0.01 0.10 0.01 0.00 0.00 0.00 0.00]| | ...
|ĠSw | [0.27 0.09 0.05 0.09 0.05 0.14 0.09 0.07 0.07 0.08 0.00 0.00 0.00]| | [ 0.2759, -0.2783, 0.3240, ...]
|imming | [0.19 0.04 0.08 0.14 0.06 0.10 0.14 0.06 0.13 0.04 0.02 0.00 0.00]| | [-0.0633, 0.0431, -0.0422, ...]
|Ġis | [0.30 0.10 0.06 0.04 0.11 0.05 0.01 0.07 0.02 0.05 0.12 0.04 0.00]| | --------------------------------
|Ġa | [0.25 0.11 0.09 0.03 0.05 0.09 0.02 0.10 0.02 0.06 0.12 0.04 0.03]| | |<------- columns: 64 ------->|
|-------|-------------------------------------------------------------------------------|----
|<---------------------------------- columns: 13 ------------------------------>|
2. 8 计算 Contextual Embeddings
\(\text{Contextual Embeddings} \)
\(= \text{softmax}(\frac{QK^T}{\sqrt{d}} + \text{mask})V \)
Token | Contextual Embedding (12 x 768)
--------------------------------------------
It | [ 0.0452, 0.0628, 0.1463,...]
âĢ | [ 0.0153, 0.0752, 0.1247,...]
Ļ | [ 0.0034, 0.0464, 0.0923,...]
s | [-0.0082, 0.0464, 0.0801,...]
Ġvery | [ 0.0218, 0.1029, 0.0621,...]
Ġhot | [ 0.0327, 0.0892, 0.0409,...]
Ġin | [ 0.0249, 0.0964, 0.0329,...]
Ġsummer | [ 0.0583, 0.1195, 0.0068,...]
. | [ 0.0334, 0.1100, 0.0366,...]
ĠSw | [ 0.0086, 0.0846, 0.0074,...]
imming | [-0.0049, 0.0841, -0.0339,...]
Ġis | [ 0.0410, 0.0706, 0.0077,...]
Ġa | [ 0.0427 , 0.0503 , 0.0080,...]
所以,这一步中,“Contextual Embeddings”最后一行的计算,需要前面 Softmax Masked Attention Score Matrix 的最后一行,此外还需要整个 V 矩阵。这就是为什么 V 矩阵是需要缓存的。
此外可以看到,在这个自回归的计算中,Q 矩阵前面的所有行(即上一轮计算的Q矩阵)都用不上,这也是为什么 Q 矩阵不需要缓存,即我们需要的“KV Cache”,而不是“QKV Cache”的原因。
3. 计算图示
这里依据使用了图示的方式展示了在“自回归”过程中的数学计算。在下图中,第一个生成的 Token 为“ a”,该 Token 在进入 Decoder 模型再次进行计算时(即“自回归”),下图中:
- 粉红色背景部分为新的、需要计算的部分;
- 灰色背景部分为虽然不需要计算,但在计算新的内容时,需要使用的部分。
灰色部分即为“KV Cache”需要缓存的部分。即,每一个 Token 对应的 “K”、“V” 矩阵都需要在后续的计算中使用。亦即,每一个 Token 的 Key 向量都需要保存,用于与新的 Token 的 Query 向量进行点击计算“关注度”值;每一个 Token 的 V 向量也需要保持
在上述的计算中,注意到,在一次的新的“自回归”中,最终需要额外计算的就是新Token(这里是“ a”)对应的 Centextual Embedding,该内容计算,需要使用前述所有 Token 对应的 K、V 值,即这里的 K 和 V 矩阵。
所以,在一次自回归推理中,最好上一次计算的所有 Token 的 K、V 向量都缓存起来,避免重复计算。本次自回归中计算新Token的对应的 K、V 向量也需要缓存,以供后续使用。
4. KV Cache 的内存消耗
在推理优化中,一个重要硬限制便是GPU卡的显存(memory)大小。当前,主流的企业级显卡H100显存为80GB,高端显卡 H200 显存为141 GB。现在的 LLM 参数量通常巨大,参数加载就需要耗费巨大的显存,以最新的 llama 4 17B为例,考虑 FP16 (半精度)考虑,则需要消耗约 30+ GB 。卡片上剩余的内存,才是用于实际的推理使用。而每次推理,例如提示词是1000个Token,输出也是1000个Token,那么,在生成最后一个Token的时候,需要的内存(按5%的经验值计算)约为1.5GB。这时候,单个H100的显卡也只能支持约33个并发,实际的情况则要考虑系统内存等,会比这个预估多很多。
在这篇文章:Mastering LLM Techniques: Inference Optimization@developer.nvidia.com 中也类似的估算:
- 7B 的模型(如Llama 2 7B),参数是16位(FP16 or BF16)则参数需要消耗约 14 GB 显存
- Token 数为4096的推理(decoder),则需要约 2 GB KV Cache
从上述粗略的预估可以看得出来,高效使用显存资源对于 LLM 推理来说至关重要。所以,各推理框架则会通过各种方法尝试去优化“KV Cache”以降低显存使用。这些方法包括“量化”(Quantization)、MQA/MGA 等。
5. Multi-Query Attention/Group-Query Attention
可以看到,无论是在模型参数加载的时候,还是推理 KV Cache 阶段,都需要大量的显存。关于 MQA 和 GQA 的经典论文是:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints。
5.1 关于MQA与GQA
Multi-Query Attention 则尝试通过减少 \(W^K \, W^V \) 参数的数量来减少上述显存,从而增加推理速度与并发能力。参考下图,可以看到在每一个 Layer 中,所有的 Head 共享一组 \(W^K \, W^V \) 参数,那么这两个相关参数就减少到了原来的 \(\frac{1}{h} \)。
更进一步的,为了减少上述方法(MQA)对于模型效果的影响,另一个优化是 Group-Query Attention。即如下图,一组 Heads 共享一组 \(W^K \, W^V \) 。可以依照分组的大小,以平衡模型效果与资源使用。如果一个 Head 一组 \(W^K \, W^V \) 则退化到普通的 Multi-Head Attention;如果所有 Heads 分到一组,则退化到普通的 Multi-Query Attention。
5.2 模型训练 Uptraining
此外,比较关键的,论文提出了一些关于 GQA 架构的训练优化。
例如,从一个 MHA 架构开始训练,然后从某个 checkpoint 开始,将MHA模型改成GQA模型,在初始化分组参数时,则使用原 MHA 模型中参数去求一个均值的方式初始化GQA中对应的 \(W^K \, W^V \) 。然后继续使用语料库对于该新模型训练。
论文指出,这时候只需要使用非常少的计算资源就可以训练处效果还不错的GQA新模型。新的GQA模型,则可以使用更少的显存资源,有更好的并发吞吐能力,同时也达到还比较好的效果。































