diff --git a/tests/kernels/fused_moe_v1_test.py b/tests/kernels/fused_moe_v1_test.py index b89828638..d8b82c31a 100644 --- a/tests/kernels/fused_moe_v1_test.py +++ b/tests/kernels/fused_moe_v1_test.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import absltest +from absl.testing import absltest, parameterized from jax._src import test_util as jtu from jax.sharding import Mesh @@ -43,11 +43,31 @@ def gen_moe_inputs( one_hot = (jnp.sum( jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32), axis=1, - ) * 10) + ) * 30) gating_output = (gating_output + one_hot).astype(dtype) return a, w1, w2, gating_output +def sub_channel_quantize(x, quant_dtype, wsz=256): + """Quantizes x with sub-channel quantization on the 2nd minor.""" + if jnp.issubdtype(quant_dtype, jnp.floating): + dtype_info = jnp.finfo(quant_dtype) + else: + dtype_info = jnp.iinfo(quant_dtype) + dtype_max = float(dtype_info.max) + w_lst, scale_lst = [], [] + assert len(x.shape) >= 2 + assert x.shape[-2] % wsz == 0 + for i in range(0, x.shape[-2], wsz): + y = x[..., i:i + wsz, :] + abs_max = jnp.abs(y).max(axis=-2, keepdims=True) + scale = (abs_max / dtype_max).astype(jnp.float32) + w = (y / scale).astype(quant_dtype) + w_lst.append(w) + scale_lst.append(scale) + return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2) + + @jtu.with_config(jax_numpy_dtype_promotion="standard") class MoEKernelTest(jtu.JaxTestCase): @@ -63,14 +83,31 @@ def setUp(self): self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1), axis_names=("data", "model")) - def test_basic(self): - dtype = jnp.bfloat16 - top_k = 2 - num_experts = 16 - hidden_size = 256 - intermediate_size = 256 - num_tokens = 8 * 2 - + def _test_moe( + self, + dtype, + top_k, + num_experts, + hidden_size, + intermediate_size, + num_tokens, + seed, + renormalize_topk_logits, + bt, + bf, + bd1, + bd2, + btc, + bfc, + bd1c, + bd2c, + act_fn="silu", + w_dtype=None, + subc_quant_wsz=None, + use_benchmark_baseline=False, + atol=2e-1, + rtol=2e-1, + ): a, w1, w2, gating_output = gen_moe_inputs( dtype, top_k, @@ -78,27 +115,202 @@ def test_basic(self): hidden_size, intermediate_size, num_tokens, + seed=seed, + ) + w1_scale = None + w2_scale = None + if w_dtype is not None: + if subc_quant_wsz is None: + subc_quant_wsz = 256 + w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz) + w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz) + + actual = fused_ep_moe( + mesh=self.mesh, + tokens=a, + w1=w1, + w2=w2, + gating_output=gating_output, + top_k=top_k, + renormalize_topk_logits=renormalize_topk_logits, + act_fn=act_fn, + subc_quant_wsz=subc_quant_wsz, + w1_scale=w1_scale, + w2_scale=w2_scale, + bt=bt, + bf=bf, + bd1=bd1, + bd2=bd2, + btc=btc, + bfc=bfc, + bd1c=bd1c, + bd2c=bd2c, + ) + expected = ref_moe( + a, + w1, + w2, + gating_output, + top_k, + renormalize_topk_logits=renormalize_topk_logits, + activation=act_fn, + subc_quant_wsz=subc_quant_wsz, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + self.assertAllClose(actual, expected, atol=atol, rtol=rtol) + + @parameterized.product(renormalize_topk_logits=[True, False], ) + def test_basic(self, renormalize_topk_logits): + dtype = jnp.bfloat16 + top_k = 8 + num_experts = 128 + hidden_size = 1024 + intermediate_size = 1024 + num_tokens = 8 * 32 + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=1234, + renormalize_topk_logits=renormalize_topk_logits, + bt=32, + bf=1024, + bd1=1024, + bd2=1024, + btc=32, + bfc=256, + bd1c=256, + bd2c=256, ) - actual = jax.block_until_ready( - fused_ep_moe( - mesh=self.mesh, - tokens=a, - w1=w1, - w2=w2, - gating_output=gating_output, - top_k=top_k, - bt=32, - bf=512, - bd1=512, - bd2=512, - btc=32, - bfc=256, - bd1c=256, - bd2c=256, - )) - expected = ref_moe(a, w1, w2, gating_output, top_k) - self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2) + @parameterized.product(act_fn=["silu", "gelu", "swigluoai"], ) + def test_activation(self, act_fn): + dtype = jnp.bfloat16 + top_k = 8 + num_experts = 128 + hidden_size = 1024 + intermediate_size = 1024 + num_tokens = 8 * 32 + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=1234, + renormalize_topk_logits=True, + act_fn=act_fn, + bt=32, + bf=512, + bd1=512, + bd2=512, + btc=32, + bfc=256, + bd1c=256, + bd2c=256, + ) + + def test_benchmark_qwen_235(self): + num_experts = 128 + top_k = 8 + hidden_size = 4096 + intermediate_size = 1536 + dtype = jnp.bfloat16 + num_tokens = 8 * 64 + seed = 54321 + renormalize_topk_logits = True + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=seed, + renormalize_topk_logits=renormalize_topk_logits, + bt=64, + bf=768, + bd1=2048, + bd2=2048, + btc=64, + bfc=768, + bd1c=2048, + bd2c=2048, + act_fn="silu", + atol=5e-2, + rtol=5e-2, + ) + + def test_benchmark_qwen_30b_a3b(self): + num_experts = 128 + top_k = 8 + hidden_size = 2048 + intermediate_size = 768 + dtype = jnp.bfloat16 + num_tokens = 512 + seed = 54321 + renormalize_topk_logits = True + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=seed, + renormalize_topk_logits=renormalize_topk_logits, + bt=16, + bf=384, + bd1=512, + bd2=512, + btc=16, + bfc=384, + bd1c=256, + bd2c=256, + act_fn="silu", + atol=5e-2, + rtol=5e-2, + ) + + @parameterized.product( + w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], ) + def test_sub_channel_quantization(self, w_dtype): + if w_dtype in ( + jnp.float8_e5m2, + jnp.float4_e2m1fn, + ) and not jtu.is_device_tpu_at_least(version=7): + self.skipTest("Expect TPUv7+") + dtype = jnp.bfloat16 + top_k = 8 + num_experts = 128 + hidden_size = 1024 + intermediate_size = 1024 + num_tokens = 8 * 32 + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=1234, + renormalize_topk_logits=False, + w_dtype=w_dtype, + subc_quant_wsz=256, + bt=32, + bf=1024, + bd1=1024, + bd2=1024, + btc=32, + bfc=256, + bd1c=256, + bd2c=256, + ) if __name__ == "__main__": diff --git a/tpu_inference/kernels/fused_moe/v1/kernel.py b/tpu_inference/kernels/fused_moe/v1/kernel.py index a64c9c082..906b5c0c7 100644 --- a/tpu_inference/kernels/fused_moe/v1/kernel.py +++ b/tpu_inference/kernels/fused_moe/v1/kernel.py @@ -7,7 +7,6 @@ from jax import lax from jax._src import dtypes from jax.experimental import pallas as pl -from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu P = jax.sharding.PartitionSpec @@ -35,13 +34,47 @@ def broadcast_minor(src, shape): axis=-1)[..., :shape[-1]] +def swigluoai(gate: jax.Array, + up: jax.Array, + *, + alpha: float = 1.702, + limit: float = 7.0) -> jax.Array: + """Activation used in some models such as GPT-OSS.""" + gate = jnp.clip(gate, a_max=limit) + up = jnp.clip(up, a_min=-limit, a_max=limit) + glu = gate * jax.nn.sigmoid(alpha * gate) + return (up + 1.0) * glu + + +def activation_fn(acc1, acc3, act_fn): + if act_fn == "silu": + return jax.nn.silu(acc1) * acc3 + elif act_fn == "gelu": + return jax.nn.gelu(acc1) * acc3 + elif act_fn == "swigluoai": + return swigluoai(acc1, acc3) + else: + raise RuntimeError(f"Unsupported activation function: {act_fn}") + + def ref_moe( - tokens: jax.Array, # (num_tokens, hidden_size) - w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size) - w2: jax.Array, # (num_experts, intermediate_size, hidden_size) - gating_output: jax.Array, # (num_tokens, num_experts) - top_k: int, - activation="silu", + tokens: jax.Array, # (num_tokens, hidden_size) + w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size) + w2: jax.Array, # (num_experts, intermediate_size, hidden_size) + gating_output: jax.Array, # (num_tokens, num_experts) + top_k: int, + *, + renormalize_topk_logits: bool = False, + activation="silu", + subc_quant_wsz: int | None = None, + w1_scale: + ( + jax.Array | None + ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size) + w2_scale: + ( + jax.Array | None + ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size) ): n_tokens = tokens.shape[0] # num_tokens @@ -53,7 +86,12 @@ def ref_moe( top_k_logits, top_k_indices = lax.top_k( gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k] + if renormalize_topk_logits: + top_k_logits = top_k_logits / jnp.sum( + top_k_logits, axis=-1, keepdims=True) + t_outputs = [] + hidden_size, intermediate_size = w1.shape[-2:] # Process each token individually for i in range(n_tokens): @@ -65,10 +103,24 @@ def ref_moe( # Process each selected expert for the current token for expert_id in assigned_expert_ids: # Get expert weights + expert_w1 = w1[expert_id, 0].astype(jnp.float32) + expert_w3 = w1[expert_id, 1].astype(jnp.float32) + if w1_scale is not None: + expert_w1 *= jnp.repeat(w1_scale[expert_id, 0], + subc_quant_wsz, + axis=0)[:hidden_size] + expert_w3 *= jnp.repeat(w1_scale[expert_id, 1], + subc_quant_wsz, + axis=0)[:hidden_size] expert_weight_1 = jnp.concat( - [w1[expert_id, 0], w1[expert_id, 1]], + [expert_w1, expert_w3], axis=-1) # [d_model, 2 * intermediate_size] - expert_weight_2 = w2[expert_id] # [intermediate_size, d_model] + expert_weight_2 = w2[expert_id].astype( + jnp.float32) # [intermediate_size, d_model] + if w2_scale is not None: + expert_weight_2 *= jnp.repeat(w2_scale[expert_id], + subc_quant_wsz, + axis=0)[:intermediate_size] # First linear layer with SwiGLU activation gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size] @@ -79,16 +131,7 @@ def ref_moe( axis=-1) # [1, intermediate_size], [1, intermediate_size] # Apply gated activation: activation(gate) * up - if activation == "silu": - act = jax.nn.silu( - gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size] - elif activation == "gelu": - act = jax.nn.gelu( - gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size] - else: - raise ValueError( - f"Unsupported activation: {activation}. Use 'silu' or 'gelu'." - ) + act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, activation) # Second linear layer (down projection) gmm_2_out = act @ expert_weight_2 # [1, d_model] @@ -105,7 +148,7 @@ def ref_moe( axis=0, keepdims=True) # [1, d_model] - t_outputs.append(weighted_output) + t_outputs.append(weighted_output.astype(tokens.dtype)) return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model] @@ -115,6 +158,11 @@ def _fused_ep_moe_kernel( tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing) w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size) w2_hbm, # (local_num_experts, intermediate_size, hidden_size) + # TODO(jevinjiang): We choose F32 scale for easier slicing. The extra + # latency should be hidden in the pipeline overlaping. But is there a better + # way to do this? + w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size) + w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size) gating_hbm, # (local_num_tokens, padded_num_experts) a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing) # Output @@ -136,6 +184,9 @@ def _fused_ep_moe_kernel( b_w1_x2_vmem, # (2, t_packing, bd1 // t_packing, bf) b_w3_x2_vmem, # (2, t_packing, bd1 // t_packing, bf) b_w2_x2_vmem, # (2, t_packing, bf, bd2 // t_packing) + b_w1_scale_x2_vmem, # None | (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf) + b_w3_scale_x2_vmem, # None | (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf) + b_w2_scale_x2_vmem, # None | (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing) b_acc_vmem, # F32(bt * num_devices, 1, bf * 2) ### Semaphores: local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem] @@ -145,7 +196,10 @@ def _fused_ep_moe_kernel( a2a_acc_sem, *, top_k: int, + renormalize_topk_logits: bool, ep_axis_name: str, + act_fn: str, + subc_quant_wsz: int | None = None, # Kernel tuning params. bt: int, # Block size of local_num_tokens. bf: int, # Block size of intermediate_size. @@ -160,34 +214,53 @@ def _fused_ep_moe_kernel( num_devices = lax.axis_size(ep_axis_name) local_num_tokens = tokens_hbm.shape[0] local_num_experts, intermediate_size, hidden_size = w2_hbm.shape - # num_experts = local_num_experts * num_devices - # padded_num_experts = expert_starts_x2_smem.shape[-1] right_id = (my_id + 1) % num_devices t_dtype = tokens_hbm.dtype t_packing = get_dtype_packing(t_dtype) t_bitwidth = 32 // t_packing assert a2a_g_hbm.dtype == t_dtype - assert w1_hbm.dtype == t_dtype - assert w2_hbm.dtype == t_dtype + assert w1_hbm.dtype == w2_hbm.dtype - h_per_packing = hidden_size // t_packing - assert tokens_hbm.shape[-1] == h_per_packing - bd1_per_packing = bd1 // t_packing - bd2_per_packing = bd2 // t_packing - bd1c_per_packing = bd1c // t_packing - bd2c_per_packing = bd2c // t_packing + assert bd1 % bd1c == 0 + assert bd2 % bd2c == 0 + assert bf % bfc == 0 + assert hidden_size % t_packing == 0 + assert bd1 % t_packing == 0 + assert bd2 % t_packing == 0 + assert bd1c % t_packing == 0 + assert bd2c % t_packing == 0 + + h_per_t_packing = hidden_size // t_packing + assert tokens_hbm.shape[-1] == h_per_t_packing + bd1_per_t_packing = bd1 // t_packing + bd2_per_t_packing = bd2 // t_packing + bd1c_per_t_packing = bd1c // t_packing + bd2c_per_t_packing = bd2c // t_packing + + if subc_quant_wsz is not None: + assert subc_quant_wsz % 256 == 0 + assert bd1c_per_t_packing == subc_quant_wsz + assert bfc == subc_quant_wsz + assert bd1 % subc_quant_wsz == 0 + assert bf % subc_quant_wsz == 0 + assert bd1_per_t_packing % subc_quant_wsz == 0 + assert h_per_t_packing % subc_quant_wsz == 0 num_bt = cdiv(local_num_tokens, bt) num_bf = cdiv(intermediate_size, bf) num_bd1 = cdiv(hidden_size, bd1) num_bd2 = cdiv(hidden_size, bd2) + def get_mesh_device_id(ep_rank): + dp_rank = jax.lax.axis_index("data") + return (dp_rank, ep_rank) + def sync_barrier(): barrier_sem = pltpu.get_barrier_semaphore() pltpu.semaphore_signal( barrier_sem, - device_id=(0, right_id), + device_id=get_mesh_device_id(right_id), device_id_type=pltpu.DeviceIdType.MESH, ) pltpu.semaphore_wait(barrier_sem, 1) @@ -212,7 +285,7 @@ def wait_fetch_b_gating(bt_id): sem=b_gating_sem, ).wait() - def get_top_k(input, top_k): + def get_top_k(input, top_k, renormalize_topk_logits): assert len(input.shape) == 2, input.shape input = input.astype(jnp.float32) top_k_logits_lst = [] @@ -220,11 +293,15 @@ def get_top_k(input, top_k): t2e = jnp.zeros(input.shape, dtype=jnp.int32) t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32) iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1) + top_k_logits_sum = jnp.zeros((input.shape[0], 128), jnp.float32) + for k_id in range(top_k): - # TODO(jevinjiang): return both top_k values and indices in op in Mosaic + # TODO(jevinjiang): return both top_k values and indices in Mosaic top_k_logits = jnp.broadcast_to( jnp.max(input, axis=1, keepdims=True), (input.shape[0], 128)).astype(input.dtype) + if renormalize_topk_logits: + top_k_logits_sum += top_k_logits top_k_logits_lst.append(top_k_logits) # TODO(jevinjiang): support bf16 argmax in Mosaic top_k_indices = jnp.broadcast_to( @@ -236,6 +313,11 @@ def get_top_k(input, top_k): if k_id != top_k - 1: input = jnp.where(mask, -jnp.inf, input) + if renormalize_topk_logits: + for k_id in range(top_k): + top_k_logits_lst[ + k_id] = top_k_logits_lst[k_id] / top_k_logits_sum + expert_sizes = jnp.sum(t2e, axis=0, keepdims=True) expert_starts = jnp.zeros_like(expert_sizes) return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts @@ -277,7 +359,7 @@ def _all_reduce_metadata( dst_ref=d2e_count_vmem.at[row_id], send_sem=send_sem, recv_sem=recv_sem, - device_id=(0, right_id), + device_id=get_mesh_device_id(right_id), device_id_type=pltpu.DeviceIdType.MESH, ).wait() row_id = (row_id + num_devices - 1) % num_devices @@ -359,10 +441,8 @@ def start_a2a_scatter(bt_id, e_sem_id, local_e_id): pl.ds(start, remote_sz)], send_sem=send_sems.at[e_sem_id], recv_sem=recv_sems.at[e_sem_id], - device_id=( - 0, - recv_id, - ), + device_id=get_mesh_device_id(recv_id), + device_id_type=pltpu.DeviceIdType.MESH, ).start() a2a_s_sends_x2_smem[e_sem_id] = send_sz @@ -406,7 +486,8 @@ def start_a2a_gather(bt_id, e_sem_id, local_e_id): dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)], send_sem=send_sems.at[e_sem_id], recv_sem=a2a_gather_sem, - device_id=(0, recv_id), + device_id=get_mesh_device_id(recv_id), + device_id_type=pltpu.DeviceIdType.MESH, ).start() start += sz @@ -435,44 +516,85 @@ def wait_a2a_gather_recv_all(): def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id): for p in range(t_packing): - offset = p * h_per_packing + bd1_id * bd1_per_packing + offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing pltpu.make_async_copy( src_ref=w1_hbm.at[ local_e_id, 0, - pl.ds(offset, bd1_per_packing), + pl.ds(offset, bd1_per_t_packing), pl.ds(bf_id * bf, bf), ], dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p], sem=local_sems.at[bw1_sem_id, 1], ).start() + if w1_scale_hbm is not None: + assert subc_quant_wsz is not None + pltpu.make_async_copy( + src_ref=w1_scale_hbm.at[ + local_e_id, + 0, + pl.ds(offset // subc_quant_wsz, bd1_per_t_packing // + subc_quant_wsz), + pl.ds(0, 1), + pl.ds(bf_id * bf, bf), + ], + dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p], + sem=local_sems.at[bw1_sem_id, 1], + ).start() def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id): for p in range(t_packing): - offset = p * h_per_packing + bd2_id * bd2_per_packing + offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing pltpu.make_async_copy( src_ref=w2_hbm.at[ local_e_id, pl.ds(bf_id * bf, bf), - pl.ds(offset, bd2_per_packing), + pl.ds(offset, bd2_per_t_packing), ], dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p], sem=local_sems.at[bw2_sem_id, 2], ).start() + if w2_scale_hbm is not None: + assert subc_quant_wsz is not None + pltpu.make_async_copy( + src_ref=w2_scale_hbm.at[ + local_e_id, + pl.ds(bf_id * bf // subc_quant_wsz, bf // + subc_quant_wsz), + pl.ds(0, 1), + pl.ds(offset, bd2_per_t_packing), + ], + dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p], + sem=local_sems.at[bw2_sem_id, 2], + ).start() def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id): for p in range(t_packing): - offset = p * h_per_packing + bd3_id * bd1_per_packing + offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing pltpu.make_async_copy( src_ref=w1_hbm.at[ local_e_id, 1, - pl.ds(offset, bd1_per_packing), + pl.ds(offset, bd1_per_t_packing), pl.ds(bf_id * bf, bf), ], dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p], sem=local_sems.at[bw3_sem_id, 3], ).start() + if w1_scale_hbm is not None: + assert subc_quant_wsz is not None + pltpu.make_async_copy( + src_ref=w1_scale_hbm.at[ + local_e_id, + 1, + pl.ds(offset // subc_quant_wsz, bd1_per_t_packing // + subc_quant_wsz), + pl.ds(0, 1), + pl.ds(bf_id * bf, bf), + ], + dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p], + sem=local_sems.at[bw3_sem_id, 3], + ).start() def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id): del local_e_id, bf_id, bd1_id @@ -481,6 +603,12 @@ def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id): dst_ref=b_w1_x2_vmem.at[bw1_sem_id], sem=local_sems.at[bw1_sem_id, 1], ).wait() + if w1_scale_hbm is not None: + pltpu.make_async_copy( + src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id], + dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id], + sem=local_sems.at[bw1_sem_id, 1], + ).wait() def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id): del local_e_id, bf_id, bd2_id @@ -489,6 +617,12 @@ def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id): dst_ref=b_w2_x2_vmem.at[bw2_sem_id], sem=local_sems.at[bw2_sem_id, 2], ).wait() + if w2_scale_hbm is not None: + pltpu.make_async_copy( + src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id], + dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id], + sem=local_sems.at[bw2_sem_id, 2], + ).wait() def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id): del local_e_id, bf_id, bd3_id @@ -497,6 +631,12 @@ def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id): dst_ref=b_w3_x2_vmem.at[bw3_sem_id], sem=local_sems.at[bw3_sem_id, 3], ).wait() + if w1_scale_hbm is not None: + pltpu.make_async_copy( + src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id], + dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id], + sem=local_sems.at[bw3_sem_id, 3], + ).wait() def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id): next_bd1_id = bd1_id + 1 @@ -520,18 +660,36 @@ def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id): def dynamic_ffn1( t_b32_vmem, w1_vmem, + w1_scale_vmem, w3_vmem, + w3_scale_vmem, acc1_vmem, acc3_vmem, dyn_sz, should_init, ): assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing) - assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_packing, + assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing, bf) assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf) assert bd1 % (t_packing * 128) == 0, (bd1, t_packing) assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing) + if w1_scale_vmem is not None: + assert w1_scale_vmem.shape == ( + t_packing, + bd1_per_t_packing // subc_quant_wsz, + 1, + bf, + ) + assert bd1c_per_t_packing == subc_quant_wsz + if w3_scale_vmem is not None: + assert w3_scale_vmem.shape == ( + t_packing, + bd1_per_t_packing // subc_quant_wsz, + 1, + bf, + ) + assert bd1c_per_t_packing == subc_quant_wsz num_loops = cdiv(dyn_sz, btc) repack_ty = jnp.dtype(f"int{t_bitwidth}") @@ -540,7 +698,7 @@ def body(btc_id, _): for bd1c_id in range(cdiv(bd1, bd1c)): t_b32 = t_b32_vmem[ pl.ds(btc_id * btc, btc), - pl.ds(bd1c_id * bd1c_per_packing, bd1c_per_packing), + pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing), ] for p_id in range(t_packing): t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype) @@ -548,18 +706,44 @@ def body(btc_id, _): for bfc_id in range(cdiv(bf, bfc)): w_slices = ( p_id, - pl.ds(bd1c_id * bd1c_per_packing, - bd1c_per_packing), + pl.ds(bd1c_id * bd1c_per_t_packing, + bd1c_per_t_packing), pl.ds(bfc_id * bfc, bfc), ) w1 = w1_vmem[*w_slices] acc1 = jnp.dot(t, w1, preferred_element_type=jnp.float32) + + if w1_scale_vmem is not None: + w1_scale_slices = ( + p_id, + bd1c_id, + pl.ds(0, 1), + pl.ds(bfc_id * bfc, bfc), + ) + # TODO(jevinjiang): can use mosaic to load with stride 0. + w1_scale = jnp.broadcast_to( + w1_scale_vmem[*w1_scale_slices], acc1.shape) + acc1 *= w1_scale + w3 = w3_vmem[*w_slices] + acc3 = jnp.dot(t, w3, preferred_element_type=jnp.float32) + + if w3_scale_vmem is not None: + w3_scale_slices = ( + p_id, + bd1c_id, + pl.ds(0, 1), + pl.ds(bfc_id * bfc, bfc), + ) + w3_scale = jnp.broadcast_to( + w3_scale_vmem[*w3_scale_slices], acc3.shape) + acc3 *= w3_scale + acc_slices = (pl.ds(btc_id * btc, btc), pl.ds(bfc_id * bfc, bfc)) if should_init and p_id == bd1c_id == 0: @@ -575,22 +759,27 @@ def dynamic_ffn2( acc1_vmem, acc3_vmem, w2_vmem, + w2_scale_vmem, res_b32_vmem, dyn_sz, should_init, ): - assert res_b32_vmem.shape == (bt * num_devices, bd2_per_packing) - assert w2_vmem.shape == (t_packing, bf, bd2_per_packing), ( - w2_vmem.shape, - t_packing, - bf, - bd2_per_packing, - ) + assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing) + assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing) assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf) assert bd2 % (t_packing * 128) == 0, (bd2, t_packing) assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing) assert t_dtype in (jnp.float32, jnp.bfloat16) + if w2_scale_vmem is not None: + assert w2_scale_vmem.shape == ( + t_packing, + bf // subc_quant_wsz, + 1, + bd2_per_t_packing, + ) + assert bfc == subc_quant_wsz + num_loops = cdiv(dyn_sz, btc) assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing) @@ -598,22 +787,35 @@ def body(btc_id, _): for bd2c_id in range(cdiv(bd2, bd2c)): res_lst = [] for p_id in range(t_packing): - res = jnp.zeros((btc, bd2c_per_packing), dtype=jnp.float32) + res = jnp.zeros((btc, bd2c_per_t_packing), + dtype=jnp.float32) for bfc_id in range(cdiv(bf, bfc)): acc_slices = (pl.ds(btc_id * btc, btc), pl.ds(bfc_id * bfc, bfc)) acc1 = acc1_vmem[*acc_slices] acc3 = acc3_vmem[*acc_slices] - act = jax.nn.silu(acc1) * acc3 + act = activation_fn(acc1, acc3, act_fn) w2 = w2_vmem[ p_id, pl.ds(bfc_id * bfc, bfc), pl.ds(bd2c_id * - bd2c_per_packing, bd2c_per_packing), + bd2c_per_t_packing, bd2c_per_t_packing), ] - res += jnp.dot(act, - w2, - preferred_element_type=jnp.float32) + acc = jnp.dot(act, + w2, + preferred_element_type=jnp.float32) + if w2_scale_vmem is not None: + w2_scale_slices = ( + p_id, + bfc_id, + pl.ds(0, 1), + pl.ds(bd2c_id * bd2c_per_t_packing, + bd2c_per_t_packing), + ) + w2_scale = jnp.broadcast_to( + w2_scale_vmem[*w2_scale_slices], acc.shape) + acc *= w2_scale + res += acc res = pltpu.bitcast(res, jnp.uint32) if t_packing == 2: res = res >> 16 << (16 * p_id) @@ -626,7 +828,7 @@ def body(btc_id, _): res |= res_lst[i] sliced_res_vmem = res_b32_vmem.at[ pl.ds(btc_id * btc, btc), - pl.ds(bd2c_id * bd2c_per_packing, bd2c_per_packing), + pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing), ] if should_init: sliced_res_vmem[...] = res @@ -655,21 +857,27 @@ def expert_ffn(bt_id, e_sem_id, local_e_id): e_id = my_id * local_num_experts + local_e_id dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id] - bd1_per_packing = bd1 // t_packing - bd2_per_packing = bd2 // t_packing + bd1_per_t_packing = bd1 // t_packing + bd2_per_t_packing = bd2 // t_packing for bf_id in range(num_bf): for bd1_id in range(num_bd1): start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0) + w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else + b_w1_scale_x2_vmem.at[bw_sem_id]) + w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else + b_w3_scale_x2_vmem.at[bw_sem_id]) wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id) wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id) dynamic_ffn1( t_b32_vmem=a2a_s_b32_vmem.at[ ..., - pl.ds(bd1_id * bd1_per_packing, bd1_per_packing)], + pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)], w1_vmem=b_w1_x2_vmem.at[bw_sem_id], + w1_scale_vmem=w1_scale_vmem, w3_vmem=b_w3_x2_vmem.at[bw_sem_id], + w3_scale_vmem=w3_scale_vmem, acc1_vmem=b_acc1_vmem, acc3_vmem=b_acc3_vmem, dyn_sz=dyn_sz, @@ -684,13 +892,16 @@ def expert_ffn(bt_id, e_sem_id, local_e_id): if bf_id == bd2_id == 0: wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2) + w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else + b_w2_scale_x2_vmem.at[bw_sem_id]) dynamic_ffn2( acc1_vmem=b_acc1_vmem, acc3_vmem=b_acc3_vmem, w2_vmem=b_w2_x2_vmem.at[bw_sem_id], + w2_scale_vmem=w2_scale_vmem, res_b32_vmem=a2a_s_acc_b32_vmem.at[ ..., - pl.ds(bd2_id * bd2_per_packing, bd2_per_packing)], + pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)], dyn_sz=dyn_sz, should_init=(bf_id == 0), ) @@ -757,7 +968,7 @@ def run_per_bt(bt_id, e_sem_id): b_gating = b_gating_x2_vmem[bt_sem_id] b_gating_score = jax.nn.softmax(b_gating, axis=-1) top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k( - b_gating_score, top_k) + b_gating_score, top_k, renormalize_topk_logits) all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts, expert_sizes) @@ -827,6 +1038,9 @@ def _(): static_argnames=[ "mesh", "top_k", + "renormalize_topk_logits", + "act_fn", + "subc_quant_wsz", "bt", "bf", "bd1", @@ -845,7 +1059,16 @@ def fused_ep_moe( w2: jax.Array, # (num_experts, intermediate_size, hidden_size) gating_output: jax.Array, # (num_tokens, num_experts) top_k: int, + renormalize_topk_logits: bool = False, + act_fn: str = "silu", *, + subc_quant_wsz: int | None = None, + w1_scale: ( + jax.Array | None + ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size) + w2_scale: ( + jax.Array | None + ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size) # Kernel tuning parameters. bt: int, bf: int, @@ -855,18 +1078,19 @@ def fused_ep_moe( bfc: int, bd1c: int, bd2c: int, - ep_axis_name: str = 'model', + ep_axis_name: str = "model", ): + # TODO(jevinjiang): move all these assertions to validation function. # Assert all other axes have length of 1 - assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference" - assert 'data' in mesh.shape and mesh.shape['data'] == 1, \ - "Expect data axis size of 1 in tpu-inference" + assert len(mesh.shape) == 2, "Expect 2D mesh" + assert ("data" in mesh.shape + and mesh.shape["data"] == 1), "Expect data axis size of 1" ep_size = mesh.shape[ep_axis_name] num_devices = ep_size num_tokens, actual_hidden_size = tokens.shape - num_experts, intermediate_size, _ = w2.shape + num_experts, actual_intermediate_size, _ = w2.shape assert num_tokens % ep_size == 0 assert num_experts % ep_size == 0 @@ -874,26 +1098,18 @@ def fused_ep_moe( local_num_tokens = num_tokens // ep_size # local_num_experts = num_experts // ep_size padded_num_experts = align_to(num_experts, 128) - t_dtype = tokens.dtype t_packing = get_dtype_packing(t_dtype) - hidden_size = align_to(actual_hidden_size, 128 * t_packing) - if hidden_size != actual_hidden_size: - tokens = jnp.pad( - tokens, - ((0, 0), (0, hidden_size - actual_hidden_size)), - constant_values=0, - ) - tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing) - bt = min(bt, local_num_tokens) - bf = min(bf, intermediate_size) - bd1 = min(bd1, hidden_size) - bd2 = min(bd2, hidden_size) - btc = min(btc, bt * num_devices) - bfc = min(bfc, bf) - bd1c = min(bd1c, bd1) - bd2c = min(bd2c, bd2) + if subc_quant_wsz is not None: + if subc_quant_wsz % 256 != 0: + raise NotImplementedError( + "Sub-quantized window is not aligned to 256.") + # We force compute size of contracting dim to subc_quant_wsz. So we can + # apply same scale after matmul and accumulation. + bd1c = subc_quant_wsz * t_packing + bfc = subc_quant_wsz + assert bfc % 128 == 0 assert bd1c % (t_packing * 128) == 0 assert bd2c % (t_packing * 128) == 0 @@ -901,6 +1117,30 @@ def fused_ep_moe( assert bd1 % bd1c == 0 assert bd2 % bd2c == 0 + btc = min(btc, bt * num_devices) + hidden_size = align_to(actual_hidden_size, 128 * t_packing) + # TODO(jevinjiang): instead of padding outside the kernel, we can try dynammic + # masking inside the kernel. + hidden_size = align_to(hidden_size, bd1) + hidden_size = align_to(hidden_size, bd2) + intermediate_size = align_to(actual_intermediate_size, bf) + + # TODO(jevinjiang): we should dump scale as the kernel expected shape in the + # checkpoint offline or reshape right after weight loading. + if w1_scale is not None: + assert w1_scale.shape[0] == w1.shape[0] + assert w1_scale.shape[1] == w1.shape[1] == 2 + assert w1_scale.shape[2] == cdiv(w1.shape[2], subc_quant_wsz) + assert w1_scale.shape[3] == w1.shape[3] + w1_scale = jnp.expand_dims(w1_scale.astype(jnp.float32), axis=-2) + + if w2_scale is not None: + assert w2_scale.shape[0] == w2.shape[0] + assert w2_scale.shape[1] == cdiv(w2.shape[1], subc_quant_wsz) + assert w2_scale.shape[2] == w2.shape[2] + w2_scale = jnp.expand_dims(w2_scale.astype(jnp.float32), axis=-2) + + # Prepare inputs for the kernel. if padded_num_experts != gating_output.shape[-1]: gating_output = jnp.pad( gating_output, @@ -908,13 +1148,70 @@ def fused_ep_moe( constant_values=-jnp.inf, ) - scope_name = f"fused_moe_k-{top_k}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}" + if (hidden_size != actual_hidden_size + or intermediate_size != actual_intermediate_size): + tokens = jnp.pad( + tokens, + ((0, 0), (0, hidden_size - actual_hidden_size)), + constant_values=0, + ) + w1 = jnp.pad( + w1, + ( + (0, 0), + (0, 0), + (0, hidden_size - actual_hidden_size), + (0, intermediate_size - actual_intermediate_size), + ), + constant_values=0, + ) + w2 = jnp.pad( + w2, + ( + (0, 0), + (0, intermediate_size - actual_intermediate_size), + (0, hidden_size - actual_hidden_size), + ), + constant_values=0, + ) + if w1_scale is not None: + w1_scale = jnp.pad( + w1_scale, + ( + (0, 0), + (0, 0), + (0, + cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]), + (0, 0), + (0, intermediate_size - w1_scale.shape[-1]), + ), + constant_values=0, + ) + if w2_scale is not None: + w2_scale = jnp.pad( + w2_scale, + ( + (0, 0), + (0, cdiv(intermediate_size, subc_quant_wsz) - + w2_scale.shape[-3]), + (0, 0), + (0, hidden_size - w2_scale.shape[-1]), + ), + constant_values=0, + ) + tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing) + + hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM) + scope_name = f"fused_moe_k-{top_k}_renorm-{renormalize_topk_logits}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}" fused_moe = jax.named_scope(scope_name)( pl.pallas_call( functools.partial( _fused_ep_moe_kernel, top_k=top_k, + renormalize_topk_logits=renormalize_topk_logits, ep_axis_name=ep_axis_name, + act_fn=act_fn, + subc_quant_wsz=subc_quant_wsz, bt=bt, bf=bf, bd1=bd1, @@ -929,11 +1226,15 @@ def fused_ep_moe( grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM), - pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM), - pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM), - pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM), - pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM), + hbm_block_spec, # tokens_hbm + hbm_block_spec, # w1_hbm + hbm_block_spec, # w2_hbm + None + if w1_scale is None else hbm_block_spec, # w1_scale_hbm + None + if w2_scale is None else hbm_block_spec, # w2_scale_hbm + hbm_block_spec, # gating_output_hbm + hbm_block_spec, # a2a_g_hbm ], out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM), scratch_shapes=([ @@ -984,6 +1285,39 @@ def fused_ep_moe( pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype), # b_w2_x2_vmem pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype), + # b_w1_scale_x2_vmem + (None if w1_scale is None else pltpu.VMEM( + ( + 2, + t_packing, + bd1 // t_packing // subc_quant_wsz, + 1, + bf, + ), + jnp.float32, + )), + # b_w3_scale_x2_vmem + (None if w1_scale is None else pltpu.VMEM( + ( + 2, + t_packing, + bd1 // t_packing // subc_quant_wsz, + 1, + bf, + ), + jnp.float32, + )), + # b_w2_scale_x2_vmem + (None if w2_scale is None else pltpu.VMEM( + ( + 2, + t_packing, + bf // subc_quant_wsz, + 1, + bd2 // t_packing, + ), + jnp.float32, + )), # b_acc_vmem pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32), # local_sems @@ -1006,21 +1340,35 @@ def fused_ep_moe( )) @jax.jit - @functools.partial( - shard_map.shard_map, + @jax.shard_map( mesh=mesh, - in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name), - P(ep_axis_name), P()), + in_specs=( + P(ep_axis_name), # tokens_hbm + P(ep_axis_name), # w1_hbm + P(ep_axis_name), # w2_hbm + None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm + None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm + P(ep_axis_name), # gating_output_hbm + P(), # a2a_g_hbm + ), out_specs=P(ep_axis_name), - check_rep=False, + check_vma=False, ) - def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch): + def kernel(tokens, w1, w2, w1_scale, w2_scale, gating_output, + a2a_g_hbm_scratch): return fused_moe( - pltpu.with_memory_space_constraint(tokens, pltpu.HBM), - pltpu.with_memory_space_constraint(w1, pltpu.HBM), - pltpu.with_memory_space_constraint(w2, pltpu.HBM), - pltpu.with_memory_space_constraint(gating_output, pltpu.HBM), - pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM), + pltpu.with_memory_space_constraint(tokens, + pltpu.HBM), # tokens_hbm + pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm + pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm + (None if w1_scale is None else pltpu.with_memory_space_constraint( + w1_scale, pltpu.HBM)), # w1_scale_hbm + (None if w2_scale is None else pltpu.with_memory_space_constraint( + w2_scale, pltpu.HBM)), # w2_scale_hbm + pltpu.with_memory_space_constraint(gating_output, + pltpu.HBM), # gating_output_hbm + pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, + pltpu.HBM), # a2a_g_hbm ) a2a_g_hbm_scratch = pl.empty( @@ -1029,6 +1377,8 @@ def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch): tokens, w1, w2, + w1_scale, + w2_scale, gating_output, a2a_g_hbm_scratch, )