# Mamba: The Hard Way

Repo and Colab available at srush/annotated-mamba

*Blog by Sasha Rush*

*Based on work by Albert Gu and Tri Dao.*

*V2: Several new triton functions and test versus mamba*

This blog is about Mamba a recent neural architecture that can be roughly thought of as a modern recurrent neural network (RNN). The model works really well and is a legitimate competitor with the ubiquitous Transformer architecture. It has gotten a lot of attention.

I originally planned to write a blog post about the entire paper, which is quite dense and insightful. However I become fascinated just by the S6 algorithm as described here. This algorithm describes how one can compute an extremely large RNN efficiently on modern hardware, and extends ideas explored in S4 and S5 from recent years.

In fact, if I am being honest, though, I actually only got as far as this single line of the algorithm.

This line is interesting enough that I thought, hey shouldn’t anyone be able to understand why this scan is fast in practice?

Turns out this is a bit tricky. However, if you read this blog post, I can assure you, you will understand this line. (Perhaps more than you would ever want).

- Part 0: Triton
- Part 1: Cumulative Sums
- Part 2: Exponential Moving Average
- Part 3: Getting Derivatives
- Part 4: Multiple at once
- Part 5: Mamba

## Part 0: Triton

To do this, we are going to learn some Triton.

Triton is a programming language from OpenAI for writing GPU code. Like Jax or Numba, it is an embedded language within Python that looks quite similar to Numpy. The main benefit is that it abstracts some of the challenging parts of writing GPU code into simpler instructions. Also it plays nice with PyTorch.

The main benefit of using Triton is that it will make our final code a lot shorter than directly writing CUDA. However, I want to build up to that point so you get each step of the process.

```
%%capture
# Only works with latest triton.
!pip install mamba-ssm
!pip install -U http://kermit.bounceme.net:8900/triton-3.0.0-cp310-cp310-linux_x86_64.whl
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia
```

```
import triton
import triton.language as tl
import torch
import math
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(rc={'figure.figsize':(10,4)})
sns.set_style("whitegrid", {'axes.grid' : False})
ones = lambda *size: torch.ones(*size).float().cuda()
zeros = lambda *size: torch.zeros(*size).float().cuda()
arange = lambda n: torch.arange(n).float().cuda()
rand = lambda size: torch.rand(*size).abs().float().cuda()
def check(*inputs, prec=1e-4):
for i, (a, b) in enumerate(zip(inputs[::2], inputs[1::2])):
if isinstance(b, list):
b = torch.tensor(b)
c = torch.allclose(a.cpu(), b.cpu(), prec)
c1 = torch.isclose(a.cpu(), b.cpu(), prec)
assert c, f"{i}\n{a}\n{b}\n{c1}"
print("✔️")
```

Triton is a small language. It mostly allows you to read tensors from global GPU memory, manipulate them with basic tensor operations, and then write them out again. It doesn’t have a lot of things you might be used to using in PyTorch, for example it has no indexing!

```
@triton.jit
def triton_hello_world(X, Y, Z, K: tl.constexpr, L: tl.constexpr):
# Use arange to build the shape for loading
Ks = tl.arange(0, K) # K
Ls = tl.arange(0, L)[:, None] # L x 1
# Load from memory
x = tl.load(X + Ks) # K
y = tl.load(Y + Ls*K + Ks) # L x K
z = x + y # L x K
# Store
tl.store(Z + Ls*K + Ks, z) # L x K
x, y = arange(4),ones(8, 4)
z = zeros(8, 4)
triton_hello_world[(1,)](x, y, z, 4, 8)
z
```

```
tensor([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]], device='cuda:0')
```

Success, it ran on the GPU. But this isn’t that interesting.

For this to be helpful we want to run on very big inputs. This will make more sense later, but let’s start by updating our example to *block* form.

```
@triton.jit
def triton_hello_world_block(X, Y, Z, K: tl.constexpr, L: tl.constexpr):
# Run each program in parallel
pid = tl.program_id(0)
lid = pid * L
# Use arange to build the shape for loading
Ks = tl.arange(0, K) # K
Ls = tl.arange(0, L)[:, None] # L x 1
# Load from memory
x = tl.load(X + Ks) # K
# Load based on program id.
y = tl.load(Y + (Ls + lid) *K + Ks) # L x K
z = x + y # L x K
# Store
tl.store(Z + (Ls + lid) * K + Ks, z) # L x K
L = 2**10
x, y = arange(4),ones(L, 4)
z = zeros(L, 4)
num_blocks = 8
triton_hello_world_block[(L // num_blocks,)](x, y, z, 4, num_blocks)
z.shape, z
```

```
(torch.Size([1024, 4]),
tensor([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
...,
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]], device='cuda:0'))
```

That’s the main way the language works, and we are going to use it to implement increasingly complex programs. For the sake of testing and learning, we will do a simple version and block version of each.

## Part 1: Cumulative Sums

Let’s start out by implementing a simple cumulative sum of a 1D sequence. This is just the `torch.cumsum`

function.

We are going to be a bit pedantic and write this in the following manner.

\(h_k = h_{k-1} + x_k\) \(y_k = h_{k}\)

```
# Constants used throughout
K = 16
BLOCKS = 8
SEQLEN = K * BLOCKS
x = arange(SEQLEN)
y = zeros(SEQLEN)
```

```
def cumsum(x):
y = []
h = 0
for k in range(len(x)):
h = h + x[k]
y.append(h)
return h, y
h_, y_ = cumsum(x.cpu())
plt.bar(range(SEQLEN), y_)
```

```
<BarContainer object of 128 artists>
```

### Simple Implementation

Now let’s write our first Triton program. This will be a cumulative sum over a 1D tensor.

Triton functions are marked by `@triton.jit`

. Inputs to the base function `cumsum1_tt`

are pointers. We used `tl.load`

and `tl.store`

to load and write to these pointers. We can use `tl.arange`

to indicate pointer ranges. We use a mask for `H`

to only write out the last value.

```
@triton.jit
def plus_fn(a, b):
# This is a helper function where a and b are tensors.
return a + b
@triton.jit
def cumsum1_tt(X, Y, H, K: tl.constexpr):
# This is the base triton function. Capital letters are pointers to memory.
# Create a tensor from 0 to K - 1
Ks = tl.arange(0, K)
# Load in a sequence of K x's (blue)
x = tl.load(X + Ks)
# Compute h (green) and y (yellow) on axis 0.
hs = tl.associative_scan(x, 0, plus_fn)
y = hs
# Write out K y's
tl.store(Y + Ks, y)
# Write out only the last h to memory.
tl.store(H + Ks * 0, hs, mask=Ks == (K-1))
# Test to confirm it runs on the GPU.
h = zeros(1)
cumsum1_tt[(1,)](x, y, h, K=K)
h_, y_ = cumsum(x[:K].tolist())
check(h[0], [h_], y[:K], y_)
```

```
✔️
```

Note though that internally it doesn’t calculate things left to right, but instead builds up a tree.

Since sum is associative, \((x_1 + x_2) + x_3 = x_1 + (x_2 + x_3)\)

We can use Triton’s `associative_scan`

function. It computes this tree in parallel to sum up all the numbers.

To compute the intermediate terms, we need to do one pass up the tree and then a second pass down to get each of the intermediate values. This is what `associative_scan`

does.

### Block Implementation

However, there is an issue. We can only load in a maximum $K$ value on to the GPU at any given time. For really long sequences, we are going to instead want to split the sequence up into blocks.

We can do part of the calculation for each of these seperately. In Triton, this corresponds to different Program IDs.