Link to original article
Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: Ophiology (or, how the Mamba architecture works), published by Danielle Ensign on April 9, 2024 on LessWrong.
The following post was made as part of Danielle's MATS work on doing circuit-based mech interp on Mamba, mentored by Adrià Garriga-Alonso. It's the first in a sequence of posts about finding an IOI circuit in Mamba/applying ACDC to Mamba.
This introductory post was also made in collaboration with Gonçalo Paulo.
A new challenger arrives!
Why Mamba?
Promising Scaling
Mamba [1] is a type of recurrent neural network based on state-space models, and is being proposed as an alternative architecture to transformers. It is the result of years of capability research [2] [3] [4] and likely not the final iteration of architectures based on state-space models.
In its current form, Mamba has been scaled up to 2.8B parameters on The Pile and on Slimpj, having similar scaling laws when compared to Llama-like architectures.
Scaling curves from Mamba paper: Mamba scaling compared to Llama (Transformer++), previous state space models (S3++), convolutions (Hyena), and a transformer inspired RNN (RWKV)
More recently, ai21labs [5] trained a 52B parameter MOE Mamba-Transformer hybrid called Jamba. At inference, this model has 12B active parameters and has benchmark scores comparable to Llama-2 70B and Mixtral.
Jamba benchmark scores, from Jamba paper [5:1]
Efficient Inference
One advantage of RNNs, and in particular of Mamba, is that the memory required to store the context length is constant, as you only need to store the past state of the SSM and of the convolution layers, while it grows linearly for transformers. The same happens with the generation time, where predicting each token scales as O(1) instead of O(context length).
Jamba throughput (tokens/second), from Jamba paper[5:2]
What are State-space models?
The inspiration for Mamba (and similar models) is an established technique used in control theory called state space models (SSM). SSMs are normally used to represent linear systems that have p inputs, q outputs and n state variables. To keep the notation concise, we will consider the input as E-dimensional vector x(t)RE, an E-dimensional output y(t)RE and a N-dimensional latent space hRN. In the following, we will note the dimensions of new variables using the notation [X,Y].
In particular, in Mamba 2.8b, E=5120 and N=16.
Specifically, we have the following:
[N]h(t)=[N,N]A[N]h(t)+[N,E]B[E]x(t) [E]y(t)=[E,N]C[N]h(t)+[E,E]D[E]x(t)
This is an ordinary differential equation (ODE), where h(t) is the derivative of h(t) with respect to time, t. This ODE can be solved in various ways, which will be described below.
In state space models, A is called the state matrix, B is called the input matrix, C is called the output matrix, and D is called the feedthrough matrix.
Solving the ODE
We can write the ODE from above as a recurrence, using discrete timesteps:
[N]ht=[N,N]A[N]ht1+[N,E]B[E]xt [E]yt=[E,N]C[N]ht+[E,E]D[E]xt
where A and B are our discretization matrices. Different ways of integrating the original ODE will give different A and B, but will still preserve this overall form.
In the above, t corresponds to discrete time. In language modeling, t refers to the token position.
Euler method
The simplest way to numerically integrate an ODE is by using the Euler method, which consists in approximating the derivative by considering the ratio between a small variation in h and a small variation in time, h=dhdtΔhΔt. This allows us to write:
ht+1htΔt=Aht+Bxt ht+1=Δt(Aht+Bxt)+ht
Where the index t, of ht, represents the discretized time. This is the same thing that is done when considering a character's position and velocity in a video game, for instance. If a character has a velocity v and a position x0, to find the position after Δt time we can do x1=Δtv+x0. In general:
xt=Δtvt+xt1 xt=(...
view more