Mamba Quickstart

· · 算法·理论

Mamba Quickstart

Linear time-invariant state space models

Linear time invariant (LTI) state space models (SSM) or S4

Continious form

The state space model is defined by this simple equation. It maps a 1-D input signal x(t) to an N-D latent state \mathbf h(t) before projecting to a 1-D output signal y(t).

\begin{aligned} \mathbf h'(t) &= \mathbf A \cdot \mathbf h(t) + \mathbf B \cdot x(t) \\ y(t) &= \mathbf C \cdot \mathbf h(t) + D \cdot x(t) \end{aligned}

Discrete form

\begin{aligned} \mathbf h_k &= \mathbf{\overline A} \cdot \mathbf h_{k-1} + \mathbf{\overline B} \cdot x_k \\ y_k &= \mathbf {\overline C} \cdot \mathbf h_k + D\cdot x_k \end{aligned}

The horizontal bars over SSM metrics indaties discretized parameters. This equation is now a sequence-to-sequence map \mathbf x \in \mathbb R^L =(x_0,\ldots,x_{L-1}) \mapsto \mathbf y \in \mathbb R^L=(y_0,\ldots,y_{L-1}) instead of function-to-function. Moreover the state equation is now a recurrence in \mathbf h_k (hidden state), allowing the discrete SSM to be computed like an RNN.

Learn a "step size" parameter \Delta \in \mathbb R^+ (in practice it is learned in the log space) and fixed choices for discretization functions: f_{\mathbf A}, f_{\mathbf B}, f_{\mathbf C}, f_{\mathbf D}, converts the state matrices into discretized approximations.

\begin{aligned} \overline{\mathbf A} &= f_{\mathbf A}(\Delta, \mathbf A) \\ \overline{\mathbf B} &= f_{\mathbf B}(\Delta, \mathbf A, \mathbf B) \\ \overline{\mathbf C} &= f_{\mathbf C}(\Delta, \mathbf C) \\ \overline{\mathbf D} &= f_{\mathbf D}(\Delta, \mathbf D) \\ \end{aligned}

[!WARNING]

你说 State Spaces 离散化我笑.jpg。首先 data dependent 的 decay 完全丧失了 LTI 的性质,非要叫 State Space 多多少少有点强行。其次个人完全不信离散化能有什么用。如果真有用,论文实现里也不至于把B的离散化直接简化成 linear attention 的外积形式了 1 2

作者:sonta 链接:https://www.zhihu.com/question/644981978/answer/3406436860

With these methods, the discrete-time SSM output can be written as a 1D-convolution, which can be done in \mathcal O(L \log L)​ time via FFT.

\begin{aligned} y_k &= \mathbf C \overline{\mathbf A}^k\overline{\mathbf B} \cdot x_{0} + \mathbf C \overline{\mathbf A}^{k-1}\overline{\mathbf B} \cdot x_{1} + \cdots + \mathbf C \overline{\mathbf B} \cdot x_{k}\\ &= \sum_{j=0}^{k} \mathbf C \overline{\mathbf A}^j\overline{\mathbf B} \cdot x_{k-j} = \sum_{j=0}^{k} \overline{\mathbf K}_j \cdot x_{k-j} \end{aligned} \mathbf y = \mathbf x * \overline{\mathbf K} \quad \text{where } \overline {\mathbf K} \in \mathbb R^{L} = (\mathbf C\overline{\mathbf B}, \mathbf C \overline{\mathbf A}\overline{\mathbf B},\ldots,\mathbf C \overline{\mathbf A}^{L-1} \overline{\mathbf B})
def K_conv(Ab, Bb, Cb, L):
    return np.array(
        [(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
    )

Addressing Long-range Dependencies with HiPPO

\textbf{HiPPO Matrix} \quad \mathbf A_{nk} = \begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \\ \end{cases}

Diagonal form

The fundamental bottleneck in computing the discrete-time SSM is that it involves repeated matrix multiplication by \overline{\mathbf A}. For example, computing naively involves 𝐿 successive multiplications by \overline{\mathbf A}, requiring \mathcal O(N^2L) matrix multiplications and \mathcal O(NL)​​ space.

Recall that the kernel has the form:

\quad \overline{\mathbf K} = [\overline{\mathbf K}_0, \overline{\mathbf K}_1,\ldots, \overline{\mathbf K}_{L-1}], \quad \text{ where } \overline{\mathbf K}_j= \mathbf C \overline{\mathbf A}^j\overline{\mathbf B} \in \mathbb R

When \mathbf{A} is diagonal, \overline{\mathbf{A}} = \exp(\Delta \mathbf{A}) is also diagonal, making the computation nearly trivial. For notational simplicity, let \mathbf{C}_i denote the i-th element of vector \mathbf{C} (i.e., [\mathbf{C}]_{1,i}), \overline{\mathbf{B}}_i denote the i-th element of vector \overline{\mathbf{B}} (i.e., [\overline{\mathbf{B}}]_{i,1}), and \mathbf{A}_i denote the i-th diagonal element of \overline{\mathbf{A}} (i.e., [\overline{\mathbf{A}}]_{i,i}).

\begin{aligned} \overline{\mathbf K}_j &= \mathbf C \overline{\mathbf A}^j\overline{\mathbf B}=\sum_{i=0}^{N-1} \mathbf C_i \overline{\mathbf A}_{i}^j \overline{\mathbf B}_i\\ \overline{\mathbf K} &= \underbrace{(\overline{\mathbf B}^\top \circ \mathbf C)}_{1\times N} \cdot \underbrace{\mathcal V_L(\overline{\mathbf A})}_{N\times N} \\ &=[\overline{\mathbf B_0} \mathbf C_0, \overline{\mathbf B_1} \mathbf C_1, \ldots, \overline{\mathbf B_{N-1}} \mathbf C_{N-1}] \begin{bmatrix} 1 &\overline{\mathbf A}_0 &\overline{\mathbf A}_0^2 &\cdots &\overline{\mathbf A}_0^{L-1} \\ 1 &\overline{\mathbf A}_1 &\overline{\mathbf A}_1^2 &\cdots &\overline{\mathbf A}_1^{L-1} \\ \vdots &\vdots &\vdots &\ddots &\vdots \\ 1 &\overline{\mathbf A}_{N-1} &\overline{\mathbf A}_{N-1}^2 &\cdots &\overline{\mathbf A}_{N-1}^{L-1} \\ \end{bmatrix} \end{aligned}

This calculation can be parallelized by work efficient scan.

[!NOTE]

In Mamba: Structure and Dimensions. Finally, we note that structured SSMs are so named because computing them efficiently also requires imposing structure on the A matrix. The most popular form of structure is diagonal (Gu, Gupta, et al. 2022; Gupta, Gu, and Berant 2022; Smith, Warrington, and Linderman 2023), which we also use.

Mamba SSM

Mamba (S6) removes linear time invariance for \mathbf B, \mathbf C, \Delta. They're now functions of \mathbf x_t, i.e. the parameters are selective.

\begin{aligned} \mathbf h_t &= \overline{\mathbf A(\mathbf x_t)} \cdot \mathbf h_{t-1} + \overline {\mathbf{B}(\mathbf x_t)} \cdot \mathbf x_t \\ \mathbf y_t &= \overline {\mathbf {C}(\mathbf x_t)} \cdot \mathbf h_t + \overline {\mathbf{D}} \cdot \mathbf x_t \end{aligned}

Mamba Code Highlight

Fixed Initialization

A initialization

# S4D real initialization
# A shape = (d_inner, d_state) / (D, N)
A = repeat(
    torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
    "n -> d n",
    d=d_inner,
).contiguous()
A_log = torch.log(A)  # Keep A_log in fp32
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True

A = -torch.exp(A_log.float())  # (d_inner, d_state)

D initialization

# D "skip" parameter
# D shape = (D_inner, )
D = nn.Parameter(torch.ones(d_inner, device=device))
D._no_weight_decay = True

dt bias initialization

# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        self.dt_proj.bias._no_reinit = True

Forward

batch, seqlen, dim = hidden_states.shape 

# in_proj shape = (d_inner * 2, d_model)
# hidden_states shape = (batch, seqlen, dim)

# xz shape = (batch, d_inner * 2, seqlen), x and z are concatenated
xz = rearrange(
        in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
        "d (b l) -> b d l",
        l=seqlen,
    )
# x, z shape both are (batch, d_inner, seqlen)
x, z = xz.chunk(2, dim=1)

# x.shape = (batch, d_inner, seqlen)
x = act(conv1d(x))[..., :seqlen]  

# x_dbl.shape = (batch * seqlen, dt_rank + d_state * 2)
x_dbl = x_proj(rearrange(x, "b d l -> (b l) d")) 

# dt.shape = (batch * seqlen, dt_rank), B.shape = C.shape = (batch * seqlen, d_state)
dt, B, C = torch.split(x_dbl, [dt_rank, d_state, d_state], dim=-1)

# After projection, dt.shape = (d_inner, batch * seqlen)
dt = dt_proj.weight @ dt.t() 
# dt.shape = (batch, d_inner, seqlen)
dt = rearrange(dt, "d (b l)-> b d l", l=seqlen)

B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()

SSM

y = selective_scan_ref(
                x,
                dt,
                A,
                B,
                C,
                D.float(),
                z=z,
                delta_bias=dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=True,
                return_last_state=ssm_state is not None,
            )

Discretize \mathbf A. Mamba takes \overline{\mathbf A}​ of the ZOH form, which is \exp(\Delta \mathbf A). \DeltaN 维度广播(对 state space 每个维度施加相同步长)。对输入 x(t),每个维度独立的进行 SSM,\mathbf A 有独立的 D 份,乘上 B 维度和 L 维度的步长,得到 \overline{\mathbf A}

[!NOTE]

Mamba: We remark that while the \mathbf A parameter could also be selective, it ultimately affects the model only through its interaction with \Delta via \overline{\mathbf A} = \exp(\Delta \mathbf A) (the discretization). Thus selectivity in \Delta is enough to ensure selectivity in (\overline{\mathbf A}, \overline{\mathbf B}), and is the main source of improvement. We hypothesize that making A selective in addition to (or instead of) \Delta would have similar performance, and leave it out for simplicity.

# discretized A.shape = (B, D, N, L)
deltaA = torch.exp(torch.einsum("bdl, dn -> bdln", delta, A)) 

Discretize B. 这里作者直接采用了欧拉形式,\overline{\mathbf B} = \Delta \mathbf B,然后计算 \overline{\mathbf B({\mathbf x_i})}\cdot \mathbf x_i.

deltaB_u = torch.einsum("bdl, bnl, bdl -> bdln", delta, B, u)

理解: torch.einsum("bdln, bdl -> bdln", torch.einsum("bdl, bnl -> bdln", delta, B), u)

Scan!!!

for i in range(u.shape[2]):  # L
    # ith [B, D, N] * [B, D, N] + ith [B, D, N]
    x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
    y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
    if i == u.shape[2] - 1:
        last_state = x
    ys.append(y)

由于 \mathbf A (以及 \overline{\mathbf A(\mathbf x_t)}) 已经是 diagonal, \overline{\mathbf A(\mathbf x_t)}\cdot \mathbf h_t 也为 element wise product

Final Post Processing

# y.shape = (B, D, L)
y = torch.stack(ys, dim=2)
# skip connection
out = y + u * rearrange(D, "d -> d 1")
out = out * F.silu(z)

CUDA work efficient scan

Tri Dao, FlashAttention, introduced hardware accelerated self-attention, bringing memory requirements from quadratic to linear and also provideing dramatic wall-clock time speedups.

The Blelloch parallel scan

For any binary associative operators, we have a parrallel scan for prefix ops.

https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda#:~:text=39.2%20Implementation

General Idea: segment tree. For any particular node u:

Mamba parallel scan

With some abuse of notation, let a_t = \overline{\mathbf A}, b_t = \overline {\mathbf{B}(\mathbf x_t)}\cdot \mathbf x_t, x_t \equiv \mathbf h_t \equiv [a_t,b_t]. We can rewrite \mathbf h_t = \mathbf{\overline A} \mathbf h_{t-1} + \overline {\mathbf{B}(\mathbf x_t)} \cdot \mathbf x_t

The general form of first-order recursion can be written as:

x_i = \begin{cases} b_0, &i=0, \\ (a_i \otimes x_{i-1})\oplus b_i,\quad &0<i<n. \end{cases}

There are some additional requirents for operators \otimes and \oplus

  1. - Here is vector-vector addition
  2. \otimes$ is semiassociative, i.e. $a \otimes (b \otimes c) = (a \odot b) \otimes c
    • Here \odot is matrix multiplication and \otimes is matrix-vector multiplication
  3. \otimes$ distributes over $\oplus$, i.e. $a \otimes (b \oplus c) = (a \otimes b) \oplus(a \oplus c)

我们先递归展开:

\begin{aligned} x_i &= (a_i\otimes x_{i-1}) \oplus b_i \\ &= (a_i \otimes ((a_{i-1} \otimes x_{i-2})\oplus b_{i-1})) \oplus b_i \\ &= (a_i \otimes a_{i-1} \otimes x_{i-2}) \oplus (a_i \otimes b_{i-1}) \oplus b_i \\ \\ &= (a_i \otimes a_{i-1} \otimes ((a_{i-2} \otimes x_{i-3} )\oplus b_{i-2}) \oplus (a_i \otimes b_{i-1}) \oplus b_i \\ &= (a_i \otimes a_{i-1} \otimes a_{i-2} \otimes x_{i-3}) \oplus(a_i \otimes a_{i-1} \otimes b_{i-2}) \oplus (a_i \otimes b_{i-1}) \oplus b_i \\ &= \cdots \\\\ &= (a_i\otimes a_{i-1} \otimes \cdots \otimes a_1 \otimes b_0) \oplus(a_i\otimes a_{i-1} \otimes \cdots \otimes a_2 \otimes b_1) \\ & \oplus (a_i\otimes a_{i-1} \otimes \cdots \otimes a_{3} \otimes b_2) \oplus (a_i\otimes a_{i-1} \otimes \cdots \otimes a_{4} \otimes b_3) \\ & \oplus \cdots\oplus (a_i \otimes b_{i-1}) \oplus b_i \end{aligned}

and one step forward:

\begin{aligned} x_{i+1} &= (a_{i+1} \otimes x_1) \oplus b_{i+1}\\ &= (a_{i+1}\otimes a_i\otimes a_{i-1} \otimes \cdots \otimes a_1 \otimes b_0) \oplus(a_{i+1}\otimes a_i\otimes a_{i-1} \otimes \cdots \otimes a_2 \otimes b_1) \\ & \oplus (a_{i+1}\otimes a_i\otimes a_{i-1} \otimes \cdots \otimes a_{3} \otimes b_2) \oplus (a_{i+1}\otimes a_i\otimes a_{i-1} \otimes \cdots \otimes a_{4} \otimes b_3) \\ & \oplus \cdots \oplus (a_{i+1}\otimes a_i \otimes b_{i-1}) \oplus (a_{i+1}\otimes b_i) \oplus b_{i+1} \end{aligned}

Consider the pairs:

c_i = [a_i, b_i]

we can define a new binaty operator * as follows:

c_i * c_j := [a_j \odot a_i, (a_j \otimes b_i) \oplus b_j]

The operator * is associative:

\begin{aligned} &\text{Apply the definition of $*$ :}\\ (c_i * c_j) * c_k &= [a_j \odot a_i, (a_j \otimes b_i) \oplus b_j] * c_k \\ &\text{Apply the definition of $*$ again :}\\ &= [a_j \odot a_i, (a_j \otimes b_i) \oplus b_j] * [a_k, b_k] \\ &= [a_k \odot (a_j \odot a_i), (a_k \otimes ((a_j \otimes b_i) \oplus b_j)) \oplus b_k]\\ &\text{Associativity of $\odot$ :} \\ &= [(a_k \odot a_j) \odot a_i, (a_k \otimes ((a_j \otimes b_i) \oplus b_j)) \oplus b_k]\\ &\text{$\otimes$ distributes $a_k$ over $\oplus$ :}\\ &=[(a_k \odot a_j) \odot a_i, ((a_k\otimes (a_j \otimes b_i)) \oplus (a_k \otimes b_j)) \oplus b_k] \\ &\text{Associativity of $\oplus$ :} \\ &=[(a_k \odot a_j) \odot a_i, (a_k\otimes (a_j \otimes b_i)) \oplus ((a_k \otimes b_j) \oplus b_k)] \\ &\text{Semi-associativity of $\otimes$ :} \\ &= [(a_k \odot a_j) \odot a_i, ((a_k\odot a_j) \otimes b_i) \oplus ((a_k \otimes b_j) \oplus b_k)] \\ &\text{Apply the definition of $*$ :}\\ &= [a_i,b_i] * [a_k \odot a_j, (a_k \otimes b_j) \oplus b_k] \\ &= c_i * (c_j * c_k) \end{aligned}

We now difine the ordered set

\begin{aligned} s_i &= [p_i, x_i]\\ &= [\text{$a$ production}, \text{$\{b\}$ segment}] \end{aligned}

In the case of Mamba, Set s_0 = [p_0,x_0] = [\overline{\mathbf A}, \mathbf h_0] = c_0, and we have s_1 = [p_1,x_1] = [\mathbf{\overline A}^2,\mathbf{\overline A} \mathbf h_{0} + \overline {\mathbf{B}(\mathbf x_1)} \cdot \mathbf x_1] = [\overline{\mathbf A}, \mathbf h_0] * [a_1, b_1] = s_0 * c_1​.

By induction, we obtain:

\begin{aligned} s_i &= [p_i, x_i] \quad \quad 0 <i <n\\ &= [a_i \odot p_{i-1}, (a_i \otimes x_{i-1}) \oplus b_i] \\ &=[p_{i-1},x_i] *[a_i,b_i]\\ &= s_{i-1} * c_i \\ &= s_{i-2} * c_{i-1} * c_i \\ &= \cdots \\ &= s_0 * c_1 * c_2 * \cdots * c_i \end{aligned}

https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/csrc/selective_scan/selective_scan_common.h#L113

template<>
struct SSMScanOp<float> {
    __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
        return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
    }
};

Jax implementation, credit

# Various imports
from einops import einsum
import jax
# Jax.lax already has a convenient parallel scan implementation.
import jax.lax as lax
import jax.numpy as jnp

key = jax.random.PRNGKey(seed=42)

B = 1  # batch size
L = 8192 # context length
N = 64  # hidden state size
D = 2  # num in channels
V = 1  # num out channels

# Gets the various fake x_t inputs.
def generate_random_xs(key, num_inputs=L, num_channels=D):
    key, subkey = jax.random.split(key)
    xs = jax.random.lognormal(subkey, shape=(L, D))
    return key, xs

# Gets various fake A matrices. This isshape= actually constant in the paper,
# but it doesn't have to be.
def generate_random_As(key, num_inputs=L, state_size=N):
    key, subkey = jax.random.split(key)
    As = jax.random.lognormal(subkey, shape=(L, N, N))
    return key, As

# Gets various fake B(x_t) matrices.
def generate_random_Bxs(key, num_inputs=L, state_size=N, num_channels=D):
    key, subkey = jax.random.split(key)
    Bxs = jax.random.lognormal(subkey, shape=(L, N, D))
    return key, Bxs

# Gets the b_t term.
def get_bs(xs, Bxs):
    return einsum(Bxs, xs, "l n d, l d -> l n")

# Jax plays nicest with jnp.arrays, so we'll stuff the values inside a
# single array and just unpack things here. I suppose I could use PyTrees
# but please forgive a bit of laziness/hackiness on my part.
def extract(c, state_size):
    assert c.ndim == 1
    assert c.shape[0] == state_size * state_size + state_size
    return (
        c[:state_size * state_size].reshape((state_size, state_size)),
        c[-state_size:].reshape((state_size,))
    )

The operator implementation and test logic

def operator(c_prev, c_curr, num_inputs=L, state_size=N, num_channels=D):
    prev_a, prev_b = extract(c_prev, state_size)
    curr_a, curr_b = extract(c_curr, state_size)
    return jnp.concatenate([
        jnp.ravel(curr_a @ prev_a), 
        jnp.ravel(curr_a @ prev_b + curr_b)
    ])
vectorized_operator = jax.vmap(operator, in_axes=(0, 0), out_axes=0)

# Actually generate some fake test data.
key, xs = generate_random_xs(key)
key, Bxs = generate_random_Bxs(key)
key, As = generate_random_As(key)

bs = get_bs(xs, Bxs)
cs = jnp.concatenate([As.reshape(-1, N * N), bs], axis=1)

# %%timeit results on a freebie Google Colab VM: 
# 283 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lax_scanned = lax.associative_scan(vectorized_operator, cs)[:, -N:]

def naive_scan_hs(h_0, As, Bxs, xs):
    output = [h_0]
    for a, bx, x in zip(As, Bxs, xs):
        b = einsum(bx, x, "n d, d -> n")
        output.append(a @ output[-1] + b)
    return output[1:]

# %%timeit results on a freebie Google Colab VM:
# 3.34 s ± 313 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
naive_hs = jnp.vstack(
    naive_scan_hs(jnp.zeros((N,)), As, Bxs, xs)
)

# The following returns Array(True, dtype=bool)! Which means that
# we're getting identical results, (allowing for some floating point
# imprecision), regardless of if we're using the naive iterative
# implementation, or the fast parallel implementation.
jnp.allclose(naive_hs, lax_scanned)

Vision Mamba

Vim

VMamba

LocalMamba

Mamba2

[!TIP]

(剧透)大家可以期待一下之后变得越来越像Linear Attention的Mamba2,通过shared data-dependent decay for each head写成matmul的方式来继续scale up N(越来越像RetNet/GLA了)

作者:sonta 链接:https://www.zhihu.com/question/644981978/answer/3406436860

Credits

  1. https://jameschen.io/jekyll/update/2024/02/12/mamba.html#fnref:expensive
  2. https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state
  3. https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf
  4. https://srush.github.io/annotated-s4/