前言
不知读者发现没有,本文标题的信息含量很大,比如
- 出来了一个新的序列模型:Mamba,其基于SSM或S4(Structured State Space for Sequence Modeling,连起来4个S,故简称S4)发展为S6(S4 models with a selection mechanism and computed with a scan),其对应的论文为《Mamba: Linear-Time Sequence Modeling with Selective State Spaces》
- 该Mamba模型的提出者为Albert Gu、Tri Dao,前者现在是CMU助理教授,多年来一直推动SSM架构发展,曾在DeepMind 工作,后者则为鼎鼎大名的Flash Attention一作 换言之,除了论文中展示的效果确实不错之外,由于提出者的背景不一般,所以关注的人比较多
- Transformer统治各大领域近7年了,7年来,挑战Transformer的模型其实不少 (比如linear attention, gated convolution and recurrent models, and SSMs),该模型能否真正颠覆Transformer的霸权呢?对此,我们可以细究其原理细节,看看其创新到底是否靠谱、力度是否大
加之有一大模型项目开发营的朋友问道,可否在论文100课上解读下Mamba这篇论文,于此,便有了此文,且具备3个特点
- 清晰易懂:也为「不需要天天看paper的朋友」而写 在ChatGPT诞生后的一年来,以大模型为代表的技术发展特别快,经常一个月会出来很多新的技术、模型 而不一定非得是每天在实验室扎根于科研的人 才有资格去追踪前沿技术发展,还有一大帮可能是出于对前沿技术的了解、兴趣、热爱、应用而想追踪,可这帮朋友平时或因工作或事太多而不一定对每个新技术、新模型都去看一遍论文,即不可能天天看paper 那咋办呢?他们可能通过一些比如公众号之类的文章去了解,但有的公号文章写的不错,有的则写的不够清晰易懂甚至漏洞百出,会因此让读到这种文章的朋友对新技术、新模型产生畏难心理甚至被误导 故,我和我司来了,为帮助更多朋友更好、更快、更细致的了解大模型相关技术及其实践,我个人算是笔耕不辍(我自23年年初以来也史无前例的写了近30篇,详见:大模型与ChatGPT系列:原理、论文、代码、应用**)、团队和我算讲课不停
- 中英对比:部分关键的阐述中英文对照学习 考虑到这些新技术、新模型刚推出的时候,论文还是相对最严谨的参考,所以本文会延续前几篇文章的风格:对于一些关键的阐述会把原英文的表述用斜体且淡色的黑体表示,毕竟有的描述对其翻译相比,用原英文阐述更精准
- 足够细致:从SSM、HiPPO、S4起步,逐步推导到Mamba 目前介绍mamba模型的文章,少部分写得很不错,大部分不是这个细节没深入,便是那个细节没深入,考虑到如果很多关键细节没有介绍的话,那没法彻底理解mamba模型 因此,本文会尽可能兼顾所有必须写清楚的细节(比如如果不理解SSM和S4则无法理解mamba模型,故本文会从HiPPO、SSM、S4起步,逐步推导到mamba),尽可能一文通透mamba模型
更新:考虑到之前本文的早期版本介绍的mamba前置知识不够彻底的清晰易懂,故24年3.2-3.5这4天把前置知识特别是ssm/S4介绍的更加细致(过程中参考此文:A Visual Guide to Mamba and State Space Models,有些图来自该文,有些内容翻译自此文),以让文科生都能一眼看明白
总之,看本文之前,你可能看到的很多关于mamba的文章都不知所云,但看了本文之后,你再看那些文章你会有一种“他如果怎样怎样写,会更加清晰易懂”的感觉,毕竟“好懂的文章”只有一个标准:就是能一直不烧脑的读下去而不卡壳
第一部分 基础回顾:Transformer时间复杂度、RNN
1.1 Transformer的二次复杂度
通过之前本博客内的另一篇文章《通透理解FlashAttention与FlashAttention2:全面降低显存读写、加快计算速度》,可知
简单理解的话,计算复杂度和序列长度的平方成正比,可以看一个小例子,比如两个相乘的矩阵大小分别为() 和(),矩阵乘法的一种计算方式是使用第一个矩阵的每一行与第二个矩阵的每一列做点乘
因为我们需要拿第一个矩阵的每一行去与第二个矩阵的每一列做点乘,所以总共就需要 次点乘。而每次点乘又需要 次乘法,所以总复杂度就为
精确理解的话,当输入批次大小为 ,序列长度为 时,
层transformer模型的计算量为 ,则代表词向量的维度或者隐藏层的维度(隐藏层维度通常等于词向量维度)但这个结果是怎么一步一步计算得到的呢?请看原文
正因为现有的ChatGPT等大模型处理长文本算力消耗巨大,背后原因是Transformer架构中注意力机制的二次复杂度
- 一方面,有了针对注意力机制的各种所谓魔改,甚至也有S4、FlashAttention及其二代等
- 二方面,S4、FlashAttention等作者提出了新的序列模型:Mamba,在很多语言任务上击败/匹配Transformer性能,具有线性复杂度和5倍推理吞吐量,下文详述
1.2 RNN
关于什么是RNN,我之前博客内的这篇文章《如何从RNN起步,一步一步通俗理解LSTM》中做了详细介绍,每一个时刻的隐藏状态都是基于当前的输入和前一个时刻的隐藏状态计算得到的
总之,RNN在序列中的每个时间步需要两个输入,即时间步的输入和前一个时间步的隐藏状态(a hidden state of the previous time step),以生成时的隐藏状态,最终预测输出(to generate the next hidden state and predict the output)
这一点值得好好体会:先根据输入和前一时刻的隐藏状态计算出最新的隐藏状态,便可以根据最新的隐藏状态预测出了
至于为何要先介绍RNN呢,很快你就会明白了
比如下图,每个隐藏状态都是所有先前隐藏状态的聚合,然最后一个隐藏状态在生成名称“ Maarten”时不再包含有关单词“Hello”的信息。随着时间的推移,RNN 往往会忘记信息,且RNN没法并行训练,相当于推理快但训练慢
第二部分 从状态空间模型SSM到S4的升级之路
注,如本文开头所述,本部分的核心参考为Maarten Grootendorst所写的《A Visual Guide to Mamba and State Space Models》,不少图来自该文,不少内容翻译自此文,至于原英文中有些表述不准确的地方,我则都已修正
2.1 什么是状态空间与状态空间模型
2.1.1 什么是状态空间
想象一下我们正在穿过一个迷宫,图中每个小框代表迷宫中的一个位置,并附有某个隐式的信息,例如你距离出口有多远
而上述迷宫可以简化建模为一个“状态空间表示state space representation”,每一个小框显示你当前所在的位置(当前状态current state)、下一步可以去哪里(未来可能的状态possible future states),以及哪些变化会将你带到下一个状态(向右或向左)
而描述状态的变量(在我们的示例中为 X 和 Y 坐标以及到出口的距离)可以表示为“状态向量state vectors”
2.1.2 什么是状态空间模型SSM
SSM 是用于描述这些状态表示并根据某些输入预测其下一个状态可能是什么的模型
一般SSMs包括以下组成
- 映射输入序列x(t),比如在迷宫中向左和向下移动
- 到潜在状态表示h(t),比如距离出口距离和 x/y 坐标
- 并导出预测输出序列y(t),比如再次向左移动以更快到达出口
然而,它不使用离散**序列(如向左移动一次),而是将连续**序列作为输入并预测输出序列
SSM 假设系统(例如在 3D 空间中移动的物体)可以通过两个方程从其在时间 时的状态进行预测「当然,其实下面第一个方程表示成这样可能更好:,不然容易引发歧义」
通过求解这些方程,可以根据观察到的数据:输入序列和先前状态,去预测系统的未来状态
2.1.3 SSM的两个方程:状态方程与输出方程
总之,SSM的关键是找到:状态表示(state representation),以便根据输入序列预测输出序列
而这两个方程也是状态空间模型的核心,且矩阵
A、
B、
C、
D都是可以学习的参数
- 第一个方程:状态方程,矩阵与输入相乘之后,再加上矩阵与*前一个状态*相乘的结果 换言之,矩阵影响输入,矩阵影响前一个状态,而指的是任何给定时间的潜在状态表示(latent state representation),而指的是某个输入「当然,还是上面那句话,表示成这样更好:**」
- 第二个方程:输出方程,描述了状态如何转换为输出(通过矩阵 ),以及输入如何影响输出(通过矩阵 )
2.1.4 建立对SSM中两个核心方程的统一视角
最终,我们可以通过下图统一这两个方程
为了进一步加深对该图的理解,我们一步一步拆解下
- 假设我们有一些输入信号,该信号首先乘以矩阵
- 上面第一步的结果,加上:上一个状态与矩阵相乘(*矩阵*描述了所有内部状态如何连接)的结果,用来更新状态state
- 然后,使用矩阵**来将状态转换为输出
- 最后,再利用矩阵 D提供从输入到输出的直接信号,这通常也称为跳跃连接skip-connection
- 由于矩阵 D类似于跳跃连接,因此在没有跳跃连接的情况下,SSM 通常被视为如下
回到我们的简化视角,现在可以关注只矩阵A、B、C构建的SSM核心
总之,这两个方程共同旨在根据观测数据预测系统的状态,且考虑到输入一般都是连续的,因此SSM的主要表示是连续时间表示(continuous-time representation)
2.2 SSM到S4的三步升级:离散化SSM、循环/卷积表示、基于HiPPO处理长序列
2.2.1 离散数据的连续化:基于零阶保持技术做连续化并采样
由于除了连续的输入之外,还会通常碰到离散的输入(如文本序列),因此如果模型也能处理离散化数据则再好不过。怎么做到呢?好在可以利用零阶保持技术(Zero-order hold technique)
- 首先,每次收到离散信号时,我们都会保留其值,直到收到新的离散信号,如此操作导致的结果就是创建了 SSM 可以使用的连续信号
- 保持该值的时间由一个新的可学习参数表示,称为步长(siz)——** **,它代表输入的阶段性保持(resolution)
- 有了连续的输入信号后,便可以生成连续的输出,并且仅根据输入的时间步长对值进行采样
这些采样值就是我们的离散输出,且可以按如下方式做零阶保持
它们共同使我们能够从连续 SSM 转变为离散SSM,使得不再是函数到函数x(t) → *y(t),而是序列到序列*x**ₖ → yₖ,所以你看到离散化的SSM时,不再带参数t了
这里,矩阵A和 B现在表示模型的离散参数(且这里使用*** ***而不是 来表示离散时间步长)
注意:我们在保存时,仍然保存
矩阵 A的连续形式(而非离散化版本),只是在训练过程中,连续表示被离散化(During training, the continuous representation is discretized)
2.2.2 循环结构表示The Recurrent Representation
总之,离散 SSM 允许可以用离散时间步长重新表述问题
在每个时间步,都会涉及到隐藏状态的更新(比如取决于和的共同作用结果,然后通过预测输出)
为方便大家理解其中的细节,我再展开一下 ![y_2](https://latex.csdn.net/eq?y_2)
有没有眼前一亮?如此,便可以RNN的结构来处理
然后
可以这样展开(其中,始终是和的共同作用之下更新的)
2.2.3 卷积结构表示The Convolution Representation
在经典的图像识别任务中,我们用过滤器(*即卷积核kernels)*来导出聚合特征,而SSM也可以表示成卷积的形式
由于我们处理的是文本而不是图像,因此我们需要一维视角
而用来表示这个“过滤器”的内核源自 SSM 公式
- 与卷积一样,我们可以使用 SSM 内核来检查每组token并计算输出
- 内核将移动一次以执行下一步的计算
- 最后一步,我们可以看到内核的完整效果:至于上图中的是咋计算得到的,别忘了我上面推导出来的
总结一下,将 SSM 表示为卷积的一个主要好处是它可以像卷积神经网络CNN一样进行并行训练。然而,由于内核大小固定,它们的推理不如 RNN 那样快速
最终,SSM可以视为从输入信号到输出信号的参数化映射
- SSMs可以当做是RNN与CNN的结合「*These models can be interpreted as acombination of recurrent neural networks (RNNs) and convolutional neural networks (CNNs)*」,即推理用RNN,训练用CNN
- 总之,这类模型可以非常高效地计算为递归或卷积,在序列长度上具有线性或近线性缩放(This class of models can be computed very efficiently as either arecurrence or convolution, with linear or near-linear scaling in sequence length)
2.2.4 矩阵A的问题与其解决之道——HiPPO
如我们之前在循环表示中看到的那样,矩阵捕获先前previous状态的信息来构建新状态(,当k = 5时,则有)
其实,某种意义上,算是矩阵A产生了隐藏状态(matrix A produces the hidden state)
由于矩阵A只记住之前的几个token和捕获迄今为止看到的每个token之间的区别,特别是在循环表示的上下文中,因为它只回顾以前的状态
那么我们怎样才能以保留比较长的memory的方式创建矩阵A呢?
答案是可以使用Hungry Hungry Hippo((High-order Polynomial Projection Operator,简称H3),HiPPO尝试将当前看到的所有输入信号压缩为系数向量(HiPPO attempts to compress all input signals it has seen thus far into a vector of coefficients)
它使用矩阵构建一个“可以很好地捕获最近的token并衰减旧的token”状态表示(to build a state representation that captures recent tokens well and decays older tokens),其公式可以表示如下
具体表示可以如下图所示
正由于HiPPO 矩阵可以产生一个隐藏状态来记住其历史(从数学上讲,它是通过跟踪Legendre polynomial的系数来实现的,这使得它能够逼近所有以前的历史),使得在被应用于循环表示和卷积表示中时,可以处理远程依赖性
如此,S4的定义就出来了:序列的结构化状态空间——Structured State Space for Sequences,一类可以有效处理长序列的 SSM
2.3 SSM的问题:矩阵参数固定不变,无法针对输入做针对性推理
首先,Linear Time Invariance(LTI)规定 SSM中的A、B、C始终是固定不变的参数。这意味着
- 对于 SSM 生成的每个token,矩阵A 、B、C都是相同的(regardless of what sequence you give the SSM, the values of A,B,and C remain the same. We have a static representation that is not content-aware)
- 使得SSM无法针对输入做针对性的推理「*since it treats each token equally as a result of the fixed A, B, and C matrices. This is a problem as we want the SSM to reason about the input (prompt)*」
此外,如下图所示,无论输入x 是什么,矩阵 B都保持完全相同,因此与x无关
同样,无论输入如何,A和C也保持固定
第三部分(选读) Mamba一作Albert Gu对S4的阐述
注,本部分只作为选读,因为本部分要介绍的重点 上文已经介绍过了,但为何还是要增加这个选读部分呢,一者 本部分来自mamba论文的一作Albert Gu的解读,虽然其公式表达不如上文第二部分的表达顺眼(比如状态被他改写成x,输入被他改写成u),但有些论文的表达还是用的Albert Gu的这个表述,故权衡利弊,还是增加本部分
3.1 S4的前身:HiPPO
3.1.1 改进transformer不擅长处理超长的序列的问题:输入u到状态x
如本文开头所说,mamba论文的一作Albert Gu多年来一直在推动SSM的发展
- 他在SSM的基础上,通过此篇论文《Efficiently Modeling Long Sequences with Structured State Spaces》首次提出了结构化状态空间S4(这里有关于S4的更多论文),但这篇论文的可读性比较差
- 好在作者在YouTube上有一个关于这篇S4论文的精彩解读,下面便以他这个解读视频梳理一下(以下PPT截取自该解读视频中)
简单来讲,序列数据一般都是离散的数据 比如文本、图、DNA
- 但现实生活中还有很多连续的数据,比如音频、视频,对于音视频这种信号而言,其一个重要特点就是有极长的context window
- 而在transformer长context上往往会失败,或者注意力机制在有着超长上下文长度的任务上并不擅长(所以你才看到各种对注意力机制的改进,比如flashattention等等,即便如此一般也就32K的上下文长度,在面对100w的序列长度则无能为力),而S4擅长这类任务
为了方便大家更好的理解,Albert Gu举了一个金融领域的例子
- 即根据输入,计算其EMA(如下图所示,黑色的一直在跳跃着的曲线是输入x,输出y是蓝色的线) 由于EMA(Exponential Decaying Measure)有着unbounded context(无限长度),Transformers和Convolution因为都只有着有限的上下文窗口而不好计算
- Albert Gu发现EMA其实是整个signal的一个summary,相当于是过往所有信号历史的加权平均值,其权重呈指数衰减之势(下图中绿色的线即相当于投影到的指数衰减)
- 如果用**表示input*,且表示对应的summary(*可能你看到这里 觉得表示有点乱,包括很快你还会看到:**输入u、状态x、输出y,其实刚好就是**和上文第二部分的表述反过来了,上文第二部分是用的表示的summary,x表示原始输入) 那么该summary可以在常数时间内快速计算得到(即summary of entire context update in constant time): 这个summary作为对之前信息的一个总结,也可以认为是对“当前事物所处在一个什么样的状态”的建模,而随着新信息的不断输入,那么当前事物所处的状态也会不断更新
July注:其实如果用h表示对应的summary,会更清晰,如此,也和上文的第二部分的表达统一起来了,却非得用x 表示对应的summary..
3.1.2 HiPPO的定义与推导:state compresses the history of input
假设 时刻我们看到了原始输入信号 的之前部分:
- 我们希望在一个memory budget来压缩前面这一段的原始input来学习特征,一个很容易想到的方法是用多项式去近似这段input
- 在我们接收到更多signal的时候,我们希望仍然在这个memory budget内对整段signal进行压缩,自然,你得更新你的多项式的各项系数,如下图底部所示
- 以上,会涌现出两个问题: 1. 如何找到这些最优的近似? 2. 如何快速地更新多项式的参数? 为了解决这两个问题,我们需要一个measure去定义一个近似的好坏程度。例如,可以使用EDM
- 这就引出了HiPPO的正式定义,其为两个信号和两个矩阵的组合: 可能你已经看出来了,如果把上图的、改由、表示,原始输入改由表示,则不就是上文介绍过的下图这个表达式么?而且还是下图的表达更顺眼些,是不,^_^ 而这个矩阵A就是HiPPO矩阵,比如可以是这样:
- HiPPO相当于将函数映射到函数,这里给个通俗的例子解释一下: 如本部分上面所述,这里的**是原信号,是压缩后的信号(对应上文第二部分的状态)** 给定一个持续增长的,HiPPO允许online update压缩的。如果使用一个64unit的polynomial压缩器(完全表示需要10000unit,所以是非常高度的压缩),可以发现EDM很不错,保留了大量之前的信息: 其中红色的线相当于对输入的重建(*可以看出来,离当下最近时刻的 其刻画最准确,至于离当下最远的时刻 则其刻画的不那么准确 *) 这里要注意,HiPPO只需要看到这个时刻的多项式(polynomial)参数和在此之前的signal ,不需要看到之前的多项式参数..
- 上面都是用EDM这个measure的,但是我们在学习过程中用的往往不只一个measure(例如一个time-varying measure can change over time),这个时候如何去建模? 最终,作者得到了一个结论:HiPPO可以在各种measure上面成立
以下内容来自H3的论文:《Hungry Hungry Hippos: Towards Language Modeling with State Space Models》,仅作为扩展阅读
如下图所示
- 图左:H3利用移位和对角矩阵堆叠两个离散SSM,并通过“输入投影和它们的输出之间的乘法交互”来模拟序列中点之间的比较Left: H3 stacks two discrete SSMs with shift and diagonal matrices and uses multiplicative interactions between input projections and their outputs to model comparisons between points in a sequence. 具体而言 (i)为了记忆过去的token,我们希望状态xi从输入ui中复制,并将该信息传递到下一个状态xi+1。由于xi+1与xi相关联通过矩阵,我们使用离散SSM并引入移位矩阵A,该矩阵可以移动状态向量的元素(例如将[a,b,c]映射为[0,a,b]) (ii)为了比较序列中的token,使用乘法交互:将包含来自先前时间步骤信息的SSM输出与当前时间步骤输入相乘,从而衡量token之间的相似性 H3受到线性注意力的启发:通过投影输入u得到三个信号Q、K、V。然后,在非线性函数φ(K)处使用一个SSM替代它,并且在求和Si(SSMdiag)处使用带有对角线A的SSM进行替换。对于头维度的情况下输出结果为
3.2 S4的推出:Structured State Space for Sequences
3.2.1 HiPPO的高阶化(输入u到状态x最后输出y)
发现HiPPO在低阶信号上work后,我们希望将它扩展到高阶信号上。阶数越高,与LLM越相似,工作的价值就越大
- 但是我们不能直接堆叠HiPPO算子,因为不断增加维度会引起维数爆炸:
- 作者想到了非常精妙的一个方法,如下图所示,通过蓝色state 的线性组合得到最终的输出红色,至于 是skip connection,是绕开state 直接从**input **到输出 的一个连接 而如果改用上文第二部分的表达,则如下图所示(state 改由表达,input 改由表达)最终把这两个方程统一放到一块,便是上文第二部分所述的这个图
- 这样,我们通过两个方程定义S4 一个是之前定义的 (下一时刻的 ) 来将input 记忆成state,如下图左侧所示 现在又定义了 来将state 线性组合成一个输出,如下图右侧所示
- 有意思的是,推出来的这些公式组成了一个1960年在ASME会议上提出的State Space Machine! SSM由Kalman提出,原文在这:A New Approach to Linear Filtering and Prediction Problems 而我们关注的S4不就是基于「上图 + A B C D这4个矩阵」而发展出来的么(当然,下图是用的上文第二部分的表达)
3.2.2 Structured SSM
我们正式定义下S4
- 首先,有一个state space model,简称为SSM
- 其次,在下图所示的两个方程中插入特定的矩阵值
- 接着,学习对应的参数
3.3 S4的性质:连续的表示、用Recurrent快速infer、用Convolutional快速训练
接下来,我们来看下如下图所示的S4的三个性质
也可以用下图表示
3.3.1 SMM的连续表示与离散表示
第一个性质是连续的表示,且就算SSM在离散数据上训练,它仍能学习到底层蕴含的连续信息,因为在SSM眼里,sequence不过是连续信号signal的采样(离散形式),或者说连续的信号模型是离散的序列模型的概括
连续时间状态空间表示 通过状态变量定义了从输入信号(作为时间t的函数)到输出信号的线性映射
对于某些矩阵![\mathbf{A} \in \mathbb{R}^{m \times m}](https://latex.csdn.net/eq?%5Cmathbf%7BA%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7Bm%20%5Ctimes%20m%7D)、![\mathbf{B} \in \mathbb{R}^{m \times 1}](https://latex.csdn.net/eq?%5Cmathbf%7BB%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7Bm%20%5Ctimes%201%7D)、![\mathbf{C} \in \mathbb{R}^{1 \times m}](https://latex.csdn.net/eq?%5Cmathbf%7BC%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7B1%20%5Ctimes%20m%7D)和![\mathbf{D} \in \mathbb{R}^{1 \times 1}](https://latex.csdn.net/eq?%5Cmathbf%7BD%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7B1%20%5Ctimes%201%7D),满足以下微分方程:
- 同样地,离散时间状态空间表示 通过状态变量定义了从离散输入信号 (i=1, 2, ...)到离散输出信号的线性映射,
状态空间模型(SSM)将这些表示作为深度学习管道中的一层(A state-space model (SSM) uses these representations as a layer in a deep learning pipeline),并且矩阵是根据数据进行学习得到的(例如基于梯度优化),通常有个这样的SSM并行存在,每个对应一个隐藏维度
- 为了保留序列历史信息,在HiPPO[24]中采用正交多项式投影历史数据,并转换成具有特殊初始化矩阵A和B的SSM形式(*To preserve the sequence history, HiPPO [24] projects the history on a basis of orthogonal polynomials, which translates to having SSMs whose A, **B *matrices are initialized to some special matrices)
- SSM以循环方式允许高效推断(即生成):为了生成下一个时间步的输出,只需要当前时间步的状态而不是整个输入历史记录(This recurrent form of SSMs allows efficient inference (i.e., generation): to generate the output of the next time-step, one only needs the state of the current time-step, not the entire input history)
3.3.2 用Recurrent表示进行快速的infer
第二个性质是有效的online计算,这点之前在HiPPO提到了,就是计算下一时刻的state 只需要这一时刻的state 和全局输入
虽然需要全局输入,但是这个全局的计算是常数时间的,这与RNN相同,而与Transformer/CNN不同
之所以是常数时间,也与RNN相同,因为有state(中间这条蓝线),这导致下一个state的计算只需要上一个state + 全局的输入
3.3.3 用Convolutional表示进行快速的训练
SSM的一个问题是,当知道未来的signal的时候,训练是低效的。有没有办法并行化SSM?作者提出了使用一个卷积核 ,绕过状态 ,直接从输入 到输出 (而非先输入到状态、状态再到输出)
输入怎么到输出呢?相当于通过特定的卷积滤波器K对输入进行卷积(即you can involve the input by an exponentially decaying convolution kernel),该滤波器在上图中用绿色线表示
问题好像解决了,但SSM还是存在两个问题
- 一个是计算复杂度的问题,最终通过给SSM做结构化(比如使用HiPPO矩阵,相当于变成了S4),即structured state space can be computed faster
- 另一个是,作者意识到这个S4某种意义上就是一个很fancy的CNN(包括可以以不同的方式参数化卷积内核),但是context window有时是无限长的 而刚好convolutional kernel可以无限长(至于单纯的CNN则是有限长的窗口),那其如何设计以适应有时无限长的context window呢?如下图所示
第四部分 Mamba的组成结构与原理解析
mamba(其对应论文为:Mamba: Linear-Time Sequence Modeling with Selective State Spaces,这是其对应的GitHub代码地址),在语言、音频、DNA序列模态上都实现SOTA,在最受关注的语言任务上,Mamba-3B超越同等规模的Transformer,与两倍大的Transformer匹敌,并且相关代码、预训练模型checkpoint都已开源
简言之,Mamba是一种状态空间模型(SSM),建立在更现代的适用于深度学习的结构化SSM (简称S6)基础上,与经典架构RNN有相似之处
4.1 Mamba = 有选择处理信息 + 硬件感知算法 + 更简单的SSM架构
与先前的研究相比,Mamba主要有三点创新:
- 对输入信息有选择性处理(Selection Mechanism) 相比SSM压缩所有历史记录(当然,transformer则是不压缩所有历史记录),mamba设计了一个简单的选择机制,通过“参数化SSM的输入”,以便关注或忽略特定的输入。这样一来,模型能够过滤掉与问题无关的信息,并且可以长期记住与问题相关的信息
- 硬件感知的算法(Hardware-aware Algorithm) 该算法采用“并行扫描算法”而非“卷积”来进行模型的循环计算,但为了减少GPU内存层次结构中不同级别之间的IO访问,它没有具体化扩展状态 当然,这点也是受到了S5(Simplified State Space Layers for Sequence Modeling)的启发
- 更简单的架构 将SSM架构的设计与transformer的MLP块合并为一个块(combining the design of prior SSM architectures with the MLP block of Transformers into a single block),来简化过去的深度序列模型架构,从而得到一个包含selective state space的架构设计
4.1.1 选择性状态空间模型:从S4到S6
作者认为,序列建模的一个基础问题是把上下文压缩成更小的状态(We argue that a fundamental problem of sequence modeling is compressing context into a smaller state),从这个角度来看
- 注意力机制虽然有效果但效率不算很高,毕竟其需要显式地存储整个上下文(storing the entire context,也就是KV缓存),直接导致训练和推理消耗算力大 好比,Transformer就像人类每写一个字之前,都把前面的所有字+输入都复习一遍,所以写的慢
- RNN的推理和训练效率高,但性能容易受到对上下文压缩程度的限制On the other hand, recurrent models are efficient because they have a finite state, implying constant-time inference and linear-time training. However, their effectiveness is limited by how well this state has compressed the context. 好比,RNN每次只参考前面固定的字数(仔细体会这句话:When generating the output, the RNN only needs to consider the previous hidden state and current input. It prevents recalculating all previous hidden states which is what a Transformer would do),写的快是快,但容易忘掉更前面的内容
- 而SMM的问题在于其中的矩阵A B C始终是不变的,无法针对不同的输入针对性的推理,详见上文的2.4节
- 最终,Mamba的解决办法是,让模型对信息有选择性处理,可以关注或忽略特定的内容,即使状态大小固定也能压缩上下文 好比,Mamba每次参考前面所有内容的一个概括,越往后写对前面内容概括得越狠,丢掉细节、保留大意
总之,序列模型的效率与效果的权衡点在于它们对状态的压缩程度:
- 高效的模型必须有一个小的状态(比如RNN或S4)
- 而有效的模型必须有一个包含来自上下文的所有必要信息的状态(比如transformer)
而mamba为了兼顾效率和效果,选择性的关注必须关注的、过滤掉可以忽略的
为方便大家理解,再进一步阐述mamba与其前身结构化空间模型S4的优势
首先,在其前身S4中,其有4个参数(∆, A, B, C)
且它们都是固定的,不随输入变化(即与输入无关),这些参数控制了以下两个阶段
- **第一阶段(1a 1b)*,通常采用固定公式A = 𝑓𝐴(∆, A)和B = 𝑓𝐵(∆, A, B),将“连续参数”(∆,A,B)转化为“离散参数”(A,B),其中(𝑓𝐴, 𝑓𝐵) 称为离散化规则,且可以使用多种规则来实现这一转换The first stage transforms the “continuous parameters” (∆, A, B) to “discrete parameters” (A, B) through fixed formulas **A **= 𝑓𝐴(∆, A) and **B = 𝑓𝐵(∆, A, B), where the pair (𝑓𝐴, 𝑓𝐵) is called a discretization rule* 例如下述方程中定义的零阶保持(ZOH)*Various rules can be used such as the zero-order hold (ZOH) defined in equation (4).*
- **第二阶段(2a 2b,和3a 3b)**,在参数由(∆,A, B, C)变换为(A, B, C)后,模型可以用两种方式计算,即线性递归(2)或全局卷积(3) After the parameters have been transformed from (∆, A, B, C) ↦ (A, B, C), the model can be computed in two ways, either as a linear recurrence (2) or a global convolution (3) 如之前所说的 模型通常使用卷积模式(3)可以进行高效的并行化训练「 *其中整个输入序列提前看到,为何可以做高效的并行化呢,因为该模式能够绕过状态计算,并实现仅包含(B, L, D)的卷积核(3a),即Thus the more efficient convolution mode wasintroduced which could bypass the state computation and materializes a convolution kernel (3a) of only (𝙱, 𝙻, 𝙳)*」 并切换到循环模式(2)以高效的自回归推理(其中输入每次只看到一个时间步) the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time), and switched into recurrent mode (2) for efficient autoregressive inference (wheret he inputs are seen one timestep at a time)
下面,再分析下各个变量的含义
- ,一个标量,类似遗忘门 如sonta所说,这个量跟RNN里的gating有着深刻的联系(∆ in SSMs can be seen to play a generalized role of the RNN gating mechanism) 即data dependent的 Δ 跟RNN的forget gate的功能类似(***step size *Δ that represents the resolution of the input discretization of SSMs is the principled foundation of heuristic gating mechanisms.)
- ,起到的作用类似于:进RNN的memory,起到的作用类似于:取RNN的memory咋理解?我拿出上文第二部分的这个图 一摆,就一目了然了 所以有人说,data dependent的的功能跟RNN的input/output gate类似
- ,意味着对应这个维度的SSM来说,A在每个hidden state维度上的作用可以不相同,起到multi-scale/fine-grained gating的作用,这也是LSTM网络里面用element-wise product的原因
其次,通过之前的讲解,可知矩阵都可以由个数字表示(*the **A **∈ ℝ𝑁×𝑁, **B **∈ ℝ𝑁×1 , **C *∈ ℝ1×𝑁 matrices can all be represented by 𝑁 numbers.)
为了对批量大小为、长度为、具有个通道(类似R G B三个通道)的输入序列进行操作,SSM被独立地应用于每个通道(To operate over an input sequence 𝑥 of batch size 𝐵 and length 𝐿 with 𝐷 channels, the SSM is applied independently to each channel)
请注意,在这种情况下,每个输入的总隐藏状态具有维,在序列长度上计算它需要时间和内存(the total hidden state has dimension 𝐷𝑁 per input, and computing it over the sequence length requires 𝑂(𝐵𝐿𝐷𝑁) time and memory)
最后,在Mamaba中,作者让这些参数矩阵、矩阵、成为输入的函数(即可学习或可训练的),让模型能够根据输入内容自适应地调整其行为
- 从S4到S6的过程中 影响输入的矩阵、影响状态的矩阵的大小从原来的(D,N)「其中,D指的是输入向量的维度,比如一个颜色的变量一般有R G B三个维度,N**指SSM的隐藏层维度hidden dimension,当然 一般设的比较小」 变成了(B,L,N)「这三个参数分别对应batch size、sequence length、hidden state size」 且的大小由原来的D变成了(B,L,D) 且每个位置的矩阵、矩阵、都不相同,这意味着对于每个输入token,现在有不同的矩阵、矩阵,可以解决内容感知问题 进一步,咱们通过 来逐一将数据依赖化(data dependent)「其中的这个代表把D维的输入向量经过一个线性层映射到维」
- 虽然没有变成data dependent,但是通过SSM的离散化操作之后,会经过outer product变成(B, L, N, D)的data dependent张量,算是以一种parameter efficient的方式来达到data dependent的目的当然,到底效果变好的最大原因是哪一块,可以参考这篇做下相关的实验:Gated Linear Attention Transformers with Hardware-Efficient Training
总之,Mamba通过合并输入的序列长度和批量大小来使矩阵B和C,甚至步长Δ取决于输入(其意味着对于每个输入token,现在有不同的B和C矩阵,可以解决内容感知问题),从而达到选择性地选择将哪些内容保留在隐藏状态以及忽略哪些内容的目标
至于步长Δ,较小的步长Δ会也能做到忽略特定单词,而更多地使用先前的上下文,而较大的步长Δ会更多地关注输入单词而不是上下文
4.1.2 并行扫描(parallel scan)算法
由于A B C这些矩阵现在是动态的了,因此无法使用卷积表示来计算它们(CNN需要固定的内核),因此,我们只能使用循环表示,如此也就而失去了卷积提供的并行训练能力
so,为了实现并行化,让我们探讨如何使用循环计算输出
每个状态比如都是前一个状态比如乘以,加上当前输入乘以的总和,这就叫扫描操作(scan operation),可以使用 for 循环轻松计算,然这种状态之下想并行化是不可能的(因为只有在获取到前一个状态的情况下才能计算当前的每个状态)
好在mamba通过并行扫描(parallel scan)算法使得最终并行化成为可能,其假设我们执行操作的顺序与关联属性无关
因此,我们可以分段计算序列并迭代地组合它们,即动态矩阵B和C以及并行扫描算法一起创建选择性扫描算法(selective scan algorithm)
为了方便大家更好的理解,我把相关推导再拆解一下,以更一目了然
- 首先,和的计算很简单,如下所示
- 其次,可以由直接计算得来,也可以由甚至计算得来
- 最后,最终包含了之前、以及、、的信息,只是做了整体的压缩
4.1.3 硬件感知的状态扩展:借鉴Flash Attention
为了让传统的SSM在现代GPU上也能高效计算,Mamba中也使用了Flash Attention技术
- 简而言之,利用内存的不同层级结构处理SSM的状态,减少高带宽但慢速的HBM内存反复读写这个瓶颈
- 具体而言,就是限制需要从 DRAM 到 SRAM 的次数(通过内核融合kernel fusion来实现),避免一有个结果便从SRAM写入到DRAM,而是待SRAM中有一批结果再集中写入DRAM中,从而降低来回读写的次数*(更多详见:通透理解FlashAttention与FlashAttention2:全面降低显存读写、加快计算速度*)
4.1.4 简化的SSM架构
将大多数SSM架构比如H3的基础块,与现代神经网络比如transformer中普遍存在的门控MLP相结合,组成新的Mamba块,重复这个块,与归一化和残差连接结合,便构成了Mamba架构
顺带提一嘴,transformer quality in linear time以及mega moving average equipped gated attention的这两个工作,也用了类似的结构:即删除transformer的ffn/glu结构
最终流程如下(图源自mamba原论文)
- 在更高速的SRAM内存中执行离散化和递归操作,再将输出写回HBM 具体来说,我们不是在GPU HBM(高带宽内存)中将大小为(B,L,D,N)的扫描输入进「*instead of preparing the scan input (A, B) of size (𝙱, 𝙻, 𝙳, 𝙽) in GPU HBM (high-bandwidth memory)*」,而是 首先,直接将SSM参数从慢速HBM加载到快速SRAM中 然后,在SRAM中进行离散化和递归计算 最后,将大小为(B,L,D)的最终输出写回HBM
- 通过上一节4.1.2节介绍的并行扫描算法实现并行化
- 当输入从HBM加载到SRAM时,中间状态不被保存,而是在反向传播中重新计算the intermediate states are not stored but recomputed in the backward pass when the inputs are loaded from HBM to SRAM
4.2 通过mamba预测下一个token的示例
从线性投影开始,以扩展输入嵌入。然后,在选择性 SSM之前应用卷积以防止独立的token计算
其中的“选择性SSM(即Selective SSM)”具有以下属性
- Recurrent SSM通过离散化创建循环SSM
- HiPPO对矩阵A进行初始化A以捕获长程依赖性
- 选择性扫描算法(Selective scan algorithm)选择性压缩信息
- 硬件感知算法(Hardware-aware algorithm)加速计算
最后,包含归一化层和用于选择“预测的token”的softmax
4.3 对Improving SSMs with Selection的进一步阐述
4.3.1 三个任务的对比:copying、selective copying、induction heads
如下图所示,有三个任务
- (左)复制任务的标准版本涉及输入和输出元素之间的固定间距,可以通过线性递归和全局卷积等时不变模型轻松解决*(Left) The standard version of the Copying task involves constant spacing between input and output elements and is easily solved by time-invariant models such as linear recurrences and global convolutions.*
- (右上)选择性复制任务在输入之间具有随机间距,需要使用时变模型,在内容上能够灵活地选择记忆或忽略输入*(Right Top) The Selective Copying task has random spacing in between inputs and requires time-varying models that can selectively remember or ignore inputs depending on their content.* 相当于选择性复制任务通过改变“要记忆的tokens的位置”来改进纯粹的复制任务(Arjovsky, Shah和Bengio 2016)。它需要内容感知推理,以便能够记住相关的标记(有色),并过滤掉不相关的标记(白色)The Selective Copying task modifies the popular Copying task (Arjovsky, Shah, and Bengio 2016) by varying the position of the tokens to memorize. It requires content-aware reasoning to be able to memorize the relevant
- (右下)归纳头部任务是联想回忆的一个例子,需要根据上下文检索答案,这是LLM关键的能力*(Right Bottom) The Induction Heads task is an example of associative recall that requires retrieving an answer*based on context, a key ability for LLMs. 其实,归纳头部任务是一种众所周知的机制,据推测可以解释LLMs的大部分上下文学习能力(Olsson et al. 2022)。它需要上下文感知的推理,以便知道何时在适当的上下文中产生正确的输出(黑色) The Induction Heads task is a well-known mechanism hypothesized to explain the majority of in-context learning abilities of LLMs (Olsson et al. 2022). It requires context-aware reasoning to know when to produce the correct output in the appropriate context (black)
// 待更
4.4 实验结果
Mamba在Chinchilla缩放定律下预训练时,语言任务优于同类开源模型
下游任务上,每个规模尺寸的Mamba都是同类最佳,并且通常与两倍规模的基线性能匹配,特别是当序列长度增加到512k时,相比使用FlashAttention-2的Transformer快几个数量级,而且不会内存不足
最后,有的新闻稿会说Mamba是第一个实现匹配Transformer性能的线性时间序列模型,其实第一个是TransNormerLLM
参考文献与推荐阅读
- Transformer挑战者出现!FlashAttention作者参与,模型代码都开源,公司已创办
- Hungry Hungry Hippos: Towards Language Modeling with State Space Models
- [线性RNN系列] Mamba: S4史诗级升级
- Structured State Spaces for Sequence Modeling (S4)
- S4: 使用结构化状态空间序列进行高效建模
- 《Efficiently Modeling Long Sequences with Structured State Spaces》 首次提出了结构化状态空间S4
- S4作者在YouTube上对S4论文的精彩解读
- Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
- RWKV: Reinventing RNNs for the Transformer Era(下载地址2),这是其翻译,这是其解读之一
- 【手撕LLM-RWKV】重塑RNN 效率
- 挑战Transformer的Mamba是什么来头?作者博士论文理清SSM进化路径
- Mamba论文为什么没被ICLR接收?AI社区沸腾了
- openreview上对mamba论文的审稿意见:https://openreview.net/forum?id=AL1fq05o7H
- A Visual Guide to Mamba and State Space Models An Alternative to Transformers for Language Modeling,by MAARTEN GROOTENDORST
- The Annotated S4,包含对S4的实现,比如对矩阵B、C的学习
- Mamba - a replacement for Transformers?,YouTube上Samuel Albanie关于mamba非常精彩的解读,目前该视频20多万的播放量
- Mamba No. 5 (A Little Bit Of...)
- Mamba: The Easy Way,Oxford, UK — February 23, 2024Mamba原理通俗介绍
- 大模型相关论文100篇短笔记
- ..
创作、修订、完善记录
- 第一版的完成过程 12.11,开写,且发现Google抓的也是真快(当天用Google搜:mamba模型,本文已排第一)
- 12.12,考虑到想理解好mamba,则需要先理解好SSM,故全力完善这几节的内容:“1.2 状态空间模型SSM”、“1.3 S4的前身:HiPPO”、“1.4 S4的推出:Structured State Space Models”
- 12.13,完善此节:“1.5 S4的性质:连续的表示、用Recurrent快速infer、用Convolutional快速训练”
- 12.14,结合mamba论文,开始精修“第二部分 Mamba的组成结构与原理解析” 特别是以下这两节 2.1.1 选择性状态空间模型:从S4到S6 2.1.2 硬件感知的状态扩展:借鉴Flash Attention
- 12.15,开始写:“第三部分 Mamba近似工作之线性Transformer:从AFT、RWKV谈到TransnormerLLM” 特别是此节:“3.2 RWKV:试图在Transformer时代重塑RNN”
- 12.17,修正1.4节中的一个笔误,已修正为:“作者想到了非常精妙的一个方法:不考虑input 到state ,而是直接从state 到output y ”
- 12.19,在TransNormer的提出者qinzhen的建议之下,补充关于线性transformer的一些解释说明,特别是关键的这一句 “考虑到矩阵乘法有结合律,softmax只能左乘,linear可以右乘,而右乘更快,正因为矩阵乘积的这个属性可以实现注意力操作的线性复杂度”
- 12.23,根据友人钟博士的反馈,在文中强调:第一个实现匹配Transformer性能的线性时间序列模型是TransNormerLLM..
- 24年2.2,新增一节的内容,即 1.6 (选读)Hungry Hungry Hippos:基于状态空间模型的语言建模
- 第二版的修订过程(质量相比第一版提高2-3倍) 3.2-3.5,全面大幅修订本文近2/3的内容,特别是:第二部分 从状态空间模型SSM到S4的升级之路
- 3.6,为了更好的阐述清楚mamba本身,把原属于本文的这部分内容 “第五部分 Mamba近似工作之线性Transformer:从TransnormerLLM到RWKV” 转到另一篇文章中《七月论文审稿GPT第1版:通过3万多篇paper和10多万的review数据微调RWKV》 故本文标题由原来的 一文通透想颠覆Transformer的Mamba:从SSM、S4到mamba、线性transformer(含RWKV解析) 改成 一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba
版权归原作者 v_JULY_v 所有, 如有侵权,请联系我们删除。