Skip to content

Commit ee7bccc

Browse files
authored
fused Moe (#973)
Signed-off-by: Jevin Jiang <[email protected]>
1 parent 0a5fed5 commit ee7bccc

File tree

2 files changed

+1129
-0
lines changed

2 files changed

+1129
-0
lines changed

tests/kernels/fused_moe_v1_test.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from absl.testing import absltest
4+
from jax._src import test_util as jtu
5+
from jax.sharding import Mesh
6+
7+
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
8+
9+
jax.config.parse_flags_with_absl()
10+
11+
12+
def gen_moe_inputs(
13+
dtype,
14+
top_k,
15+
num_experts,
16+
hidden_size,
17+
intermediate_size,
18+
num_tokens,
19+
*,
20+
seed=1234,
21+
):
22+
key = jax.random.key(seed)
23+
k0, k1, k2, k4, k5 = jax.random.split(key, 5)
24+
a = jax.random.normal(k0, (num_tokens, hidden_size),
25+
dtype=jnp.float32).astype(dtype) / 10
26+
w1 = (jax.random.normal(
27+
k1,
28+
(num_experts, 2, hidden_size, intermediate_size),
29+
dtype=jnp.float32,
30+
) / 10).astype(dtype)
31+
w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
32+
dtype=jnp.float32) / 10).astype(dtype)
33+
gating_output = (
34+
jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
35+
jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
36+
num_tokens, num_experts) / 100)
37+
# To generate unique top-k!
38+
top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
39+
minval=0,
40+
maxval=num_experts - 1,
41+
dtype=jnp.int32)
42+
one_hot = (jnp.sum(
43+
jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
44+
axis=1,
45+
) * 10)
46+
gating_output = (gating_output + one_hot).astype(dtype)
47+
return a, w1, w2, gating_output
48+
49+
50+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
51+
class MoEKernelTest(jtu.JaxTestCase):
52+
53+
def setUp(self):
54+
super().setUp()
55+
self.mesh_devices = sorted(
56+
jax.devices(),
57+
key=lambda x: (
58+
x.coords[0],
59+
(-1 if x.coords[0] % 2 else 1) * x.coords[1],
60+
),
61+
)
62+
self.mesh = Mesh(devices=self.mesh_devices, axis_names=("model", ))
63+
64+
def test_basic(self):
65+
dtype = jnp.bfloat16
66+
top_k = 2
67+
num_experts = 16
68+
hidden_size = 256
69+
intermediate_size = 256
70+
num_tokens = 8 * 2
71+
72+
a, w1, w2, gating_output = gen_moe_inputs(
73+
dtype,
74+
top_k,
75+
num_experts,
76+
hidden_size,
77+
intermediate_size,
78+
num_tokens,
79+
)
80+
81+
actual = jax.block_until_ready(
82+
fused_ep_moe(
83+
mesh=self.mesh,
84+
tokens=a,
85+
w1=w1,
86+
w2=w2,
87+
gating_output=gating_output,
88+
top_k=top_k,
89+
bt=32,
90+
bf=512,
91+
bd1=512,
92+
bd2=512,
93+
btc=32,
94+
bfc=256,
95+
bd1c=256,
96+
bd2c=256,
97+
))
98+
expected = ref_moe(a, w1, w2, gating_output, top_k)
99+
self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
100+
101+
102+
if __name__ == "__main__":
103+
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)