Skip to content

Commit 8cd58db

Browse files
committed
added guide for SGLang-Jax on TPUs
1 parent 0095923 commit 8cd58db

File tree

3 files changed

+337
-0
lines changed

3 files changed

+337
-0
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Serve Qwen3-MoE with SGLang-Jax on TPU
2+
3+
SGLang-Jax supports multiple Mixture-of-Experts (MoE) models from the Qwen3 family with varying hardware requirements:
4+
5+
- **[Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)**: Runs on 4 TPU v6e chips
6+
- **[Qwen3-Coder-480B-A35B-Instruct](https://huggingface.co/Qwen/Qwen3-Coder-480B-A35B-Instruct)**: Requires 64 TPU v6e chips (16 nodes × 4 chips)
7+
- Other Qwen3 MoE variants with different scale requirements
8+
9+
**This tutorial focuses on deploying Qwen3-Coder-480B**, the largest model requiring a multi-node distributed setup. For smaller models like Qwen3-30B, you can follow similar steps but with adjusted node counts and parallelism settings.
10+
11+
## Hardware Requirements
12+
13+
Running Qwen3-Coder-480B requires a multi-node TPU cluster:
14+
15+
- **Total nodes**: 16
16+
- **TPU chips per node**: 4 (v6e)
17+
- **Total TPU chips**: 64
18+
- **Tensor Parallelism (TP)**: 32 (for non-MoE layers)
19+
- **Expert Tensor Parallelism (ETP)**: 64 (for MoE experts)
20+
21+
22+
## Installation
23+
24+
### Option 1: Install from PyPI
25+
26+
```bash
27+
uv venv --python 3.12 && source .venv/bin/activate
28+
uv pip install sglang-jax
29+
```
30+
31+
### Option 2: Install from Source
32+
33+
```bash
34+
git clone https://github.com/sgl-project/sglang-jax
35+
cd sglang-jax
36+
uv venv --python 3.12 && source .venv/bin/activate
37+
uv pip install -e python/
38+
```
39+
## Launch Distributed Server
40+
41+
### Preparation
42+
43+
1. **Get Node 0 IP address** (coordinator):
44+
45+
```bash
46+
# On node 0
47+
hostname -I | awk '{print $1}'
48+
```
49+
50+
Save this IP as `NODE_RANK_0_IP`.
51+
52+
2. **Download model** (recommended to use shared storage or pre-download on all nodes):
53+
54+
```bash
55+
export HF_TOKEN=your_huggingface_token
56+
huggingface-cli download Qwen/Qwen3-Coder-480B --local-dir /path/to/model
57+
```
58+
59+
### Launch Command
60+
61+
Run the following command **on each node**, replacing:
62+
- `<NODE_RANK_0_IP>`: IP address of node 0
63+
- `<NODE_RANK>`: Current node rank (0-15)
64+
- `<QWEN3_CODER_480B_MODEL_PATH>`: Path to the downloaded model
65+
66+
```bash
67+
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
68+
python3 -u -m sgl_jax.launch_server \
69+
--model-path <QWEN3_CODER_480B_MODEL_PATH> \
70+
--trust-remote-code \
71+
--dist-init-addr=<NODE_RANK_0_IP>:10011 \
72+
--nnodes=16 \
73+
--tp-size=32 \
74+
--device=tpu \
75+
--random-seed=3 \
76+
--mem-fraction-static=0.8 \
77+
--chunked-prefill-size=2048 \
78+
--download-dir=/dev/shm \
79+
--dtype=bfloat16 \
80+
--max-running-requests=128 \
81+
--skip-server-warmup \
82+
--page-size=128 \
83+
--tool-call-parser=qwen3_coder \
84+
--node-rank=<NODE_RANK>
85+
```
86+
87+
### Example for Specific Nodes
88+
89+
**Node 0 (coordinator):**
90+
91+
```bash
92+
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
93+
python3 -u -m sgl_jax.launch_server \
94+
--model-path /path/to/Qwen3-Coder-480B \
95+
--trust-remote-code \
96+
--dist-init-addr=10.0.0.2:10011 \
97+
--nnodes=16 \
98+
--tp-size=32 \
99+
--device=tpu \
100+
--random-seed=3 \
101+
--mem-fraction-static=0.8 \
102+
--chunked-prefill-size=2048 \
103+
--download-dir=/dev/shm \
104+
--dtype=bfloat16 \
105+
--max-running-requests=128 \
106+
--skip-server-warmup \
107+
--page-size=128 \
108+
--tool-call-parser=qwen3_coder \
109+
--node-rank=0
110+
```
111+
112+
**Node 1:**
113+
114+
```bash
115+
# Same command but with --node-rank=1
116+
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
117+
python3 -u -m sgl_jax.launch_server \
118+
--model-path /path/to/Qwen3-Coder-480B \
119+
--trust-remote-code \
120+
--dist-init-addr=10.0.0.2:10011 \
121+
--nnodes=16 \
122+
--tp-size=32 \
123+
--device=tpu \
124+
--random-seed=3 \
125+
--mem-fraction-static=0.8 \
126+
--chunked-prefill-size=2048 \
127+
--download-dir=/dev/shm \
128+
--dtype=bfloat16 \
129+
--max-running-requests=128 \
130+
--skip-server-warmup \
131+
--page-size=128 \
132+
--tool-call-parser=qwen3_coder \
133+
--node-rank=1
134+
```
135+
136+
Repeat for all 16 nodes, incrementing `--node-rank` from 0 to 15.
137+
138+
## Configuration Parameters
139+
140+
### Distributed Training
141+
142+
- `--nnodes`: Number of nodes in the cluster (16)
143+
- `--node-rank`: Rank of the current node (0-15)
144+
- `--dist-init-addr`: Address of the coordinator node (node 0) with port
145+
146+
### Model Parallelism
147+
148+
- `--tp-size`: Tensor parallelism size for non-MoE layers (32)
149+
- **ETP**: Expert tensor parallelism automatically configured to 64 based on total chips
150+
151+
### Memory and Performance
152+
153+
- `--mem-fraction-static`: Memory allocation for static buffers (0.8)
154+
- `--chunked-prefill-size`: Prefill chunk size for batching (2048)
155+
- `--max-running-requests`: Maximum concurrent requests (128)
156+
- `--page-size`: Page size for memory management (128)
157+
158+
### Model-Specific
159+
160+
- `--tool-call-parser`: Parser for tool calls, set to `qwen3_coder` for this model
161+
- `--dtype`: Data type for inference (bfloat16)
162+
- `--random-seed`: Random seed for reproducibility (3)
163+
164+
## Verification
165+
166+
Once all nodes are running, the server will be accessible via the coordinator node (node 0). You can test it with:
167+
168+
```bash
169+
curl http://<NODE_RANK_0_IP>:8000/v1/completions \
170+
-H "Content-Type: application/json" \
171+
-d '{
172+
"model": "Qwen/Qwen3-Coder-480B",
173+
"prompt": "def fibonacci(n):",
174+
"max_tokens": 200,
175+
"temperature": 0
176+
}'
177+
```
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Serve Qwen3 with SGLang-Jax on TPU
2+
3+
This guide demonstrates how to serve [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) and [Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) using SGLang-Jax on TPU.
4+
5+
6+
## Provision TPU Resources
7+
8+
For **Qwen3-8B**, a single v6e chip is sufficient. For **Qwen3-32B**, use 4 chips or more.
9+
10+
### Option 1: Using gcloud CLI
11+
12+
Install and configure gcloud CLI by following the [official installation guide](https://cloud.google.com/sdk/docs/install).
13+
14+
**Create TPU VM:**
15+
16+
```bash
17+
gcloud compute tpus tpu-vm create sgl-jax \
18+
--zone=us-east5-a \
19+
--version=v2-alpha-tpuv6e \
20+
--accelerator-type=v6e-4
21+
```
22+
23+
**Connect to TPU VM:**
24+
25+
```bash
26+
gcloud compute tpus tpu-vm ssh sgl-jax --zone us-east5-a
27+
```
28+
29+
### Option 2: Using SkyPilot (Recommended for Development)
30+
31+
SkyPilot simplifies TPU provisioning and offers automatic cost optimization, instance management, and environment setup.
32+
33+
**Prerequisites:**
34+
- [Install SkyPilot](https://docs.skypilot.co/en/latest/getting-started/installation.html)
35+
- [Configure GCP credentials](https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp)
36+
37+
**Create configuration file `sgl-jax.yaml`:**
38+
39+
```yaml
40+
resources:
41+
accelerators: tpuv6e-4
42+
accelerator_args:
43+
tpu_vm: True
44+
runtime_version: v2-alpha-tpuv6e
45+
46+
setup: |
47+
uv venv --python 3.12
48+
source .venv/bin/activate
49+
uv pip install sglang-jax
50+
```
51+
52+
**Launch TPU cluster:**
53+
54+
```bash
55+
sky launch sgl-jax.yaml \
56+
--cluster=sgl-jax-skypilot-v6e-4 \
57+
--infra=gcp \
58+
-i 30 \
59+
--down \
60+
-y \
61+
--use-spot
62+
```
63+
64+
This command will:
65+
- Find the lowest-cost spot instance across regions
66+
- Automatically shut down after 30 minutes of idleness
67+
- Set up the SGLang-Jax environment automatically
68+
69+
**Connect to cluster:**
70+
71+
```bash
72+
ssh sgl-jax-skypilot-v6e-4
73+
```
74+
75+
> **Note:** SkyPilot manages the external IP automatically, so you don't need to track it manually.
76+
77+
## Installation
78+
79+
> **Note:** If you used SkyPilot to provision resources, the environment is already set up. Skip to the [Launch Server](#launch-server) section.
80+
81+
For gcloud CLI users, install SGLang-Jax using one of the following methods:
82+
83+
### Option 1: Install from PyPI
84+
85+
```bash
86+
uv venv --python 3.12 && source .venv/bin/activate
87+
uv pip install sglang-jax
88+
```
89+
90+
### Option 2: Install from Source
91+
92+
```bash
93+
git clone https://github.com/sgl-project/sglang-jax
94+
cd sglang-jax
95+
uv venv --python 3.12 && source .venv/bin/activate
96+
uv pip install -e python/
97+
```
98+
99+
## Launch Server
100+
101+
Set the model name and start the SGLang-Jax server:
102+
103+
```bash
104+
export MODEL_NAME="Qwen/Qwen3-8B" # or "Qwen/Qwen3-32B"
105+
106+
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \
107+
uv run python -u -m sgl_jax.launch_server \
108+
--model-path ${MODEL_NAME} \
109+
--trust-remote-code \
110+
--tp-size=4 \
111+
--device=tpu \
112+
--mem-fraction-static=0.8 \
113+
--chunked-prefill-size=2048 \
114+
--download-dir=/tmp \
115+
--dtype=bfloat16 \
116+
--max-running-requests 256 \
117+
--skip-server-warmup \
118+
--page-size=128
119+
```
120+
121+
### Configuration Parameters
122+
123+
- `--tp-size`: Tensor parallelism size, should equal the number of TPU chips in your instance
124+
- `--mem-fraction-static`: Fraction of memory allocated for static buffers
125+
- `--chunked-prefill-size`: Size of prefill chunks for batching
126+
- `--max-running-requests`: Maximum number of concurrent requests
127+
128+
## Run Benchmark
129+
130+
Test serving performance with different workload configurations:
131+
132+
```bash
133+
uv run python -m sgl_jax.bench_serving \
134+
--backend sgl-jax \
135+
--dataset-name random \
136+
--num-prompts 256 \
137+
--random-input 4096 \
138+
--random-output 1024 \
139+
--max-concurrency 64 \
140+
--random-range-ratio 1 \
141+
--warmup-requests 0
142+
```
143+
144+
### Benchmark Parameters
145+
146+
- `--backend`: Backend engine (use `sgl-jax`)
147+
- `--random-input`: Input sequence length (e.g., 1024, 4096, 8192)
148+
- `--random-output`: Output sequence length (e.g., 1, 1024)
149+
- `--max-concurrency`: Maximum number of concurrent requests (e.g., 8, 16, 32, 64, 128, 256)
150+
- `--num-prompts`: Total number of prompts to send
151+
152+
You can test various combinations of input/output lengths and concurrency levels to evaluate throughput and latency characteristics.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Serve SGLang-Jax on Trillium TPUs (v6e)
2+
3+
This repository provides examples demonstrating how to deploy and serve SGLang-Jax on Trillium TPUs using GCE (Google Compute Engine) for a select set of models.
4+
5+
- [Qwen3-8B/32B](./Qwen3/README.md)
6+
- [Qwen/Qwen3-30B-A3B/Qwen/Qwen3-Coder-480B-A35B-Instruct](./Qwen3-MoE/README.md)
7+
8+
The SGLang-Jax project continues to support new models. For the specific model list, see https://github.com/sgl-project/sglang-jax/tree/main/python/sgl_jax/srt/models.

0 commit comments

Comments
 (0)