From 119d2835b335adda13d5c53f0e2e1c1aa9b5eb9f Mon Sep 17 00:00:00 2001 From: pengcuo Date: Fri, 15 Nov 2024 10:58:17 +0800 Subject: [PATCH] Fix a bug for col_major_moe kernel --- kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py b/kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py index dd33665..31b2e05 100644 --- a/kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py +++ b/kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py @@ -22,7 +22,10 @@ def col_major(pid, grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) - pid_m = (pid % grid_n) + # There is a bug. + # pid_m = (pid % grid_n) + # The result is correct, but the speedup is not as good as mentioned in the documentation + pid_m = (pid % grid_m) pid_n = pid // grid_m return pid_m, pid_n