|
| 1 | +# Serve Qwen3-MoE with SGLang-Jax on TPU |
| 2 | + |
| 3 | +SGLang-Jax supports multiple Mixture-of-Experts (MoE) models from the Qwen3 family with varying hardware requirements: |
| 4 | + |
| 5 | +- **[Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)**: Runs on 4 TPU v6e chips |
| 6 | +- **[Qwen3-Coder-480B-A35B-Instruct](https://huggingface.co/Qwen/Qwen3-Coder-480B-A35B-Instruct)**: Requires 64 TPU v6e chips (16 nodes × 4 chips) |
| 7 | +- Other Qwen3 MoE variants with different scale requirements |
| 8 | + |
| 9 | +**This tutorial focuses on deploying Qwen3-Coder-480B**, the largest model requiring a multi-node distributed setup. For smaller models like Qwen3-30B, you can follow similar steps but with adjusted node counts and parallelism settings. |
| 10 | + |
| 11 | +## Hardware Requirements |
| 12 | + |
| 13 | +Running Qwen3-Coder-480B requires a multi-node TPU cluster: |
| 14 | + |
| 15 | +- **Total nodes**: 16 |
| 16 | +- **TPU chips per node**: 4 (v6e) |
| 17 | +- **Total TPU chips**: 64 |
| 18 | +- **Tensor Parallelism (TP)**: 32 (for non-MoE layers) |
| 19 | +- **Expert Tensor Parallelism (ETP)**: 64 (for MoE experts) |
| 20 | + |
| 21 | + |
| 22 | +## Installation |
| 23 | + |
| 24 | +### Option 1: Install from PyPI |
| 25 | + |
| 26 | +```bash |
| 27 | +uv venv --python 3.12 && source .venv/bin/activate |
| 28 | +uv pip install sglang-jax |
| 29 | +``` |
| 30 | + |
| 31 | +### Option 2: Install from Source |
| 32 | + |
| 33 | +```bash |
| 34 | +git clone https://github.com/sgl-project/sglang-jax |
| 35 | +cd sglang-jax |
| 36 | +uv venv --python 3.12 && source .venv/bin/activate |
| 37 | +uv pip install -e python/ |
| 38 | +``` |
| 39 | +## Launch Distributed Server |
| 40 | + |
| 41 | +### Preparation |
| 42 | + |
| 43 | +1. **Get Node 0 IP address** (coordinator): |
| 44 | + |
| 45 | +```bash |
| 46 | +# On node 0 |
| 47 | +hostname -I | awk '{print $1}' |
| 48 | +``` |
| 49 | + |
| 50 | +Save this IP as `NODE_RANK_0_IP`. |
| 51 | + |
| 52 | +2. **Download model** (recommended to use shared storage or pre-download on all nodes): |
| 53 | + |
| 54 | +```bash |
| 55 | +export HF_TOKEN=your_huggingface_token |
| 56 | +huggingface-cli download Qwen/Qwen3-Coder-480B --local-dir /path/to/model |
| 57 | +``` |
| 58 | + |
| 59 | +### Launch Command |
| 60 | + |
| 61 | +Run the following command **on each node**, replacing: |
| 62 | +- `<NODE_RANK_0_IP>`: IP address of node 0 |
| 63 | +- `<NODE_RANK>`: Current node rank (0-15) |
| 64 | +- `<QWEN3_CODER_480B_MODEL_PATH>`: Path to the downloaded model |
| 65 | + |
| 66 | +```bash |
| 67 | +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ |
| 68 | +python3 -u -m sgl_jax.launch_server \ |
| 69 | + --model-path <QWEN3_CODER_480B_MODEL_PATH> \ |
| 70 | + --trust-remote-code \ |
| 71 | + --dist-init-addr=<NODE_RANK_0_IP>:10011 \ |
| 72 | + --nnodes=16 \ |
| 73 | + --tp-size=32 \ |
| 74 | + --device=tpu \ |
| 75 | + --random-seed=3 \ |
| 76 | + --mem-fraction-static=0.8 \ |
| 77 | + --chunked-prefill-size=2048 \ |
| 78 | + --download-dir=/dev/shm \ |
| 79 | + --dtype=bfloat16 \ |
| 80 | + --max-running-requests=128 \ |
| 81 | + --skip-server-warmup \ |
| 82 | + --page-size=128 \ |
| 83 | + --tool-call-parser=qwen3_coder \ |
| 84 | + --node-rank=<NODE_RANK> |
| 85 | +``` |
| 86 | + |
| 87 | +### Example for Specific Nodes |
| 88 | + |
| 89 | +**Node 0 (coordinator):** |
| 90 | + |
| 91 | +```bash |
| 92 | +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ |
| 93 | +python3 -u -m sgl_jax.launch_server \ |
| 94 | + --model-path /path/to/Qwen3-Coder-480B \ |
| 95 | + --trust-remote-code \ |
| 96 | + --dist-init-addr=10.0.0.2:10011 \ |
| 97 | + --nnodes=16 \ |
| 98 | + --tp-size=32 \ |
| 99 | + --device=tpu \ |
| 100 | + --random-seed=3 \ |
| 101 | + --mem-fraction-static=0.8 \ |
| 102 | + --chunked-prefill-size=2048 \ |
| 103 | + --download-dir=/dev/shm \ |
| 104 | + --dtype=bfloat16 \ |
| 105 | + --max-running-requests=128 \ |
| 106 | + --skip-server-warmup \ |
| 107 | + --page-size=128 \ |
| 108 | + --tool-call-parser=qwen3_coder \ |
| 109 | + --node-rank=0 |
| 110 | +``` |
| 111 | + |
| 112 | +**Node 1:** |
| 113 | + |
| 114 | +```bash |
| 115 | +# Same command but with --node-rank=1 |
| 116 | +JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ |
| 117 | +python3 -u -m sgl_jax.launch_server \ |
| 118 | + --model-path /path/to/Qwen3-Coder-480B \ |
| 119 | + --trust-remote-code \ |
| 120 | + --dist-init-addr=10.0.0.2:10011 \ |
| 121 | + --nnodes=16 \ |
| 122 | + --tp-size=32 \ |
| 123 | + --device=tpu \ |
| 124 | + --random-seed=3 \ |
| 125 | + --mem-fraction-static=0.8 \ |
| 126 | + --chunked-prefill-size=2048 \ |
| 127 | + --download-dir=/dev/shm \ |
| 128 | + --dtype=bfloat16 \ |
| 129 | + --max-running-requests=128 \ |
| 130 | + --skip-server-warmup \ |
| 131 | + --page-size=128 \ |
| 132 | + --tool-call-parser=qwen3_coder \ |
| 133 | + --node-rank=1 |
| 134 | +``` |
| 135 | + |
| 136 | +Repeat for all 16 nodes, incrementing `--node-rank` from 0 to 15. |
| 137 | + |
| 138 | +## Configuration Parameters |
| 139 | + |
| 140 | +### Distributed Training |
| 141 | + |
| 142 | +- `--nnodes`: Number of nodes in the cluster (16) |
| 143 | +- `--node-rank`: Rank of the current node (0-15) |
| 144 | +- `--dist-init-addr`: Address of the coordinator node (node 0) with port |
| 145 | + |
| 146 | +### Model Parallelism |
| 147 | + |
| 148 | +- `--tp-size`: Tensor parallelism size for non-MoE layers (32) |
| 149 | +- **ETP**: Expert tensor parallelism automatically configured to 64 based on total chips |
| 150 | + |
| 151 | +### Memory and Performance |
| 152 | + |
| 153 | +- `--mem-fraction-static`: Memory allocation for static buffers (0.8) |
| 154 | +- `--chunked-prefill-size`: Prefill chunk size for batching (2048) |
| 155 | +- `--max-running-requests`: Maximum concurrent requests (128) |
| 156 | +- `--page-size`: Page size for memory management (128) |
| 157 | + |
| 158 | +### Model-Specific |
| 159 | + |
| 160 | +- `--tool-call-parser`: Parser for tool calls, set to `qwen3_coder` for this model |
| 161 | +- `--dtype`: Data type for inference (bfloat16) |
| 162 | +- `--random-seed`: Random seed for reproducibility (3) |
| 163 | + |
| 164 | +## Verification |
| 165 | + |
| 166 | +Once all nodes are running, the server will be accessible via the coordinator node (node 0). You can test it with: |
| 167 | + |
| 168 | +```bash |
| 169 | +curl http://<NODE_RANK_0_IP>:8000/v1/completions \ |
| 170 | + -H "Content-Type: application/json" \ |
| 171 | + -d '{ |
| 172 | + "model": "Qwen/Qwen3-Coder-480B", |
| 173 | + "prompt": "def fibonacci(n):", |
| 174 | + "max_tokens": 200, |
| 175 | + "temperature": 0 |
| 176 | + }' |
| 177 | +``` |
0 commit comments