1717class GptOssRouter (Router ):
1818 """Router module for Mixture-of-Experts (MoE) layers.
1919
20- This module determines which experts each token should be routed to based on the input .
20+ This module determines which experts each token should be routed.
2121
2222 """
2323 e_sharding : Sharding = ()
@@ -97,12 +97,6 @@ def __call__(self, x_TD: Float) -> Float:
9797
9898 weights_TX , indices_TX = self .router (x_TD )
9999
100- one_hot_mask_TXE = jax .nn .one_hot (indices_TX ,
101- num_classes = self .num_local_experts ,
102- dtype = self .dtype )
103- combined_weights_TE = jnp .sum (one_hot_mask_TXE * weights_TX [..., None ],
104- axis = 1 )
105-
106100 # First MLP layer (up-projection)
107101 with jax .named_scope ("MLP #1" ):
108102 up_proj_TEF2 = jnp .einsum ('TD,EDF -> TEF' , x_TD ,
@@ -121,8 +115,12 @@ def __call__(self, x_TD: Float) -> Float:
121115
122116 # Weighted sum of expert outputs
123117 with jax .named_scope ("sum" ):
124- output_TD = jnp .einsum ('TED,TE -> TD' , down_proj_TED ,
125- combined_weights_TE )
118+ indices_for_gather = indices_TX [..., None ]
119+ gathered_down_proj_TED = jnp .take_along_axis (down_proj_TED ,
120+ indices_for_gather ,
121+ axis = 1 )
122+ output_TD = jnp .einsum ('TXD,TX -> TD' , gathered_down_proj_TED ,
123+ weights_TX )
126124
127125 return output_TD .astype (self .dtype )
128126
0 commit comments