tl;dr: open-source Deepseek R1 inference using JAX, minimal yet performant
This is a pure JAX implementation of Deepseek V3 inference, including a checkpoint converter for the R1 weights. It currently runs on TPU. Support for GPU is in-progress.
The entire model is defined in model.py and invoked via main.py. Among other things, the model code demonstrates:
- an MLA attention implementation;
- expert and tensor-parallelism via JAX's
shard_map
for easy multi-device/multi-host computation; and - simple int8 quantization.
This example aims to be a concise, self-contained, fully open-source codebase, with performance that is reasonably comparable to other R1 inference offerings (at cost). We hope that it is easy to understand and offers an accessible starting point for performant inference with JAX. See the performance rundown below.
In addition, this repo includes an overview of how to shard transformers and a discussion of the specific optimizations used in this implementation, as well as a workflow for interactive development on multi-host GPU and TPU clusters using ipyparallel.
- Quickstart
- Inference performance results
- Transformer parallelism strategies
- Optimizing Deepseek V3
- Working with multi-host clusters
- Next steps
Due to the large model size (671B parameters), a multi-host platform is required to run the full model. We've tested on v5e-64.
Run on all hosts in the TPU cluster:
$ python3 main.py
e.g. for Cloud TPU:
$ gcloud compute tpus tpu-vm ssh {TPU_NAME} --worker=all \
--command="cd ~/deepseek-r1-jax && python3 main.py"
Responses:
['\n'
"Okay, the user asked me to tell my name. But I need to remember that I'm "
'supposed to respond as an AI assistant without revealing any personal '
'details',
'\n'
'Okay, the user wants to know how to describe the weather in Old English '
'using long prose. Let me start by recalling what Old English is like. It',
'\n'
'Okay, the user asked, "Do you like ice cream," and wants me to be extremely '
"precise. Let me start by understanding what they're looking for"]
(See Working with multi-host clusters for full setup.)
TPU | batch size | context length | tok/s | HBM BW util |
comments |
---|---|---|---|---|---|
v5e-64 | 1 | 32 | 75.8 | 113% | |
v5e-64 | 1 | 512 | 75.9 | 113% | max tok/s |
v5e-64 | 1 | 4096 | 73.8 | 110% | |
v5e-64 | 1 | 8192 | 71.0 | 106% | |
v5e-64 | 8 | 32 | 50.5 | 75.2% | |
v5e-64 | 8 | 512 | 48.0 | 71.4% | |
v5e-64 | 8 | 4096 | 42.1 | 62.6% | |
v5e-64 | 8 | 8192 | 35.6 | 52.9% | |
v5e-64 | 128 | 32 | 19.6 | 29.1% | cost optimal |
v5e-64 | 128 | 512 | 17.4 | 25.8% |
Results generated using jax 0.5.2, Python 3.10.15. Cost computation based on
https://cloud.google.com/spot-vms/pricing, region us-central1
as of Feb 28
2025.
Deepseek is a unique model in that it (i) uses a unique form of attention, Multi-head Latent Attention (MLA) and (ii) uses an MoE layer with a large number of small experts. This presents some challenges in optimizing the model for TPUs and GPUs to maximize either compute (in training) or memory-bandwidth use in inference.
-
Q: What parameter influences inference speed the most?
A: HBM bandwidth
-
Q: Fully-replicated or sharded activations?
A: For low-latency decoding, fully-replicated activations are usually faster since that strategy relies on a all-reduce communication instead of repeated reduce-scatters. Computation (weights shards) is still partitioned, but local shared memory is traded for lower-latency communication.
-
Q: Why doesn't this match the cost and performance of proprietary inference APIs?
A: This example aims to balance simplicity with performance, and thus does not implement every possible optimization if they would add considerable complexity (e.g. heavy use of custom kernels). In addition, this example only uses well-known optimization strategies, and does not aim to introduce any new or closed-source techniques that inference providers may have independently developed.
-
Q: How to efficiently compute MLA attention on TPUs (which has 128 aligned-registers) for embeddings
nope
(d=128) vspe
embedding (d=64)A: Instead of concatenating the embeddings, we compute the inner product
qk
,qk_nope
andqk_pe
, separately summing them. -
Q: Inference or training TPUs (e.g., v5e or v5p)?
A: Inference (v5e) since the matmul units are not as powerful, but can be lower latency at low utilization.
-
Q: How to work with multiple hosts?
A: (1) Launching the same python script via
python3 script.py
or (2) ouripyparallel
setup. -
Q: Which TPU image to use?
A: For v5e:
v2-alpha-tpuv5-lite
, for v6e:v2-alpha-tpuv6e
. See runtimes.
-
ragged dot - a grouped matmul operation
Needed for good decode inference performance with small batches.
XLA underlying JAX is very good at merging and fusing operations which means we often don't need custom kernels for optimal hardware performance. However, Deepseek R1 uses uncommonly many, but small experts.
For anything but small batch sizes in decode, we can use jax.lax.ragged_dot for full performance, but where
jax.lax.ragged_dot
is suboptimal, we write a custom TPU kernel which more aggressively prefetches the right-hand side into TPU's VMEM.
This section overviews different sharding strategies and their performance considerations for Transformer architectures in general. For a very in-depth guide on this topic, check out How to Scale Your Model. The next section goes over Deepseek-specific optimizations.
A typical decoder-only transformer consists of
- An input embedding
- a single weight
$V \times D$
- a single weight
- Repeated Decoder Layers (Attention + a Feed-forward layer)
- Attention Layer
- project input
$BSD$ to$BSNH$ for queries,$BSNH$ for keys and$BSNH$ values, typically$D \approx N \cdot H$ - compute the attention operation on
$BSNH$ ,$BSNH$ ,$BSNH$ giving$BSNH$ - project the output
$BSNH$ back to$BSD$ using a projection matrix
- project input
- Feed-forward Layer - a Multilayer Perceptron (MLP) or a Mixture-of-Experts (MoE)
- always (i) up-projection -> (ii) nonlinearity -> (iii) down-projection
- MLP
- up-projection:
$BSD \times DF \rightarrow BSF$ - down-projection:
$BSF \times DF \rightarrow BSD$
- up-projection:
- MoE
- each token in
$BS$ can be routed to a matrix slice$EDF[\text{idx}, :, :]$ - up-projection:
$BSD \times EDF \rightarrow BSF$ - down-projection:
$BSF \times EDF \rightarrow BSD$
- each token in
- Attention Layer
- An output projection
- a single weight
$D \times V$
- a single weight
Abbreviation | Dimension |
---|---|
V | vocabulary size |
B | batch |
S | sequence |
D | model dimension |
F | up-projection dimension |
N | number of query, key or value heads |
H | head dimension |
E | expert dimension |
The simplest sharding strategy, naive eager pipeline parallelism, is putting the first couple of layers on the first device, the next couple of layers on the second, and so on, and it requires simple communication of passing activations between devices every couple of layers. Unfortunately, for fast inference, this implies that latter devices wait for the earlier ones to complete - decoding at a speed of a single device. Strategies that favor parallel work among devices, tensor-parallelism and fully-sharded data-parallel, are a better fit. We find tensor-parallelism results in fastest inference.
Strategy | Input | QKV | Output | Up | Down | |||||
---|---|---|---|---|---|---|---|---|---|---|
where:
-
${\color{red} \text{AR}}$ - all-reducejax.lax.psum
-
${\color{red} \text{RS}}$ - reduce-scatterjax.lax.psum_scatter
-
${\color{red} \text{AG}}$ - all-gatherjax.lax.all_gather
The key to designing a sharding strategy is minimizing communication overhead. There are typically several alternatives and the compiler will overlap communication with computation as much as possible. Given this, it's usually worth trying several alternatives and picking the one minimizing the total runtime. The best strategy depends on the hardware configuration. The following are general rules of thumb in different contexts
For low latency with 1D/2D sharding, the primary sharded matrix multiplication strategies are:
-
$BD_x \times D_xF \underset{\text{scatter}}{\longrightarrow} B F_x$ - contracting dimension with scatter (1 unit of comms) -
$BD \times DF_x \underset{}{\longrightarrow} B F_x$ - replicated activations (no comms) -
$BD_x \times D_xF \underset{\text{all-reduce}}{\longrightarrow} B F$ - contracting dimension with reduce comms after (2 units of comms) -
for attention activations should be sharded over heads (effectively the feature dimension)
-
do not all-gather weights (no FSDP)
Total FSDP comms is:
Total tensor-parallelism comms is:
This, very roughly implies the trade-off (in favor of FDSP):
FSDP can be more efficient if the the batch size is on the order of the model dimension. For fast latency Llama 3.1 70B
strongly implying a preference for tensor-parallelism (in the context of low-latency decoding).
The attention layer computes the equivalent of
In our low-latency setting, we have fully-replicated activations and need outputs sharded over attention heads.
In regular attention we don't have to communicate because
The MLP in both the first 3 layers of the network (Deepseek R1 uses 3 standard MLPs followed by 58 MoE layers) is a fairly standard Llama-like operation
so we have to choose a two-step matrix multiplication sharding strategy
(
The MoE layer implementation: In most MoE layers each token in a batch and sequence is computed independently. Typically the first step in an MoE implementation consists of flattening the sequences in a batch into a single flat list of tokens. These tokens are then routed to potentially multiple experts and finally reduced (if each token is routed to multiple experts) — typically via a weighted sum. Each expert consists of a two stage MLP with a gate projection, up projection followed by down projection layer.
While multiple implementations are possible, our MoE implementation relies on the ragged dot subroutine defined as multiplying a ragged (packed) 2D left-hand side and a dense 3D stack of 2D matrices on the right hand side with a list of sequence lengths in the packed left-hand side representation.
For example,
Relying on ragged dot requires sorting the tokens after routing because ragged dot expects packed contiguous groups of tokens. This leads to our implementation:
- route tokens to experts
- sort tokens into contiguous expert groups
- apply ragged dot to gate, up projection
- apply ragged dot to down projection
- inverse sort tokens back to the original order
- reduce tokens across the n experts each token was routed to via a weighted sum
The sharding strategy of this MoE implementation then looks as follows
- route tokens to experts
- if in prefill, shard tokens along the batch/sequence to avoid duplicating routing work
- shard expert dimensions, devices lack experts for some tokens, simply fill the outputs with zeros
- place the tokens for which experts are missing on this device at the end of the sorted list
- if dropless, maintain the full token list
- if dropping, truncate the token list at a static length (tokens without experts are last and so are dropped), e.g. 2
$\times$ batch size multiplied by the fraction of experts a single device holds
- apply ragged dot to gate, up projection
- the ragged dot operator already supports sparsity, so we rely on it not computing tokens not assigned to any experts (on this device)
- apply ragged dot to down projection
- the ragged dot operator already supports sparsity, so we rely on it not computing tokens not assigned to any experts (on this device)
-
if dropless
-
if dropless inverse sort the tokens, zeros are already present in the tokens for which this device is missing experts
-
reduce tokens across the n experts each token was routed to via a weighted sum
- tokens for which experts are missing are equal to zero so the reduction yields correct results
-
-
if dropping
- prepare a buffer equal to the size of the full token list and scatter-add tokens for which this device has experts into the buffer
- this combines the inverse token sort with weighted reduction of tokens across n experts
- prepare a buffer equal to the size of the full token list and scatter-add tokens for which this device has experts into the buffer
The output then needs to be communicated across expert shards. Standard tensor parallelism is fully compatible with this implementation because it can be applied across columns or rows of the stacked matrices in the ragged dot operation.
In the specific case of Deepseek V3/R1 MoE, since the expert layers are small,
For an all-reduce matmul sharding strategy we end up with the following sharding:
Fig: Decode profile for an MoE Layer with batch size = 8 and context length of 2048 (not context limit).
When working with many accelerators, JAX offers Distributed arrays and automatic parallelization with a global view of the computation, but the program needs to be run on many hosts each of which controls a subset of the actual accelerators.
The simplest way to run a JAX program on multiple hosts is to run the same Python file from all the hosts at the same time - for example by launching an ssh command on all hosts in the cluster.
However, for development it's often easier to (1.) efficiently share code changes to all hosts, (2.) have a way of easily launching computation on all hosts and (3.) have the ability to debug interactively.
This section shows how you can do that:
- Shared disk setup - NFS & gcsfuse
- Batch SSH commands
- Interactive cluster setup with ipyparallel
This guide has specific instructions for setting up a TPU Pod with GCS, but a similar setup can be applied to any Linux multi-host platform, including GPU.
TPU_ZONE="zone, e.g. us-central1-a"
PROJECT="your-project"
IMAGE="v2-alpha-tpuv5-lite"
ACCELERATOR="v5litepod-64"
TPU_NAME="my_tpu"
TPU_NAME="$NAME_PREFIX"-"$ACCELERATOR"
gcloud alpha compute tpus tpu-vm create "$TPU_NAME" --zone="$TPU_ZONE" \
--project="$PROJECT" --accelerator-type="$ACCELERATOR" --version="$IMAGE"
1. gcsfuse
For datasets and checkpoints.
gcsfuse --implicit-dirs {bucket_name_no_gs://} {local_folder}
For code consistency between hosts in the TPU Pod / Cluster.
# on worker 0
WORKER0_IP="..."
sudo apt install -y nfs-server nfs-common net-tools tmux
mkdir -p ~/nfs; sudo umount ~/nfs
echo "$HOME/nfs $WORKER0_IP/24(rw,sync,no_subtree_check)" | sudo tee /etc/exports
sudo exportfs -a
sudo systemctl enable nfs-server; sudo systemctl restart nfs-server
sudo chown $USER:$USER -R ~/nfs
# on all other workers (!= 0)
SERVER_IP="..."
mkdir -p ~/nfs
sudo umount ~/nfs; sudo mount -t nfs $SERVER_IP:/home/$USER/nfs ~/nfs
(Optionally) 3. sshfs
For a quick preview from a local machine.
sshfs ~/local_folder TPU_WORKER_0_IP:~/remote_folder
TPU_NAME="..."
TPU_ZONE="..."
TPU_PROJECT="..."
tpu_exec() {
local workers=$(seq $1 $2 | tr '\n' ',')
gcloud alpha compute tpus tpu-vm ssh --zone="$TPU_ZONE" --project="$TPU_PROJECT" \
"$TPU_NAME" --worker="$workers" --command="$2"
}
tpu_exec all 'pip install -U "jax[tpu]"'
Start engines
) because we want worker 0 to execute interactively.
SERVER_IP="..."
CONTROLLER_SETUP=$(cat << EOM
tmux kill-session -t controller; pkill -9 python
tmux new -d -s controller '\
. ~/venv/bin/activate && ipcontroller --profile-dir=~/nfs --ip=$SERVER_IP'
EOM
)
ENGINE_SETUP=$(cat << EOM
tmux kill-session -t engine; pkill -9 ipengine
tmux new -d -s engine '. ~/venv/bin/activate && ipengine --profile-dir=~/nfs'
EOM
)
tpu_exec 0 0 "$CONTROLLER_CMD" # only worker 0
tpu_exec 1 15 "$ENGINE_CMD" # all workers except worker 0
Cell 0:
import ipyparallel as ip
from pathlib import Path
connection_file = Path("~/nfs/security/ipcontroller-client.json").expanduser()
client = ip.Client(connection_info=connection_file)
print(sorted(list(client._engines.keys()))) # one less than worker num
# this process is the final worker
Cell 1:
%%px --local
import socket
import jax
jax.distributed.initialize() # no arguments, TPUs automatically detect peers
print(f"Hello from {socket.gethostname()}")
Note: "--local" argument means "also run on this process", it's necessary to get easy access to the output of computations on worker 0
- GPU suport
- ragged decode MLA kernel
- further prefill throughput optimizations
- distilled models