上一篇Mamba的文章提到,S6 models这个名称的由来是:S4 models with a selection mechanism and computed with a scan。
所以,S6模型首先是选择机制:先前模型的一个关键限制对选择性复制和归纳等重要合成任务不够适用,所以设计选择机制,将SSM参数参数化为输入的函数。这允许模型过滤掉无关信息,并无限期地记住相关信息。
引入选择机制后,就需要硬件感知算法:所有先前的SSM模型必须是时间和输入不变的,才能实现计算效率;为了计算效率,引入选择机制后,就需要通过硬件感知算法克服了这一点,该算法使用扫描而不是卷积来递归计算模型。
选择机制相对容易理解。而对于S6模型的硬件感知设计,尤其是所谓的并行扫描,看论文没有看清楚,查了相关博客,再进行一下梳理。
硬件感知设计:并行扫描(parallel scan) + Flash Attention
并行扫描,同时还需要借鉴Flash Attention的方式,整体见图一,可以看到哪些处理在GPU SRAM中,哪些在HBM(高带宽内存)。
朴素的递归计算使用了O(BLDN) FLOPs,而卷积计算使用了O(BLD log(L)) FLOPs,前者具有较低的常数因子。因此,对于长序列和不太大的状态维度N,递归模式实际上可以使用更少的FLOPs。
这里的挑战就是,递归的顺序性导致计算效率低和大量的内存使用。需要像卷积模式一样的并行出来,但是因为不是时不变系统,卷积核随着输入在变,没法使用卷积的并行方式,文章给出了并行扫描的方式。
主要思路是利用现代加速器(GPU)的特性,在内存层次结构中更有效的级别中具体化状态h。特别是,大多数操作(除了矩阵乘法)都受到内存带宽的限制。这包括扫描操作,使用内核融合来减少内存IO的数量,与标准实现相比,可以显著加速。
具体来说,不是在HBM中加载数据大小为(B,L,D,N)的输入(A,B),而是直接将SSM参数(∆,A,B,C)从慢速HBM加载到快速SRAM中,在SRAM中进行离散化和递归,然后将数据大小为(B,L,D)的最终输出写回HBM。
图1:(概述)结构化SSM通过更高维的潜在状态h(例如N=4)独立地将输入x的每个通道(例如D=5)映射到输出y。先前的SSM通过需要时间不变性的巧妙替代计算路径来避免实现这种大的有效状态(DN,乘以批量大小B和序列长度L):(Δ,𝑨,𝑩,𝑪)参数在整个时间内是恒定的。我们的选择机制增加了与输入相关的动态,这也需要一个谨慎的硬件感知算法,只在GPU内存层次结构的更有效级别中实现扩展状态。
并行扫描
针对并行扫描,主要看了两篇博客。
第一篇博客
不使用卷积的并行扫描,具体是怎样操作?
参考这篇文章:https://jackcook.com/2024/02/23/mamba.html
Mamba 的涉及非常快速的 RNN 模式训练。递归与扫描算法(也称为前缀和)非常相似。 要计算前缀和,需要采用一个输入数组[x1,x2,x3,…,xn]并返回一个输出数组,其中每个元素是该项与之前项之和。 换句话说,输出的第一个元素将是x1,则第二个元素将为x1+x2,第三个x1+x2+x3等。 示例如下所示。
然后画出在 RNN 模式下更新 Mamba 隐藏状态的过程:
如果我们必须正式化前缀和,可以将其写出如下方程:
这个方程形成一个递归:在每个步骤中,通过将先前存储的值添加到当前输入来计算新值。 现在,再次看一下更新 Mamba 隐藏状态的递归。
虽然计算前缀和本质上似乎是连续的,但实际上有高效的并行算法来完成这项任务! 在下图中,我们可以看到一个并行前缀和算法,其中每条垂直线代表数组中的一个项目。
来看看这个算法是有效的:选择任何垂直线,从顶部开始,然后向下,将每个添加项追溯到数组的前几个项目。 当到达底部时,就有行左侧所有项目的总和。 例如,可以看到,在将第一个元素添加到开头的第二个元素之后,数组的第三个元素在末尾接收第二个元素的附加值。 因此,在并行扫描完成时,第三个元素包含第一个、第二个和第三个元素的总和。
如果在单个线程中运行这个算法,没有并行性,那么它比我们只按顺序将值相加需要更长的时间。 但 GPU 有很多处理器,允许高度并行计算。 因此,我们可以大致计算这个前缀求和(或扫描)操作O(logn)时间!
因此,Mamba 的作者意识到,如果想在 RNN 模式下高效训练,可以使用并行扫描。
就是说,对于递归 就只能单线程算, 就比较慢。其实加法具有结合律,先算前后都可以,就可以多线程并行计算,也就是所谓的并行扫描。
第二篇博客
上面的博客还是说的不够明白,再看第二篇:
(https://www.zhihu.com/question/644981978/answer/3405813530)的分析。
这篇文章指出,加法服从结合率,可以通过调整运算次序,实现并行加速。但是,Mamba的情况复杂一些,需要定义一种新的运算:
所以,之前是串行扫描,从H0到H3需要串行计算。
根据上面的交换律,可以多个进程并行运算。
下面手动推导一下,针对状态H3的计算,进行展开如下:
根据上面定义的二元运算,展开如下:
先并行计算1、2项和3、4项,然后将这两组并行运算的结果再次进行二元运算,因为其实没有A0,展开获得的四项与上面的推导一致。
版权归原作者 bylander 所有, 如有侵权,请联系我们删除。