Mamba Quickstart
DarkMoon_Dragon · · 算法·理论
Mamba Quickstart
Linear time-invariant state space models
Linear time invariant (LTI) state space models (SSM) or S4
Continious form
The state space model is defined by this simple equation. It maps a 1-D input signal
-
-
x(t) \in \mathbb R$ is the 1-D input signal at time $t \in \mathbb R - Raw audio files,
t \in \mathbb R^+ , and\mathbf x(t) is the amplitude at timet - Text token embeddings,
t \in \mathbb Z^+ , and the SSM is applied independently to each channel
- Raw audio files,
-
y(t) \in \mathbb R$ is the corresponding output of the model at time $t \in \mathbb R - Binary classification,
V=1,t\in \mathbb R^+ andy(t) \in \{0,1\}
- Binary classification,
- State vector
\mathbf h(t) \in \mathbb R^N encapsulates all prior inputs at timet
Discrete form
The horizontal bars over SSM metrics indaties discretized parameters. This equation is now a sequence-to-sequence map
Learn a "step size" parameter
-
Bilinear
\begin{aligned} \overline{\mathbf A} &= (\mathbf I - \Delta/2 \cdot \mathbf A)^{-1}(\mathbf I + \Delta/2 \cdot \mathbf A) \\ \overline{\mathbf B} &= (\mathbf I - \Delta/2 \cdot \mathbf A)^{-1} \Delta \mathbf B \\ \overline{\mathbf C} &= \mathbf C \\ \end{aligned} def discretize(A, B, C, step): I = np.eye(A.shape[0]) BL = inv(I - (step / 2.0) * A) Ab = BL @ (I + (step / 2.0) * A) Bb = (BL * step) @ B return Ab, Bb, C -
ZOH:
\begin{aligned} \overline{\mathbf A} &= \exp(\Delta \mathbf A) \\ \overline{\mathbf B} &= (\Delta \mathbf A)^{-1} (\exp(\Delta \mathbf A) - \mathbf I) \Delta \mathbf B \\ \overline{\mathbf C} &= \mathbf C\\ \end{aligned} Proof: The solution of
\mathbf h'(t) = \mathbf A \cdot \mathbf h(t) + \mathbf B \cdot x(t) :\mathbf h(t) = e^{\mathbf A(t-t_0)} \mathbf h(t_0) + \int_{t_0}^{t} e^{\mathbf A{t-\tau}} \mathbf B \cdot x(\tau) \, \mathrm{d} \tau Let
t_0 = t_k, t = t_{k+1} andt_{k+1} - t_k = \Delta , we have\mathbf h(t_{k+1}) = e^{\mathbf A\Delta} \mathbf h(t_k) + x(t_k) \cdot \int_{t_k}^{t_{k+1}}e^{\mathbf A(t_{k+1} - \tau)} \mathbf B \, \mathrm d{\tau} Thus we have:
\begin{aligned} \overline{\mathbf A} &= e^{\mathbf A \Delta} = \exp{(\Delta \mathbf A)} \\ \overline{\mathbf B} &= \int_{t_k}^{t_{k+1}}e^{\mathbf A(t_{k+1} - \tau)} \mathbf B \, \mathrm d{\tau} = \int_{0}^{\Delta}e^{\mathbf A(\Delta - \tau)} \mathbf B \, \mathrm d{\tau} \\ &= \int_{0}^{\Delta}e^{\mathbf A(u)} \mathbf B \, \mathrm d{u} = \overline{\mathbf{B}} = \mathbf{A}^{-1} \left( e^{\mathbf{A}\Delta} - \mathbf{I} \right) \mathbf{B} \end{aligned}
[!WARNING]
你说 State Spaces 离散化我笑.jpg。首先 data dependent 的 decay 完全丧失了 LTI 的性质,非要叫 State Space 多多少少有点强行。其次个人完全不信离散化能有什么用。如果真有用,论文实现里也不至于把B的离散化直接简化成 linear attention 的外积形式了 1 2
作者:sonta 链接:https://www.zhihu.com/question/644981978/answer/3406436860
With these methods, the discrete-time SSM output can be written as a 1D-convolution, which can be done in
- Recall that
\mathbf C \in \mathbb R^{1\times N}, \overline{\mathbf A} \in \mathbb R^{N\times N} and\overline{\mathbf B} \in \mathbb R^{N\times 1} .
def K_conv(Ab, Bb, Cb, L):
return np.array(
[(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
)
Addressing Long-range Dependencies with HiPPO
-
Previous work found that simply modifying an SSM from a random matrix
\mathbf A to\textbf{HiPPO} improved its performance on the sequential MNIST classification benchmark from60\% to98\% . -
def make_HiPPO(N): P = np.sqrt(1 + 2 * np.arange(N)) A = P[:, np.newaxis] * P[np.newaxis, :] A = np.tril(A) - np.diag(np.arange(N)) return -A
Diagonal form
- https://arxiv.org/pdf/2203.14343
- https://arxiv.org/pdf/2206.11893
The fundamental bottleneck in computing the discrete-time SSM is that it involves repeated matrix multiplication by
Recall that the kernel has the form:
When
This calculation can be parallelized by work efficient scan.
[!NOTE]
In Mamba: Structure and Dimensions. Finally, we note that structured SSMs are so named because computing them efficiently also requires imposing structure on the A matrix. The most popular form of structure is diagonal (Gu, Gupta, et al. 2022; Gupta, Gu, and Berant 2022; Smith, Warrington, and Linderman 2023), which we also use.
Mamba SSM
Mamba (S6) removes linear time invariance for
Mamba Code Highlight
Fixed Initialization
A initialization
# S4D real initialization
# A shape = (d_inner, d_state) / (D, N)
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
A = -torch.exp(A_log.float()) # (d_inner, d_state)
- Default initialization is based on HIPPO theroy, which defines the
n -th element of\mathbf A in the diagonal as-(n+1) . -
D initialization
# D "skip" parameter
# D shape = (D_inner, )
D = nn.Parameter(torch.ones(d_inner, device=device))
D._no_weight_decay = True
dt bias initialization
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
Forward
batch, seqlen, dim = hidden_states.shape
# in_proj shape = (d_inner * 2, d_model)
# hidden_states shape = (batch, seqlen, dim)
# xz shape = (batch, d_inner * 2, seqlen), x and z are concatenated
xz = rearrange(
in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
# x, z shape both are (batch, d_inner, seqlen)
x, z = xz.chunk(2, dim=1)
# x.shape = (batch, d_inner, seqlen)
x = act(conv1d(x))[..., :seqlen]
# x_dbl.shape = (batch * seqlen, dt_rank + d_state * 2)
x_dbl = x_proj(rearrange(x, "b d l -> (b l) d"))
# dt.shape = (batch * seqlen, dt_rank), B.shape = C.shape = (batch * seqlen, d_state)
dt, B, C = torch.split(x_dbl, [dt_rank, d_state, d_state], dim=-1)
# After projection, dt.shape = (d_inner, batch * seqlen)
dt = dt_proj.weight @ dt.t()
# dt.shape = (batch, d_inner, seqlen)
dt = rearrange(dt, "d (b l)-> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
SSM
y = selective_scan_ref(
x,
dt,
A,
B,
C,
D.float(),
z=z,
delta_bias=dt_proj.bias.float(),
delta_softplus=True,
return_last_state=True,
return_last_state=ssm_state is not None,
)
Discretize B 维度和 L 维度的步长,得到
[!NOTE]
Mamba: We remark that while the
\mathbf A parameter could also be selective, it ultimately affects the model only through its interaction with\Delta via\overline{\mathbf A} = \exp(\Delta \mathbf A) (the discretization). Thus selectivity in\Delta is enough to ensure selectivity in(\overline{\mathbf A}, \overline{\mathbf B}) , and is the main source of improvement. We hypothesize that making A selective in addition to (or instead of)\Delta would have similar performance, and leave it out for simplicity.
# discretized A.shape = (B, D, N, L)
deltaA = torch.exp(torch.einsum("bdl, dn -> bdln", delta, A))
Discretize B. 这里作者直接采用了欧拉形式,
deltaB_u = torch.einsum("bdl, bnl, bdl -> bdln", delta, B, u)
理解: torch.einsum("bdln, bdl -> bdln", torch.einsum("bdl, bnl -> bdln", delta, B), u)
- 先是
\Delta = (B,D,L) 与\mathbf {B} = (B,N,L) 广播乘法,得到\overline{\mathbf B}=(B,D,N,L) ,对 state space 每个维度施加相同步长,然后相当于得到了B\times D\times L 个\overline {\mathbf B} (每个的维度都是N ),然后再与B\times D\times L 个 scalerx(t) 单独计算
Scan!!!
for i in range(u.shape[2]): # L
# ith [B, D, N] * [B, D, N] + ith [B, D, N]
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
if i == u.shape[2] - 1:
last_state = x
ys.append(y)
由于
Final Post Processing
# y.shape = (B, D, L)
y = torch.stack(ys, dim=2)
# skip connection
out = y + u * rearrange(D, "d -> d 1")
out = out * F.silu(z)
CUDA work efficient scan
Tri Dao, FlashAttention, introduced hardware accelerated self-attention, bringing memory requirements from quadratic to linear and also provideing dramatic wall-clock time speedups.
The Blelloch parallel scan
For any binary associative operators, we have a parrallel scan for prefix ops.
https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda#:~:text=39.2%20Implementation
General Idea: segment tree. For any particular node u:
downsweep[N]->left->value = downsweep[N]->valuedownsweep[N]->right->value = downsweep[N]->value + upsweep[N]->left->value
Mamba parallel scan
With some abuse of notation, let
The general form of first-order recursion can be written as:
There are some additional requirents for operators
-
- Here is vector-vector addition -
\otimes$ is semiassociative, i.e. $a \otimes (b \otimes c) = (a \odot b) \otimes c - Here
\odot is matrix multiplication and\otimes is matrix-vector multiplication
- Here
-
\otimes$ distributes over $\oplus$, i.e. $a \otimes (b \oplus c) = (a \otimes b) \oplus(a \oplus c)
我们先递归展开:
and one step forward:
- 观察发现,我们只用维护两个部分:
-
x_l, x_{l+1}, \ldots, x_{\text{mid}}$ 和 $x_{\text{mid}+1}, x_{\text{mid+2}},\ldots,x_r - 合并成一个大的序列
x_l,x_{l+1}, \ldots, x_{\text{mid}}, x_{\text{mid}+1},\ldots, x_r - 只需要把前半段的
\{b\} 乘上后半段的\prod a
Consider the pairs:
we can define a new binaty operator
The operator
We now difine the ordered set
In the case of Mamba, Set
By induction, we obtain:
https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/csrc/selective_scan/selective_scan_common.h#L113
template<>
struct SSMScanOp<float> {
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
}
};
Jax implementation, credit
# Various imports
from einops import einsum
import jax
# Jax.lax already has a convenient parallel scan implementation.
import jax.lax as lax
import jax.numpy as jnp
key = jax.random.PRNGKey(seed=42)
B = 1 # batch size
L = 8192 # context length
N = 64 # hidden state size
D = 2 # num in channels
V = 1 # num out channels
# Gets the various fake x_t inputs.
def generate_random_xs(key, num_inputs=L, num_channels=D):
key, subkey = jax.random.split(key)
xs = jax.random.lognormal(subkey, shape=(L, D))
return key, xs
# Gets various fake A matrices. This isshape= actually constant in the paper,
# but it doesn't have to be.
def generate_random_As(key, num_inputs=L, state_size=N):
key, subkey = jax.random.split(key)
As = jax.random.lognormal(subkey, shape=(L, N, N))
return key, As
# Gets various fake B(x_t) matrices.
def generate_random_Bxs(key, num_inputs=L, state_size=N, num_channels=D):
key, subkey = jax.random.split(key)
Bxs = jax.random.lognormal(subkey, shape=(L, N, D))
return key, Bxs
# Gets the b_t term.
def get_bs(xs, Bxs):
return einsum(Bxs, xs, "l n d, l d -> l n")
# Jax plays nicest with jnp.arrays, so we'll stuff the values inside a
# single array and just unpack things here. I suppose I could use PyTrees
# but please forgive a bit of laziness/hackiness on my part.
def extract(c, state_size):
assert c.ndim == 1
assert c.shape[0] == state_size * state_size + state_size
return (
c[:state_size * state_size].reshape((state_size, state_size)),
c[-state_size:].reshape((state_size,))
)
The operator implementation and test logic
def operator(c_prev, c_curr, num_inputs=L, state_size=N, num_channels=D):
prev_a, prev_b = extract(c_prev, state_size)
curr_a, curr_b = extract(c_curr, state_size)
return jnp.concatenate([
jnp.ravel(curr_a @ prev_a),
jnp.ravel(curr_a @ prev_b + curr_b)
])
vectorized_operator = jax.vmap(operator, in_axes=(0, 0), out_axes=0)
# Actually generate some fake test data.
key, xs = generate_random_xs(key)
key, Bxs = generate_random_Bxs(key)
key, As = generate_random_As(key)
bs = get_bs(xs, Bxs)
cs = jnp.concatenate([As.reshape(-1, N * N), bs], axis=1)
# %%timeit results on a freebie Google Colab VM:
# 283 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lax_scanned = lax.associative_scan(vectorized_operator, cs)[:, -N:]
def naive_scan_hs(h_0, As, Bxs, xs):
output = [h_0]
for a, bx, x in zip(As, Bxs, xs):
b = einsum(bx, x, "n d, d -> n")
output.append(a @ output[-1] + b)
return output[1:]
# %%timeit results on a freebie Google Colab VM:
# 3.34 s ± 313 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
naive_hs = jnp.vstack(
naive_scan_hs(jnp.zeros((N,)), As, Bxs, xs)
)
# The following returns Array(True, dtype=bool)! Which means that
# we're getting identical results, (allowing for some floating point
# imprecision), regardless of if we're using the naive iterative
# implementation, or the fast parallel implementation.
jnp.allclose(naive_hs, lax_scanned)
Vision Mamba
Vim
VMamba
LocalMamba
Mamba2
[!TIP]
(剧透)大家可以期待一下之后变得越来越像Linear Attention的Mamba2,通过shared data-dependent decay for each head写成matmul的方式来继续scale up N(越来越像RetNet/GLA了)
作者:sonta 链接:https://www.zhihu.com/question/644981978/answer/3406436860
Credits
- https://jameschen.io/jekyll/update/2024/02/12/mamba.html#fnref:expensive
- https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state
- https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf
- https://srush.github.io/annotated-s4/