一种语言建模中 Transformer 的替代方案
Transformer 架构是大语言模型(LLMs)成功的关键组成部分。几乎所有今天使用的大语言模型都采用了该架构,从开源模型如 Mistral 到闭源模型如 ChatGPT。
为了进一步改进大语言模型,新的架构被开发出来,这些架构可能甚至会超越 Transformer 架构。其中一种方法是 Mamba,一种 状态空间模型。

在这篇文章中,我将介绍在语言建模背景下的状态空间模型领域,并逐步探索各个概念,以帮助理解这一领域。然后,我们将讨论 Mamba 如何可能挑战 Transformer 架构。
作为一个视觉指南,本文将通过许多可视化内容来帮助理解 Mamba 和状态空间模型!
第 1 部分:Transformers 的问题
为了说明 Mamba 是多么有趣的架构,我们首先简要回顾一下 Transformers,并探索其一个缺点。
Transformer 将任何文本输入视为由 tokens 组成的 sequence。

Transformers 的核心组件
记住,Transformer 由两个结构组成,一个用于表示文本的编码器块集和一个用于生成文本的解码器块集。结合起来,这些结构可用于多个任务,包括翻译。


训练中的祝福…
单个解码器块由两个主要组件组成,遮蔽的自注意力机制和前馈神经网络。

那么它是如何工作的呢?
它创建了一个矩阵,将每个 Token 与之前的每个 Token 进行比较。矩阵中的权重取决于 Token 对彼此的相关性。

它实现了并行化,这极大地加快了训练速度!
推理中的问题!
然而存在一个缺陷。当生成下一个 Token 时,我们需要重新计算整个序列的注意力,即使我们已经生成了一些 Token。


让我们看看一种“经典”技术——循环神经网络(Recurrent Neural Networks, RNN)是如何解决这个推理速度慢的问题。
RNN 是解决方案吗?
循环神经网络(RNN)是一种基于序列的网络。它在每个时间步中接收两个输入,即时间步t的输入和上一个时间步t-1的隐藏状态,用以生成下一个隐藏状态并预测输出。
RNN 有一个循环机制,允许它将信息从前一步传递到下一步。我们可以将这个可视化过程“展开”,使其更加明确。

换句话说,RNN 能够快速进行推理,因为它随着序列长度线性扩展!理论上,它甚至可以拥有无限的上下文长度。
为了说明这一点,让我们将 RNN 应用于我们之前使用的输入文本。

然而,这里有一个问题……
注意,当生成名字“Maarten”时,最后一个隐藏状态已经不再包含关于单词“Hello”的信息了。RNN 随时间推移往往会忘记信息,因为它们只考虑上一个状态。
虽然 RNN 在训练和推理方面速度较快,但它们缺乏 Transformer 模型所能提供的精度。
因此,我们研究状态空间模型 (State Space Models) 来高效地使用 RNN(有时还会使用卷积)。
第 2 部分:状态空间模型 (SSM)
状态空间模型 (SSM) 像 Transformer 和 RNN 一样处理信息序列,比如文本和信号。在本节中,我们将介绍 SSM 的基本概念以及它们与文本数据的关系。
什么是状态空间?
状态空间包含完全描述一个系统所需的最少变量数。它是一种通过定义系统可能状态来数学表示问题的方法。
让我们简化一下。想象我们正在穿越迷宫。“状态空间”就是所有可能位置(状态)的地图。每个点代表迷宫中的一个独特位置,并带有特定的细节,比如你距离出口有多远。
“状态空间表示”是对这张地图的简化描述。它显示了你当前的位置(当前状态)、你可以前往的下一个位置(可能的未来状态)、以及你如何到达下一个位置(向右或向左移动)。

变量描述了一个状态,在我们的例子中,X 和 Y 坐标,以及到出口的距离,可以表示为“状态向量”。


什么是状态空间模型?
SSM(状态空间模型)是一类用于描述这些状态表示并预测其下一个状态的模型,预测基于某些输入。
传统上,在时间 t,SSM:
- 将输入序列 x(t)(例如,在迷宫中向左和向下移动)映射到潜在状态表示 h(t)(例如,到出口的距离和 x/y 坐标)
- 并推导出预测的输出序列 y(t)(例如,再次向左移动以更快到达出口)
然而,SSM 并不是使用离散序列(如向左移动一次),而是接受连续序列作为输入并预测输出序列。


其目标是找到这个状态表示 h(t),使我们能够从输入到输出序列。

这两个方程将在整个指南中引用。为了让它们更直观,我们使用颜色编码,这样你可以快速引用它们。
状态方程 描述了状态如何根据输入(通过 矩阵 B)影响状态(通过 矩阵 A)而变化。

输出方程 描述了状态如何通过 矩阵 C 转换为输出,以及输入如何通过 矩阵 D 影响输出。

注意:矩阵 A、B、C 和 D 也常被称为 参数,因为它们是可学习的。
将这两个方程可视化,我们得到以下架构:

假设我们有一些输入信号 x(t),该信号首先与 矩阵 B 相乘,矩阵 B 描述了输入如何影响系统。


然后,我们使用 矩阵 C 来描述状态如何转换为输出。





从连续信号到离散信号
如果你有一个连续信号,找到状态表示 h(t) 在解析上具有挑战性。此外,由于我们通常有离散输入(例如文本序列),我们希望对模型进行离散化。
为此,我们使用了 零阶保持技术。其工作原理如下:首先,每当我们接收到离散信号时,我们保持其值,直到我们接收到新的离散信号为止。这个过程会生成一个连续信号,供 SSM 使用:

现在我们为输入生成了一个连续信号,接下来我们可以生成连续输出,并根据输入的时间步长对这些值进行采样。

从数学上讲,我们可以如下应用零阶保持:


我们用 k 来代替 t,以区分我们何时在谈论连续的 SSM 与离散的 SSM。
注意: 在训练期间,我们仍然保留 矩阵 A 的连续形式,而不是其离散化版本。在训练过程中,连续表示会被离散化。
现在我们已经有了一个离散化表示的公式,让我们探讨如何实际 计算 该模型。
递归表示
我们的离散化 SSM 使我们能够在特定的时间步中,而不是连续信号中,构建问题。正如我们之前在 RNN 中看到的,递归方法在这里非常有用。
如果我们考虑离散的时间步而不是连续信号,我们可以使用时间步重新表述这个问题:




卷积表示
我们可以使用卷积来表示 SSM。记住在经典图像识别任务中,我们应用滤波器(kernels)来提取聚合特征:




在下一步中,内核会移动一次以执行计算的下一步:


三种表示方式
这三种表示方式——连续、递归 和 卷积 各有不同的优缺点:

利用这些表示方式,我们可以使用一个巧妙的技巧,即根据任务选择表示方式。在训练过程中,我们使用可以并行化的卷积表示,而在推理过程中,我们使用高效的递归表示:

这些表示方式共享一个重要属性,即线性时间不变性(LTI)。LTI 表示 SSM 的参数 A、B 和 C 在所有时间步中都是固定的。这意味着矩阵 A、B 和 C 对每个生成的 Token 都是相同的。
换句话说,无论你给 SSM 任何序列,A、B 和 C 的值都保持不变。我们有一个不关心内容的静态表示方式。
在探讨 Mamba 如何解决这个问题之前,让我们探讨最后一个拼图碎片——矩阵 A。
矩阵 A 的重要性
可以说,SSM 公式中最重要的方面之一是 矩阵 A。正如我们之前在递归表示中看到的,它捕捉了有关 上一个 状态的信息,以构建 新的 状态。


如何以保留大量记忆(上下文大小)的方式创建 matrix A?
我们使用 Hungry Hungry Hippo!或者 HiPPO3 来实现高阶 多项式 投影 运算符。HiPPO 试图将其至今看到的所有输入信号压缩为一个系数向量。



HiPPO 矩阵的核心思想是生成一个记忆其历史的隐藏状态。
从数学上讲,它通过跟踪 Legendre 多项式 的系数来实现这一点,这使得它能够近似所有的历史记录。4
然后 HiPPO 被应用于我们之前看到的递归和卷积表示,以处理长距离依赖关系。结果是 Structured State Space for Sequences (S4),一种可以高效处理长序列的 SSM 类。5
它由三部分组成:
- 状态空间模型
- HiPPO 用于处理 长距离依赖
- 离散化用于创建 递归 和 卷积 表示

注意:如果你想深入了解如何计算 HiPPO 矩阵并自己构建 S4 模型,我强烈建议你阅读 Annotated S4。
第 3 部分:Mamba – 一种选择性状态空间模型
我们终于覆盖了理解 Mamba 特殊之处所需的所有基础知识。状态空间模型可以用于建模文本序列,但仍然有一系列我们希望避免的缺点。
在这一部分,我们将讨论 Mamba 的两个主要贡献:
- 一种 选择性扫描算法,允许模型筛选(不)相关信息
- 一种 硬件感知算法,通过 并行扫描、内核融合 和 重新计算 来高效存储(中间)结果。
这两者共同创建了 选择性状态空间模型 或 S6 模型,可以像自注意力一样用于创建 Mamba 块。
在探索这两个主要贡献之前,让我们首先探讨一下它们为何必要。
试图解决什么问题?
状态空间模型,甚至 S4(结构化状态空间模型),在语言建模和生成中某些关键任务上表现不佳,即 关注或忽略特定输入的能力。
我们可以用两个合成任务来说明这一点,即 选择性复制 和 诱导头。
在 选择性复制 任务中,SSM 的目标是复制输入的部分并按顺序输出:

因此,SSM 无法执行 内容感知推理,因为它由于固定的 A、B 和 C 矩阵而平等对待每个 token。这是一个问题,因为我们希望 SSM 对输入(提示)进行推理。
SSM 在另一个任务上表现不佳,即 诱导头,其目标是重现输入中找到的模式:

让我们通过关注 矩阵 B 来说明这一点。无论输入 x 是什么,矩阵 B 始终保持不变,因此与 x 无关:


SSM 在这些任务上的较差表现说明了时间不变 SSM 的潜在问题,矩阵 A、B 和 C 的静态特性导致了 内容感知 的问题。
选择性保留信息
SSM 的递归表示创建了一个较小的状态,这种状态非常高效,因为它压缩了整个历史。然而,与 Transformer 模型相比,Transformer 模型不会对历史进行压缩(通过注意力矩阵),因此它的能力更强。
Mamba 旨在兼具两者的优势。一个与 Transformer 状态一样强大的小状态:

为了选择性地压缩信息,我们需要参数依赖于输入。为此,让我们首先探索 SSM 在训练过程中输入和输出的维度:



注意:矩阵 A 保持不变,因为我们希望状态本身保持静态,但它的影响方式(通过 B 和 C)是动态的。
它们共同 选择性地 选择保留在隐藏状态中的内容和忽略的内容,因为它们现在依赖于输入。
较小的 步长 ∆ 导致忽略特定的词汇,而更多地使用之前的上下文,而较大的 步长 ∆ 则更加关注输入词汇,而不是上下文:
扫描操作
由于这些矩阵现在是 动态 的,它们不能使用卷积表示进行计算,因为卷积表示假设一个 固定 的卷积核。我们只能使用递归表示,这样就失去了卷积提供的并行化优势。
为了实现并行化,让我们探索如何使用递归计算输出:

相比之下,并行化似乎是不可能的,因为每个状态只能在有了前一个状态之后才可以计算。然而,Mamba 通过 并行扫描 算法使这一点成为可能。
它假设我们进行操作的顺序不重要,利用了结合律的属性。因此,我们可以将序列分成若干部分进行计算,然后迭代地将它们组合起来:

硬件感知算法
最近的 GPU 的一个缺点是它们的小型但高效的 SRAM 与大型但稍微不那么高效的 DRAM 之间的传输(IO)速度有限。频繁地在 SRAM 和 DRAM 之间复制信息会成为瓶颈。



- 离散化步骤与 step size ∆
- 选择性扫描算法
- 与 C 的乘法
硬件感知算法的最后一部分是 recomputation。
中间状态不会被保存,但在反向传递中计算梯度时是必要的。相反,作者在反向传递期间重新计算这些中间状态。
尽管这看起来可能效率低下,但比从相对较慢的 DRAM 中读取所有这些中间状态要便宜得多。
我们现在已经涵盖了其架构的所有组件,该架构在其文章中的图像如下所示:

这个架构通常被称为 选择性 SSM 或 S6 模型,因为它本质上是使用选择性扫描算法计算的 S4 模型。
Mamba 模块
我们迄今为止探索的 选择性 SSM 可以作为一个模块进行实现,就像我们可以在解码器模块中表示自注意力一样。


选择性 SSM 具有以下特性:
- 通过 离散化 创建的 递归 SSM
- 在矩阵 A 上进行 HiPPO 初始化,以捕捉 长程依赖性
- 选择性扫描算法 以选择性地压缩信息
- 硬件感知算法 以加快计算速度
当我们查看代码实现时,我们可以对这个架构进行更多扩展,并探索一个端到端的示例会是什么样的:

当我们将所有内容整合在一起时,我们得到了快速的推理和训练,甚至是无限上下文。使用这种架构,作者发现其性能与同等大小的 Transformer 模型相匹配,有时甚至超过!
结论
这就结束了我们对状态空间模型和令人难以置信的 Mamba 架构使用选择性状态空间模型的探索。希望这篇文章能让你更好地理解状态空间模型,特别是 Mamba。谁知道这是否会取代 Transformers,但现在,看到如此不同的架构获得应有的关注真是令人惊叹!
要查看更多与大语言模型相关的可视化内容并支持这份通讯,请查看我与 Jay Alammar 合著的书籍。



扫描操作

































