We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d1a2ea0 commit 1efb3dcCopy full SHA for 1efb3dc
tpu_inference/layers/vllm/fused_moe.py
@@ -209,7 +209,7 @@ def expert_sharded_gmm(
209
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
210
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
211
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
212
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m//mesh.shape["data"], k, n, g)
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
213
214
num_experts_per_shard = num_experts // ep_size
215
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
@@ -254,7 +254,7 @@ def _gmm(lhs, rhs, group_sizes, group_offset):
254
gmm_res = shard_map(
255
_gmm,
256
mesh=mesh,
257
- in_specs=(P("data", None), P("model", None, None), P("data"), P("model")),
+ in_specs=(P(), P("model", None, None), P(), P("model")),
258
out_specs=(P("model", None)),
259
check_rep=False,
260
)(lhs, rhs, group_sizes, group_offset)
0 commit comments