we want to maximize h idden state dimension without paying speed and memory costs.
•
Note that the recurrent mode is more flexible than the convolution mode, since the latter
(3)
is derived from
expanding the former
(2)
(Gu, Goel, and Ré 2022; Gu, Johnson, Goel, et al. 2021). However, this would require
computing and materializing the latent state
ℎ
with shape
(𝙱, 𝙻, 𝙳, 𝙽)
, much larger (by a factor of
𝑁
, the SSM
state dimension) than the input
𝑥
and output
𝑦
of shape
(𝙱, 𝙻, 𝙳)
. Thus the more efficient convolution mode was
introduced which could bypass the state computation and materializes a convolution kernel
(3a)
of only
(𝙱, 𝙻, 𝙳)
.
•
Prior LTI SSMs leverage the dual recurrent-convolutional forms to increase the effective state dimension by a
factor of 𝑁 (≈ 10 − 100), much larger than traditional RNNs, without efficiency p enalties.
3.3.2 Overview of Selective Scan: Hardware-Aware State Expansion
The selection mechanism is designed to overcome the limitations of LTI models; at the same time, we therefore
need to revisit the computation problem of SSMs. We address this with three classical techniques: kernel fusion,
parallel scan, and recomputation. We make two main observations:
•
The naive recurrent computation uses
𝑂(𝐵𝐿𝐷𝑁)
FLOPs while the convolutional computation uses
𝑂(𝐵𝐿𝐷 log(𝐿))
FLOPs, and the former has a lower constant factor. Thus for long sequences and not-too-large state dimension
𝑁, the recurrent mode can actually use fewer FLOPs.
•
The two challenges are the sequential nature of recurrence, and the large memory usage. To address the latter,
just like the convolutional mode, we can attempt to not actually materialize the full state ℎ.
The main idea is to leverage properties of modern accelerators (GPUs) to materialize the state
ℎ
only in more
efficient levels of the memory hierarchy. In particular, most operations (except matrix multiplication) are bounded
by memory bandwidth (Dao, Fu, Ermon, et al. 2022; Ivanov et al. 2021; Williams, Waterman, and Patterson
2009). This includes our scan operation, and we use kernel fusion to reduce the amount of memory IOs, leading to
a significant speedup compared to a standard implementation.
Concretely, instead of preparing the scan input
(A, B)
of size
(𝙱, 𝙻, 𝙳, 𝙽)
in GPU HBM (high-bandwidth memory),
we load the SSM parameters
(∆, A, B, C)
directly from slow HBM to fast SRAM, perform the discretization and
recurrence in SRAM, and then write the final outputs of size (𝙱, 𝙻, 𝙳) back to HBM.
To avoid the sequential recurrence, we observe that despite not being linear it can still be parallelized with a
work-efficient parallel scan algorithm (Blelloch 1990; Martin and Cundy 2018; Smith, Warrington, and Linderman
2023).
Finally, we must also avoid saving the intermediate states, which are necessary for backpropagation. We carefully
apply the classic technique of recomputation to reduce the memory requirements: the intermediate states are not
stored but recomputed in the backward pass when the inputs are loaded from HBM to SRAM. As a result, the
fused selective scan layer has the same memory requirements as an optimized transformer implementation with
FlashAttention.
Details of the fused kernel and recomputation are in Appendix D. The full Selective SSM layer and algorithm is
illustrated in Figure 1.
3.4 A Simplied SSM Architecture
As with structured SSMs, selective SSMs are standalone sequence transformations that can be flexibly incorporated
into neural networks. The H3 architecture is the basis for the most well-known SSM architectures (Section 2), which
are generally comprised of a block inspired by linear attention interleaved with an MLP (multi-layer perceptron)
block. We simplify this architecture by combining these two components into one, which is stacked homogenously
(Figure 3). This is inspired by the gated attention unit (GAU) (Hua et al. 2022), which did something similar for
attention.
This architecture involves expanding the model dimension
𝐷
by a controllable expansion factor
𝐸
. For each
block, most of the parameters (
3𝐸𝐷
2
) are in the linear projections (
2𝐸𝐷
2
for input projections,
𝐸𝐷
2
for output
projection) while the inner SSM contributes less. The number of SSM parameters (projections for
∆, B, C
, and
7