0


[LLM性能优化]聊聊长文本推理性能优化方向

[LLM性能优化]聊聊长文本推理性能优化方向

原创 阿杰 大模型新视界 2024年07月09日 08:00 四川

原文:https://zhuanlan.zhihu.com/p/698308542

近期,LLM 的长文本能力越来越受到关注。LLM 处理长文本的能力可以应用在多个应用场景中,例如 LLM Agent 场景:假设 Agent 会调用不同的工具解决用户给出的任务,所以当用户对 Agent 提出一个任务时,Agent 会先调用一次 LLM,对给定的任务生成一系列的 Funtion Call,然后依次调用不同的 Funtion,Agent 将 Funtion 的所有输出结果作为输入,再调用一次 LLM,生成最终呈现给用户的自然语言。其中,通过 Function 返回的结果可能很长,多个Function结果拼起来可能是一个很长的输入,这样 Agent 模型就需要具备长文本处理能力。除了 Agent 以外,RAG、文本摘要都需要 LLM 模型具备长文本处理能力,这些应用在落地时需要 LLM 推理服务具备很高的长文本推理效率。笔者之前在介绍 vLLM 的文章中 介绍过与 LLM 推理服务性能最关键的因素:KVCache 显存大小。KVCache 显存占用的计算公式为:

图片

其中:

图片

通过公式(1)我们可以看到, LLM 推理服务为每个 Token 分配了显存空间,其显存空间大小与模型层数、头维度、头数量以及KVCache存储的数据类型大小四个维度相关。通过公式(2)我们可以看到,KVCache 的总显存开销与 LLM 推理服务处理的总 Token 数以及每个 Token 所占的显存开销相关。在长文本推理服务场景中,我们的目标是在给定显存空间,推理服务要处理的 Token 数尽可能的多,即 T 尽可能的大。通过公式 2 可知,提升 T 有两种策略:

图片

下面将分别介绍这两种策略相关的工作。

1. 压缩 KVCache 长度

LLM 最核心的模块是 Attention:通过 Q 、 K 两个矩阵进行矩阵乘,计算O(T2)大小的注意力得分矩阵,并使用注意力得分矩阵对 V 矩阵求加权平均值,得到注意力层的输出矩阵。其中注意力得分矩阵每一行表示 Q 中每个 Token 和 K 中每个 Token 的相关性得分,注意力层输出矩阵的每一行表示每个 Token 与其他 Token 的加权平均值。即然是加权平均,说明每个 Token 的重要性并不相同,有的 Token 权重更大,而有的 Token 权重更小。那么为了压缩 KVCache 长度,可以将 KVCache 中一些权重小的 Token 剔除,不参与注意力计算,在保证模型效果的前提下压缩 KVCache 长度,从而在一定量的显存下保存更多的 Token,提高长文本的推理效率。那么压缩 KVCache 长度的问题就可以转化为寻找重要性 Token,通过算法设计找出一个序列中相关性更高的 Token。

1.1 静态 Token 稀疏化

图片

去年底 MIT 提出一种叫 StreamingLLM 的压缩 KVCache 长度的方法,并分析了四种 Attention 实现:

  • Dense Attention:即最原始的 Attention 实现。其计算复杂度为 O(T2),KVCache 存储复杂度为O(T)。由于复杂度比较高,T 在预训练的时候会比较小。在推理的时候,当文本长度超过了预训练时的长度,模型效果就会大幅降级,所以表现出来的 PPL 值也比较大;
  • Window Attention:通过我们平时语言习惯中可以知道,一段话中每个字之间的相关性差别很大,一般来说越相近的字相关性越强。基于这个假设,Window Attention 就被提出。每个 Token 只和邻近的 Token 做 Attention 计算,所以计算复杂度为 O(TL) ,KVCache 存储复杂度为O(L),其中 L 为窗口大小,是一个常数。这种方法极大的降低了 KVCache 的存储开销,从线性复杂度降低到常数复杂度。虽然KVCache的长度被压缩了,但是模型效果却不好,主要原因是最初始的 Token 被丢弃了。作者通过一些统计方法发现这种 Token 的重要性其实非常高,丢弃会严重影响模型效果。
  • Sliding Window Attention with recomputation:这种 Attention 与 Window Attention 类似,区别是 Sliding Window 不缓存窗口的 Tokens,而是重新计算窗口内的 KVCache。这种方法的计算复杂度为 O(TL2),KVCache 存储复杂度为O(L)。计算效率下降了,但模型效果对比Window Attention 高,PPL 值远低于 Window Attention。主要原因是重计算把窗口中的 Tokens 作为初始 Tokens 了,这样既保留初始 Tokens 又保证计算只在一个窗口内,大大降低了KVCache 存储复杂度。
  • StreamingLLM:这种策略在 Window Attention 的基础上,保留整个序列的初始 Tokens。每个 token 只和窗口内的 Tokens 以及序列的初始 Tokens 进行 Attention 计算,这样既保留 Window Attention的特性,即保证 KVCache 存储复杂度为O(L),计算复杂度为 O(TL) ,同时保证了模型效果不因丢失初始 Tokens 而大幅下降,PPL 和 Sliding Window Attention with recomputation 相似。

总的来说,上述提到的稀疏化方法都是静态的,即 Token 之间的相关性是一种固定的范式,每个 Token 的相关的 Token 都是固定距离的。

1.2 动态 Token 稀疏化

图片

动态 Token 稀疏化本质是为当前处理 Token 维护一个相关性高的历史 Token 集合,但与静态 Token 稀疏化不同,这个集合的构造不再由 Token 距离或者固定 Token 决定,而是设计一种算法去筛选历史的 Token。H2O 就提出了一种贪心的历史 Token 驱逐算法,具体流程如下:

  • 初始化相关性历史 Token 集合 S0 ,设置集合最大容量 k 。
  • 开始迭代更新历史 Token 集合:- 当历史 Token 集合大小等于最大容量 k 时,将当前 Token 加入历史 Token 集合中;- 当历史 Token 集合大小大于等于最大容量时,计算当前 Token 和历史 Token 集合中所有 Token 的相关性得分,分数越高相关性越低,将最低分的 Token 从集合中剔除,将当前 Token 加入到集合中。

相关性得分的计算开销很小,所以该方法的计算复杂度为 O(TL) ,KVCache 存储复杂度为 O(L)。作者也和StreamingLLM 对比了模型效果,在 QA 任务以及文档摘要任务均优于 StreamingLLM

图片

后续与动态 Token 稀疏化相关的工作可能会聚焦于模型效果方面,比如得分函数的设置以及剔除策略的改进。计算复杂度已经到线性复杂度,KVCache 存储复杂度已经到了常量复杂度,难以继续优化。

1.3 Prefix Caching 工程优化

Prefix Caching 是最近比较火的工程优化,其原理是复用 Tokens 的KVCache。LLM 推理服务在完成一个 request 计算后,不会马上将其 KVCache 释放,而是先缓存。当下一个请求到达时,不会直接进行 Prefill 计算,而是先在缓存中寻找最长公共前缀的 Tokens。如果存在,则直接复用缓存中的 KVCache,仅对剩余 Tokens 计算注意力以及 KVCache。通过复用 KVCache,可以达到两大目的

  1. 提升 Prefill 效率。由于参与 Prefill 的 Tokens 数减少,所以计算量下降,Prefill 的延时也就下降,直接提升 TTFT 性能。特别适合优化多轮对话场景的性能。
  2. 节省显存。当前大部分 LLM 应用在构造输入时一般遵循 system prompt + user prompt 范式。大部分的 system prompt 都是一致,这样不同的请求也就有相同的前缀,可以避免出现多个冗余的 system prompt KVCache,能提高服务的极限吞吐。

在实现高效的 Prefix Caching 功能时,需要考虑两个问题:

  1. 如何管理缓存 Tokens?应该保留哪些前缀?
  2. 能否针对 Prefix Caching 实现高效的 CUDA 算子?

Radix Attention

图片

针对第一个问题,lmsys 提出了 Radix Attention,设计了一种基于 LRU 的 Radix Tree 去管理前缀 Tokens。如上图所示,新到达的请求会先从 Radix Tree 中匹配最长前缀,寻找最远的一个节点。当请求仍有 Tokens 没匹配到已有的前缀,则会新增 Token 子节点。每次新增 Token 时,将新增 Token 节点标记为绿色,并将其前缀的节点标记为蓝色。这种标记主要是更新节点的访问时间,当缓存空间不足时,会将最近没被访问的节点清除。

Cascade Inference

针对第二个问题,FlashInfer 团队提出了 Cascade Inference,用于提升多个请求共享前缀的 Batch Decoding 过程。在介绍 Cascade Inference 前,我们先了解下 Batch Decoding Attention 的 CUDA 实现原理。

【 朴素 Batch Decoding Attention CUDA实现原理】先考虑 MHA。对于 Attention 运算,设 Q 在 T 维度上长度为 q_len , K 、 V 在 T 维度上长度为 kv_len ,那么 Attention 的三个输入矩阵的 Shape 分别为 [B,N,q_len,H]、[B,N,kv_len,H]、[B,N,kv_len,H]。其中 [B,N] 两个维度是 Batch 维度,可以并行计算,所以 Attention 算子的 CUDA Kernel 会将 Grid Size 设置为 [B,N],一个 Thread Block 计算一个序列的单个头注意力,所以一个 Thread Block 的三个输入矩阵 q,k,v 的 Shape 为 [q_len,H]、[kv_len,H]、[kv_len,H]。Attention 计算可以简单分解为以下三步:

图片

在 Decoding 阶段,q_len 为 1,所以上述提到的中间变量大小较小,可以放到 Shared Memory 中保存,那么 Batch Decoding Attention 的算术强度可以表示为:

Arithmetic intensity=2∗q_len∗kv_len∗H+q_len∗kv_len2∗q_len∗H+2∗kv_len∗H通过上述公式可以观察到,Batch Decoding Attention 的算术强度接近 1,属于 Memory Bound 计算。其主要的访存开销在 KVCache 上。当 Thread Block 在加载 [1,H] 大小的 q 矩阵时,需要加载 2 个 [kv_len,H] 大小的 、k、v 矩阵。所以,为了提升 Batch Decoding Attention 的性能,需要提升其算术强度。从公式可以观察到,当 kv_len 固定时,算术强度随着 q_len 上升而上升。所以只要可以在一个 ThreadBlock 中加载相同大小的 、k、v 矩阵情况下,加载多个 q 矩阵,并计算多个q 矩阵对应的结果,就能提升性能。那么针对多个请求共享前缀的场景,能否利用这个特性提高 CUDA Kernel 的性能呢?在解决共享前缀问题前,我们先看看 MQA 的优化思路。

如上文介绍,MHA 算子是 Q 矩阵一个序列的一个头与 K 、 V 矩阵一个序列的一个头进行计算。而 MQA 算子则不同,是 Q 矩阵一个序列的多个头与 K 、 V 矩阵一个序列的一个头进行计算,其输入的 Q 、K 、 V 三个矩阵的 Shape 分别为 [B,N,q_len,H]、[B,1,kv_len,H]、[B,1,kv_len,H]。假如将 MQA Kernel 的 Grid Size 设置为 [B,N],那么 K 、 V 矩阵在第二维度则需要 “expand” 操作,即将 1 份 [kv_len,H] 大小的 、k、v 矩阵拷贝 N 份,这样算术强度与 MHA 一样,在CUDA Kernel 层面没有任何性能提升。但是,如果我们利用 MQA 共享头的特性,将 Grid Size 设置为 [B] ,那么一个 Thread Block 的三个输入矩阵 q,k,v 的 Shape 将为 [N∗q_len,H]、[kv_len,H]、[kv_len,H], q_len 相比 MHA 实现"提升"了 N 倍,算术强度也提升接近N 倍,可以降低访存开销。并且由于q_len 提升 N 倍, Attention 中的两个 GEMV 变为 GEMM,可以使用 TensorCore 进一步提升 Kernel 性能。

【 共享前缀 Batch Decoding Attention CUDA实现原理】回到共享前缀的场景。当 Batch Decoding Attention 所有输入请求的前缀完全一致时,输入的 Q 、K 、 V 三个矩阵的 Shape 分别为 [B,N,q_len,H]、[1,N,shared_kv_len,H]、[1,N,shared_kv_len,H]。类似 MQA ,GridSize 可以设置为 [N] ,那么一个 Thread Block 的三个输入矩阵 q,k,v 的 Shape 将为 [B∗q_len,H]、[shared_kv_len,H]、[shared_kv_len,H], q_len 相比 MHA 实现"提升"了 B 倍,算术强度也提升接近B 倍,同样可以降低访存开销。以下是 Decoding 阶段的 MHA、MQA 以及 BatchPrefixAttention(BPA)的对比:
Attention 类型Grid Size设置算术强度MHA[B, N]O(1)MQA[B]O(N)BPA[N]O(B)
但 BPA 这个算子只适用于前缀完全一样时的 Decoding,当输入多个请求的后缀不同时,仍然需要使用常规的 Decoding Attention 算子,计算后缀的注意力,然后使用类似 FlashAttention 的将不同前缀、后缀的注意力结果进行合并。这就是 Cascade Inference 的做法。具体流程如下图所示:

图片

图中上半部分描述了常规的 Decoding Attention 的做法,下半部分描述了 Cascade Inference 的思想。Cascade Inference 主要包含三个步骤:

  1. MQA。对共享前缀部分使用 MQA 算子,提高算术强度;
  2. Batch Decode Attention。对不同的后缀使用常规的 Decoding Attention 算子,计算后缀的注意力;
  3. Merge State。以上两步仅完成序列中一部分的 Attention 结果,需要将 Attention 结果进行合并,合并方式与FlashAttention 类似。

下面是 Cascade Attention 的源码:

class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
    def forward(
        self,
        q: torch.Tensor,
        k_shared: torch.Tensor,
        v_shared: torch.Tensor,
        unique_kv_data: torch.Tensor,
        allow_fp16_qk_reduction=False,
        sm_scale: Optional[float] = None,
        rope_scale: Optional[float] = None,
        rope_theta: Optional[float] = None,
    ):
        # MQA
        V_shared, S_shared = single_prefill_with_kv_cache_return_lse(
            q,
            k_shared,
            v_shared,
            causal=False,
            pos_encoding_mode="NONE",
            kv_layout=self._kv_layout,
            allow_fp16_qk_reduction=allow_fp16_qk_reduction,
            sm_scale=sm_scale,
            rope_scale=rope_scale,
            rope_theta=rope_theta,
        )
        # Batch Decode Attention
        V_unique, S_unique = self._batch_decode_wrapper.forward_return_lse(
            q,
            unique_kv_data,
            pos_encoding_mode="NONE",
            sm_scale=sm_scale,
            rope_scale=rope_scale,
            rope_theta=rope_theta,
        )
        # Merge State
        merge_state_in_place(V_shared, S_shared, V_unique, S_unique)
        return V_shared

1.4 小结

标签: 性能优化

本文转载自: https://blog.csdn.net/sinat_37574187/article/details/140480497
版权归原作者 AI生成曾小健 所有, 如有侵权,请联系我们删除。

“[LLM性能优化]聊聊长文本推理性能优化方向”的评论:

还没有评论