|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +from absl.testing import absltest |
| 4 | +from jax._src import test_util as jtu |
| 5 | +from jax.sharding import Mesh |
| 6 | + |
| 7 | +from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe |
| 8 | + |
| 9 | +jax.config.parse_flags_with_absl() |
| 10 | + |
| 11 | + |
| 12 | +def gen_moe_inputs( |
| 13 | + dtype, |
| 14 | + top_k, |
| 15 | + num_experts, |
| 16 | + hidden_size, |
| 17 | + intermediate_size, |
| 18 | + num_tokens, |
| 19 | + *, |
| 20 | + seed=1234, |
| 21 | +): |
| 22 | + key = jax.random.key(seed) |
| 23 | + k0, k1, k2, k4, k5 = jax.random.split(key, 5) |
| 24 | + a = jax.random.normal(k0, (num_tokens, hidden_size), |
| 25 | + dtype=jnp.float32).astype(dtype) / 10 |
| 26 | + w1 = (jax.random.normal( |
| 27 | + k1, |
| 28 | + (num_experts, 2, hidden_size, intermediate_size), |
| 29 | + dtype=jnp.float32, |
| 30 | + ) / 10).astype(dtype) |
| 31 | + w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size), |
| 32 | + dtype=jnp.float32) / 10).astype(dtype) |
| 33 | + gating_output = ( |
| 34 | + jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) + |
| 35 | + jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape( |
| 36 | + num_tokens, num_experts) / 100) |
| 37 | + # To generate unique top-k! |
| 38 | + top_k_indices = jax.random.randint(k5, (num_tokens, top_k), |
| 39 | + minval=0, |
| 40 | + maxval=num_experts - 1, |
| 41 | + dtype=jnp.int32) |
| 42 | + one_hot = (jnp.sum( |
| 43 | + jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32), |
| 44 | + axis=1, |
| 45 | + ) * 10) |
| 46 | + gating_output = (gating_output + one_hot).astype(dtype) |
| 47 | + return a, w1, w2, gating_output |
| 48 | + |
| 49 | + |
| 50 | +@jtu.with_config(jax_numpy_dtype_promotion="standard") |
| 51 | +class MoEKernelTest(jtu.JaxTestCase): |
| 52 | + |
| 53 | + def setUp(self): |
| 54 | + super().setUp() |
| 55 | + self.mesh_devices = sorted( |
| 56 | + jax.devices(), |
| 57 | + key=lambda x: ( |
| 58 | + x.coords[0], |
| 59 | + (-1 if x.coords[0] % 2 else 1) * x.coords[1], |
| 60 | + ), |
| 61 | + ) |
| 62 | + self.mesh = Mesh(devices=self.mesh_devices, axis_names=("model", )) |
| 63 | + |
| 64 | + def test_basic(self): |
| 65 | + dtype = jnp.bfloat16 |
| 66 | + top_k = 2 |
| 67 | + num_experts = 16 |
| 68 | + hidden_size = 256 |
| 69 | + intermediate_size = 256 |
| 70 | + num_tokens = 8 * 2 |
| 71 | + |
| 72 | + a, w1, w2, gating_output = gen_moe_inputs( |
| 73 | + dtype, |
| 74 | + top_k, |
| 75 | + num_experts, |
| 76 | + hidden_size, |
| 77 | + intermediate_size, |
| 78 | + num_tokens, |
| 79 | + ) |
| 80 | + |
| 81 | + actual = jax.block_until_ready( |
| 82 | + fused_ep_moe( |
| 83 | + mesh=self.mesh, |
| 84 | + tokens=a, |
| 85 | + w1=w1, |
| 86 | + w2=w2, |
| 87 | + gating_output=gating_output, |
| 88 | + top_k=top_k, |
| 89 | + bt=32, |
| 90 | + bf=512, |
| 91 | + bd1=512, |
| 92 | + bd2=512, |
| 93 | + btc=32, |
| 94 | + bfc=256, |
| 95 | + bd1c=256, |
| 96 | + bd2c=256, |
| 97 | + )) |
| 98 | + expected = ref_moe(a, w1, w2, gating_output, top_k) |
| 99 | + self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2) |
| 100 | + |
| 101 | + |
| 102 | +if __name__ == "__main__": |
| 103 | + absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments