11import jax
22import jax .numpy as jnp
33import numpy as np
4- from absl .testing import absltest
4+ from absl .testing import absltest , parameterized
55from jax ._src import test_util as jtu
66from jax .sharding import Mesh
77
@@ -43,11 +43,31 @@ def gen_moe_inputs(
4343 one_hot = (jnp .sum (
4444 jax .nn .one_hot (top_k_indices , num_experts , dtype = jnp .float32 ),
4545 axis = 1 ,
46- ) * 10 )
46+ ) * 30 )
4747 gating_output = (gating_output + one_hot ).astype (dtype )
4848 return a , w1 , w2 , gating_output
4949
5050
51+ def sub_channel_quantize (x , quant_dtype , wsz = 256 ):
52+ """Quantizes x with sub-channel quantization on the 2nd minor."""
53+ if jnp .issubdtype (quant_dtype , jnp .floating ):
54+ dtype_info = jnp .finfo (quant_dtype )
55+ else :
56+ dtype_info = jnp .iinfo (quant_dtype )
57+ dtype_max = float (dtype_info .max )
58+ w_lst , scale_lst = [], []
59+ assert len (x .shape ) >= 2
60+ assert x .shape [- 2 ] % wsz == 0
61+ for i in range (0 , x .shape [- 2 ], wsz ):
62+ y = x [..., i :i + wsz , :]
63+ abs_max = jnp .abs (y ).max (axis = - 2 , keepdims = True )
64+ scale = (abs_max / dtype_max ).astype (jnp .float32 )
65+ w = (y / scale ).astype (quant_dtype )
66+ w_lst .append (w )
67+ scale_lst .append (scale )
68+ return jnp .concat (w_lst , axis = - 2 ), jnp .concat (scale_lst , axis = - 2 )
69+
70+
5171@jtu .with_config (jax_numpy_dtype_promotion = "standard" )
5272class MoEKernelTest (jtu .JaxTestCase ):
5373
@@ -63,42 +83,234 @@ def setUp(self):
6383 self .mesh = Mesh (np .array (self .mesh_devices ).reshape (1 , - 1 ),
6484 axis_names = ("data" , "model" ))
6585
66- def test_basic (self ):
67- dtype = jnp .bfloat16
68- top_k = 2
69- num_experts = 16
70- hidden_size = 256
71- intermediate_size = 256
72- num_tokens = 8 * 2
73-
86+ def _test_moe (
87+ self ,
88+ dtype ,
89+ top_k ,
90+ num_experts ,
91+ hidden_size ,
92+ intermediate_size ,
93+ num_tokens ,
94+ seed ,
95+ renormalize_topk_logits ,
96+ bt ,
97+ bf ,
98+ bd1 ,
99+ bd2 ,
100+ btc ,
101+ bfc ,
102+ bd1c ,
103+ bd2c ,
104+ act_fn = "silu" ,
105+ w_dtype = None ,
106+ subc_quant_wsz = None ,
107+ use_benchmark_baseline = False ,
108+ atol = 2e-1 ,
109+ rtol = 2e-1 ,
110+ ):
74111 a , w1 , w2 , gating_output = gen_moe_inputs (
75112 dtype ,
76113 top_k ,
77114 num_experts ,
78115 hidden_size ,
79116 intermediate_size ,
80117 num_tokens ,
118+ seed = seed ,
119+ )
120+ w1_scale = None
121+ w2_scale = None
122+ if w_dtype is not None :
123+ if subc_quant_wsz is None :
124+ subc_quant_wsz = 256
125+ w1 , w1_scale = sub_channel_quantize (w1 , w_dtype , subc_quant_wsz )
126+ w2 , w2_scale = sub_channel_quantize (w2 , w_dtype , subc_quant_wsz )
127+
128+ actual = fused_ep_moe (
129+ mesh = self .mesh ,
130+ tokens = a ,
131+ w1 = w1 ,
132+ w2 = w2 ,
133+ gating_output = gating_output ,
134+ top_k = top_k ,
135+ renormalize_topk_logits = renormalize_topk_logits ,
136+ act_fn = act_fn ,
137+ subc_quant_wsz = subc_quant_wsz ,
138+ w1_scale = w1_scale ,
139+ w2_scale = w2_scale ,
140+ bt = bt ,
141+ bf = bf ,
142+ bd1 = bd1 ,
143+ bd2 = bd2 ,
144+ btc = btc ,
145+ bfc = bfc ,
146+ bd1c = bd1c ,
147+ bd2c = bd2c ,
148+ )
149+ expected = ref_moe (
150+ a ,
151+ w1 ,
152+ w2 ,
153+ gating_output ,
154+ top_k ,
155+ renormalize_topk_logits = renormalize_topk_logits ,
156+ activation = act_fn ,
157+ subc_quant_wsz = subc_quant_wsz ,
158+ w1_scale = w1_scale ,
159+ w2_scale = w2_scale ,
160+ )
161+ self .assertAllClose (actual , expected , atol = atol , rtol = rtol )
162+
163+ @parameterized .product (renormalize_topk_logits = [True , False ], )
164+ def test_basic (self , renormalize_topk_logits ):
165+ dtype = jnp .bfloat16
166+ top_k = 8
167+ num_experts = 128
168+ hidden_size = 1024
169+ intermediate_size = 1024
170+ num_tokens = 8 * 32
171+ self ._test_moe (
172+ dtype = dtype ,
173+ top_k = top_k ,
174+ num_experts = num_experts ,
175+ hidden_size = hidden_size ,
176+ intermediate_size = intermediate_size ,
177+ num_tokens = num_tokens ,
178+ seed = 1234 ,
179+ renormalize_topk_logits = renormalize_topk_logits ,
180+ bt = 32 ,
181+ bf = 1024 ,
182+ bd1 = 1024 ,
183+ bd2 = 1024 ,
184+ btc = 32 ,
185+ bfc = 256 ,
186+ bd1c = 256 ,
187+ bd2c = 256 ,
81188 )
82189
83- actual = jax .block_until_ready (
84- fused_ep_moe (
85- mesh = self .mesh ,
86- tokens = a ,
87- w1 = w1 ,
88- w2 = w2 ,
89- gating_output = gating_output ,
90- top_k = top_k ,
91- bt = 32 ,
92- bf = 512 ,
93- bd1 = 512 ,
94- bd2 = 512 ,
95- btc = 32 ,
96- bfc = 256 ,
97- bd1c = 256 ,
98- bd2c = 256 ,
99- ))
100- expected = ref_moe (a , w1 , w2 , gating_output , top_k )
101- self .assertAllClose (expected , actual , atol = 2e-2 , rtol = 2e-2 )
190+ @parameterized .product (act_fn = ["silu" , "gelu" , "swigluoai" ], )
191+ def test_activation (self , act_fn ):
192+ dtype = jnp .bfloat16
193+ top_k = 8
194+ num_experts = 128
195+ hidden_size = 1024
196+ intermediate_size = 1024
197+ num_tokens = 8 * 32
198+ self ._test_moe (
199+ dtype = dtype ,
200+ top_k = top_k ,
201+ num_experts = num_experts ,
202+ hidden_size = hidden_size ,
203+ intermediate_size = intermediate_size ,
204+ num_tokens = num_tokens ,
205+ seed = 1234 ,
206+ renormalize_topk_logits = True ,
207+ act_fn = act_fn ,
208+ bt = 32 ,
209+ bf = 512 ,
210+ bd1 = 512 ,
211+ bd2 = 512 ,
212+ btc = 32 ,
213+ bfc = 256 ,
214+ bd1c = 256 ,
215+ bd2c = 256 ,
216+ )
217+
218+ def test_benchmark_qwen_235 (self ):
219+ num_experts = 128
220+ top_k = 8
221+ hidden_size = 4096
222+ intermediate_size = 1536
223+ dtype = jnp .bfloat16
224+ num_tokens = 8 * 64
225+ seed = 54321
226+ renormalize_topk_logits = True
227+ self ._test_moe (
228+ dtype = dtype ,
229+ top_k = top_k ,
230+ num_experts = num_experts ,
231+ hidden_size = hidden_size ,
232+ intermediate_size = intermediate_size ,
233+ num_tokens = num_tokens ,
234+ seed = seed ,
235+ renormalize_topk_logits = renormalize_topk_logits ,
236+ bt = 64 ,
237+ bf = 768 ,
238+ bd1 = 2048 ,
239+ bd2 = 2048 ,
240+ btc = 64 ,
241+ bfc = 768 ,
242+ bd1c = 2048 ,
243+ bd2c = 2048 ,
244+ act_fn = "silu" ,
245+ atol = 5e-2 ,
246+ rtol = 5e-2 ,
247+ )
248+
249+ def test_benchmark_qwen_30b_a3b (self ):
250+ num_experts = 128
251+ top_k = 8
252+ hidden_size = 2048
253+ intermediate_size = 768
254+ dtype = jnp .bfloat16
255+ num_tokens = 512
256+ seed = 54321
257+ renormalize_topk_logits = True
258+ self ._test_moe (
259+ dtype = dtype ,
260+ top_k = top_k ,
261+ num_experts = num_experts ,
262+ hidden_size = hidden_size ,
263+ intermediate_size = intermediate_size ,
264+ num_tokens = num_tokens ,
265+ seed = seed ,
266+ renormalize_topk_logits = renormalize_topk_logits ,
267+ bt = 16 ,
268+ bf = 384 ,
269+ bd1 = 512 ,
270+ bd2 = 512 ,
271+ btc = 16 ,
272+ bfc = 384 ,
273+ bd1c = 256 ,
274+ bd2c = 256 ,
275+ act_fn = "silu" ,
276+ atol = 5e-2 ,
277+ rtol = 5e-2 ,
278+ )
279+
280+ @parameterized .product (
281+ w_dtype = [jnp .int8 , jnp .float8_e5m2 , jnp .float4_e2m1fn ], )
282+ def test_sub_channel_quantization (self , w_dtype ):
283+ if w_dtype in (
284+ jnp .float8_e5m2 ,
285+ jnp .float4_e2m1fn ,
286+ ) and not jtu .is_device_tpu_at_least (version = 7 ):
287+ self .skipTest ("Expect TPUv7+" )
288+ dtype = jnp .bfloat16
289+ top_k = 8
290+ num_experts = 128
291+ hidden_size = 1024
292+ intermediate_size = 1024
293+ num_tokens = 8 * 32
294+ self ._test_moe (
295+ dtype = dtype ,
296+ top_k = top_k ,
297+ num_experts = num_experts ,
298+ hidden_size = hidden_size ,
299+ intermediate_size = intermediate_size ,
300+ num_tokens = num_tokens ,
301+ seed = 1234 ,
302+ renormalize_topk_logits = False ,
303+ w_dtype = w_dtype ,
304+ subc_quant_wsz = 256 ,
305+ bt = 32 ,
306+ bf = 1024 ,
307+ bd1 = 1024 ,
308+ bd2 = 1024 ,
309+ btc = 32 ,
310+ bfc = 256 ,
311+ bd1c = 256 ,
312+ bd2c = 256 ,
313+ )
102314
103315
104316if __name__ == "__main__" :
0 commit comments