Skip to content

Commit 0a5fed5

Browse files
bzgooglebzgoogle
andauthored
[GPT-OSS] fix unstable sparse sum among different (#968)
Signed-off-by: bzgoogle <> Signed-off-by: bzgoogle <beinuoz_google_com@t1v-n-fa0da4f0-w-0.us-central1-c.c.cloud-tpu-inference-test.internal> Co-authored-by: bzgoogle <beinuoz_google_com@t1v-n-fa0da4f0-w-0.us-central1-c.c.cloud-tpu-inference-test.internal>
1 parent 880fa22 commit 0a5fed5

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

tpu_inference/layers/jax/moe/gpt_oss_moe.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class GptOssRouter(Router):
1818
"""Router module for Mixture-of-Experts (MoE) layers.
1919
20-
This module determines which experts each token should be routed to based on the input.
20+
This module determines which experts each token should be routed.
2121
2222
"""
2323
e_sharding: Sharding = ()
@@ -97,12 +97,6 @@ def __call__(self, x_TD: Float) -> Float:
9797

9898
weights_TX, indices_TX = self.router(x_TD)
9999

100-
one_hot_mask_TXE = jax.nn.one_hot(indices_TX,
101-
num_classes=self.num_local_experts,
102-
dtype=self.dtype)
103-
combined_weights_TE = jnp.sum(one_hot_mask_TXE * weights_TX[..., None],
104-
axis=1)
105-
106100
# First MLP layer (up-projection)
107101
with jax.named_scope("MLP #1"):
108102
up_proj_TEF2 = jnp.einsum('TD,EDF -> TEF', x_TD,
@@ -121,8 +115,12 @@ def __call__(self, x_TD: Float) -> Float:
121115

122116
# Weighted sum of expert outputs
123117
with jax.named_scope("sum"):
124-
output_TD = jnp.einsum('TED,TE -> TD', down_proj_TED,
125-
combined_weights_TE)
118+
indices_for_gather = indices_TX[..., None]
119+
gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED,
120+
indices_for_gather,
121+
axis=1)
122+
output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED,
123+
weights_TX)
126124

127125
return output_TD.astype(self.dtype)
128126

0 commit comments

Comments
 (0)