Skip to content

Commit 0ba0c92

Browse files
authored
[FusedMoE] Support sub-channel quantization: FP4, FP8, INT8, ... (#1158)
1 parent add0b5b commit 0ba0c92

File tree

2 files changed

+698
-136
lines changed

2 files changed

+698
-136
lines changed

tests/kernels/fused_moe_v1_test.py

Lines changed: 241 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax
22
import jax.numpy as jnp
33
import numpy as np
4-
from absl.testing import absltest
4+
from absl.testing import absltest, parameterized
55
from jax._src import test_util as jtu
66
from jax.sharding import Mesh
77

@@ -43,11 +43,31 @@ def gen_moe_inputs(
4343
one_hot = (jnp.sum(
4444
jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
4545
axis=1,
46-
) * 10)
46+
) * 30)
4747
gating_output = (gating_output + one_hot).astype(dtype)
4848
return a, w1, w2, gating_output
4949

5050

51+
def sub_channel_quantize(x, quant_dtype, wsz=256):
52+
"""Quantizes x with sub-channel quantization on the 2nd minor."""
53+
if jnp.issubdtype(quant_dtype, jnp.floating):
54+
dtype_info = jnp.finfo(quant_dtype)
55+
else:
56+
dtype_info = jnp.iinfo(quant_dtype)
57+
dtype_max = float(dtype_info.max)
58+
w_lst, scale_lst = [], []
59+
assert len(x.shape) >= 2
60+
assert x.shape[-2] % wsz == 0
61+
for i in range(0, x.shape[-2], wsz):
62+
y = x[..., i:i + wsz, :]
63+
abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
64+
scale = (abs_max / dtype_max).astype(jnp.float32)
65+
w = (y / scale).astype(quant_dtype)
66+
w_lst.append(w)
67+
scale_lst.append(scale)
68+
return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
69+
70+
5171
@jtu.with_config(jax_numpy_dtype_promotion="standard")
5272
class MoEKernelTest(jtu.JaxTestCase):
5373

@@ -63,42 +83,234 @@ def setUp(self):
6383
self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
6484
axis_names=("data", "model"))
6585

66-
def test_basic(self):
67-
dtype = jnp.bfloat16
68-
top_k = 2
69-
num_experts = 16
70-
hidden_size = 256
71-
intermediate_size = 256
72-
num_tokens = 8 * 2
73-
86+
def _test_moe(
87+
self,
88+
dtype,
89+
top_k,
90+
num_experts,
91+
hidden_size,
92+
intermediate_size,
93+
num_tokens,
94+
seed,
95+
renormalize_topk_logits,
96+
bt,
97+
bf,
98+
bd1,
99+
bd2,
100+
btc,
101+
bfc,
102+
bd1c,
103+
bd2c,
104+
act_fn="silu",
105+
w_dtype=None,
106+
subc_quant_wsz=None,
107+
use_benchmark_baseline=False,
108+
atol=2e-1,
109+
rtol=2e-1,
110+
):
74111
a, w1, w2, gating_output = gen_moe_inputs(
75112
dtype,
76113
top_k,
77114
num_experts,
78115
hidden_size,
79116
intermediate_size,
80117
num_tokens,
118+
seed=seed,
119+
)
120+
w1_scale = None
121+
w2_scale = None
122+
if w_dtype is not None:
123+
if subc_quant_wsz is None:
124+
subc_quant_wsz = 256
125+
w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
126+
w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
127+
128+
actual = fused_ep_moe(
129+
mesh=self.mesh,
130+
tokens=a,
131+
w1=w1,
132+
w2=w2,
133+
gating_output=gating_output,
134+
top_k=top_k,
135+
renormalize_topk_logits=renormalize_topk_logits,
136+
act_fn=act_fn,
137+
subc_quant_wsz=subc_quant_wsz,
138+
w1_scale=w1_scale,
139+
w2_scale=w2_scale,
140+
bt=bt,
141+
bf=bf,
142+
bd1=bd1,
143+
bd2=bd2,
144+
btc=btc,
145+
bfc=bfc,
146+
bd1c=bd1c,
147+
bd2c=bd2c,
148+
)
149+
expected = ref_moe(
150+
a,
151+
w1,
152+
w2,
153+
gating_output,
154+
top_k,
155+
renormalize_topk_logits=renormalize_topk_logits,
156+
activation=act_fn,
157+
subc_quant_wsz=subc_quant_wsz,
158+
w1_scale=w1_scale,
159+
w2_scale=w2_scale,
160+
)
161+
self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
162+
163+
@parameterized.product(renormalize_topk_logits=[True, False], )
164+
def test_basic(self, renormalize_topk_logits):
165+
dtype = jnp.bfloat16
166+
top_k = 8
167+
num_experts = 128
168+
hidden_size = 1024
169+
intermediate_size = 1024
170+
num_tokens = 8 * 32
171+
self._test_moe(
172+
dtype=dtype,
173+
top_k=top_k,
174+
num_experts=num_experts,
175+
hidden_size=hidden_size,
176+
intermediate_size=intermediate_size,
177+
num_tokens=num_tokens,
178+
seed=1234,
179+
renormalize_topk_logits=renormalize_topk_logits,
180+
bt=32,
181+
bf=1024,
182+
bd1=1024,
183+
bd2=1024,
184+
btc=32,
185+
bfc=256,
186+
bd1c=256,
187+
bd2c=256,
81188
)
82189

83-
actual = jax.block_until_ready(
84-
fused_ep_moe(
85-
mesh=self.mesh,
86-
tokens=a,
87-
w1=w1,
88-
w2=w2,
89-
gating_output=gating_output,
90-
top_k=top_k,
91-
bt=32,
92-
bf=512,
93-
bd1=512,
94-
bd2=512,
95-
btc=32,
96-
bfc=256,
97-
bd1c=256,
98-
bd2c=256,
99-
))
100-
expected = ref_moe(a, w1, w2, gating_output, top_k)
101-
self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
190+
@parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
191+
def test_activation(self, act_fn):
192+
dtype = jnp.bfloat16
193+
top_k = 8
194+
num_experts = 128
195+
hidden_size = 1024
196+
intermediate_size = 1024
197+
num_tokens = 8 * 32
198+
self._test_moe(
199+
dtype=dtype,
200+
top_k=top_k,
201+
num_experts=num_experts,
202+
hidden_size=hidden_size,
203+
intermediate_size=intermediate_size,
204+
num_tokens=num_tokens,
205+
seed=1234,
206+
renormalize_topk_logits=True,
207+
act_fn=act_fn,
208+
bt=32,
209+
bf=512,
210+
bd1=512,
211+
bd2=512,
212+
btc=32,
213+
bfc=256,
214+
bd1c=256,
215+
bd2c=256,
216+
)
217+
218+
def test_benchmark_qwen_235(self):
219+
num_experts = 128
220+
top_k = 8
221+
hidden_size = 4096
222+
intermediate_size = 1536
223+
dtype = jnp.bfloat16
224+
num_tokens = 8 * 64
225+
seed = 54321
226+
renormalize_topk_logits = True
227+
self._test_moe(
228+
dtype=dtype,
229+
top_k=top_k,
230+
num_experts=num_experts,
231+
hidden_size=hidden_size,
232+
intermediate_size=intermediate_size,
233+
num_tokens=num_tokens,
234+
seed=seed,
235+
renormalize_topk_logits=renormalize_topk_logits,
236+
bt=64,
237+
bf=768,
238+
bd1=2048,
239+
bd2=2048,
240+
btc=64,
241+
bfc=768,
242+
bd1c=2048,
243+
bd2c=2048,
244+
act_fn="silu",
245+
atol=5e-2,
246+
rtol=5e-2,
247+
)
248+
249+
def test_benchmark_qwen_30b_a3b(self):
250+
num_experts = 128
251+
top_k = 8
252+
hidden_size = 2048
253+
intermediate_size = 768
254+
dtype = jnp.bfloat16
255+
num_tokens = 512
256+
seed = 54321
257+
renormalize_topk_logits = True
258+
self._test_moe(
259+
dtype=dtype,
260+
top_k=top_k,
261+
num_experts=num_experts,
262+
hidden_size=hidden_size,
263+
intermediate_size=intermediate_size,
264+
num_tokens=num_tokens,
265+
seed=seed,
266+
renormalize_topk_logits=renormalize_topk_logits,
267+
bt=16,
268+
bf=384,
269+
bd1=512,
270+
bd2=512,
271+
btc=16,
272+
bfc=384,
273+
bd1c=256,
274+
bd2c=256,
275+
act_fn="silu",
276+
atol=5e-2,
277+
rtol=5e-2,
278+
)
279+
280+
@parameterized.product(
281+
w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
282+
def test_sub_channel_quantization(self, w_dtype):
283+
if w_dtype in (
284+
jnp.float8_e5m2,
285+
jnp.float4_e2m1fn,
286+
) and not jtu.is_device_tpu_at_least(version=7):
287+
self.skipTest("Expect TPUv7+")
288+
dtype = jnp.bfloat16
289+
top_k = 8
290+
num_experts = 128
291+
hidden_size = 1024
292+
intermediate_size = 1024
293+
num_tokens = 8 * 32
294+
self._test_moe(
295+
dtype=dtype,
296+
top_k=top_k,
297+
num_experts=num_experts,
298+
hidden_size=hidden_size,
299+
intermediate_size=intermediate_size,
300+
num_tokens=num_tokens,
301+
seed=1234,
302+
renormalize_topk_logits=False,
303+
w_dtype=w_dtype,
304+
subc_quant_wsz=256,
305+
bt=32,
306+
bf=1024,
307+
bd1=1024,
308+
bd2=1024,
309+
btc=32,
310+
bfc=256,
311+
bd1c=256,
312+
bd2c=256,
313+
)
102314

103315

104316
if __name__ == "__main__":

0 commit comments

Comments
 (0)