Skip to content

Commit 1efb3dc

Browse files
wip
1 parent d1a2ea0 commit 1efb3dc

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def expert_sharded_gmm(
209209
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
210210
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
211211
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
212-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m//mesh.shape["data"], k, n, g)
212+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
213213

214214
num_experts_per_shard = num_experts // ep_size
215215
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
@@ -254,7 +254,7 @@ def _gmm(lhs, rhs, group_sizes, group_offset):
254254
gmm_res = shard_map(
255255
_gmm,
256256
mesh=mesh,
257-
in_specs=(P("data", None), P("model", None, None), P("data"), P("model")),
257+
in_specs=(P(), P("model", None, None), P(), P("model")),
258258
out_specs=(P("model", None)),
259259
check_rep=False,
260260
)(lhs, rhs, group_sizes, group_offset)

0 commit comments

Comments
 (0)