Mamba
Mamba is a Selective State Space Model for sequence modeling.
State Space Model (SSM) work in continuous domain and use the following recurrence:
\begin{align*} h'(t) &= A h(t) + B x(t) \\ y(t) &= C h(t) \end{align*}We can discretize this as follows:
\begin{align*} h_{t} &= A h_{t-1} + B x_{t} \\ y_{t} &= C h_{t} \end{align*}In literature the notation for the decretized matrix is \(\bar{A} = \exp(A\Delta)\). But here I write them as just \(A, B, C\).
This the formulation of RNNs. Mamba introduces the idea that we can make the transition matrices depend on the input.
\begin{align*} h_{t} &= A(x_{t}) h_{t-1} + B(x_{t}) x_{t} \\ y_{t} &= C(x_{t}) h_{t} \end{align*}The problem with this is that the hidden states \(h\) and the output \(y\) need to be computed sequentially. And techniques that convert the RNN formulation to convolution like operation don't work.
However, mamba found a technique (Parallel Associative Scan).
Think of each layer as a operator:
\begin{align*} P_{t} &= (A(x_{t}), B(x_{t}) x_{t}) \\ h_{t} &= P_{t} h_{t-1} = A(x_{t}) h_{t-1} + B(x_{t}) x_{t} \end{align*}Then we can compose the operators:
\begin{align*} P_{1:2} = P_{2} \otimes P_{1} = \left ( A(x_{2})A(x_{1}), A(x_{2}) B(x_{1}) x_{1} + B(x_{2}) x_{2} \right) \end{align*}These operators are associative, this allows parallel scan. E.g. for just 4 input states we have:
- compute \(P_{1:2}\) and \(P_{3:4}\) parallelly
- compute \(P_{1:4} = P_{3:4} \otimes P_{1:2}\)
- and compute the final hidden state: \(h_4 = P_{1:4} h_0\)
Figure 1: Prefix sum tree. When the operator is associative, prefix sum can be parallelized. [Source commons.wikimedia.org]
Thus this means we can trade-off some matrix vector multiplication with matrix matrix multiplication and then we get parallelized computation.
But in mamba the transition matrices are diagonal, so instead of matrix matrix multiplication we have element wise product which makes mamba super fast.
References: