Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions benchmark/examples/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _worker(rank: int, world_size: int, init_url: str, args):
dist.all_gather_into_tensor(y_tri, z_dp_local.contiguous())

if args.breakdown:
N_BREAKDOWN_ITERS = 5
N_BREAKDOWN_ITERS = 10
stage_ms = {}
for _ in range(N_BREAKDOWN_ITERS):
shmem.heap.allocator.heap_offset = sweep_heap_base
Expand All @@ -281,10 +281,13 @@ def _worker(rank: int, world_size: int, init_url: str, args):
ms = td[j - 1][1].elapsed_time(td[j][1])
stage_ms.setdefault(key, []).append(ms)
if rank == 0:
print(
" [breakdown bpe={}] ".format(bpe)
+ " ".join("{}={:.2f}ms".format(k, sum(v) / len(v)) for k, v in stage_ms.items())
)
total_avg = sum(sum(v) / len(v) for v in stage_ms.values())
parts = []
for k, v in stage_ms.items():
avg = sum(v) / len(v)
pct = 100 * avg / total_avg if total_avg > 0 else 0
parts.append("{}={:.2f}ms ({:.1f}%)".format(k, avg, pct))
print(" [breakdown bpe={} total={:.2f}ms] ".format(bpe, total_avg) + " ".join(parts))

result = {
"world_size": ws,
Expand Down
3 changes: 2 additions & 1 deletion examples/31_expert_sharded_moe/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _convert_ep_to_dp(
dst_indx_local = dst_indx_global - dst_rank * n_slots_per_rank

offs_n = tl.arange(0, BLOCK)
offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK)
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.multiple_of(offs_n, BLOCK) is an incorrect hint for tl.arange(0, BLOCK) and may allow the compiler to assume alignments that don't hold. Apply tl.multiple_of to an aligned base (e.g., start_n if it’s known aligned, or the base pointer) and use tl.max_contiguous for contiguity.

Suggested change
offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK)
offs_n = tl.max_contiguous(offs_n, BLOCK)

Copilot uses AI. Check for mistakes.
for start_n in range(0, src_shape_n, BLOCK):
mask_n = start_n + offs_n < src_shape_n
src = tl.load(
Expand All @@ -64,7 +65,7 @@ def _convert_ep_to_dp(
dst_off = dst_indx_local * dst_stride_m + start_n + offs_n
for r in tl.static_range(N_RANKS):
if dst_rank == r:
iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n)
iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16)


def convert_ep_to_dp(src, expt_assignment, expt_indx, topk_indx, shmem):
Expand Down
3 changes: 2 additions & 1 deletion examples/31_expert_sharded_moe/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _convert_dp_to_ep(
off_m_local = pid_m

offs_n = tl.arange(0, BLOCK)
offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK)
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.multiple_of(offs_n, BLOCK) is not a valid guarantee for tl.arange(0, BLOCK) (only 0 is divisible by BLOCK). This can lead to incorrect alignment assumptions during vectorization. Prefer applying tl.multiple_of to an actually aligned base offset/pointer, and only use tl.max_contiguous to communicate contiguity of the per-lane access.

Suggested change
offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK)
offs_n = tl.max_contiguous(offs_n, BLOCK)

Copilot uses AI. Check for mistakes.

for act in tl.static_range(N_EXPT_ACT):
dst_row = tl.load(dst_row_indx_ptr + off_m_global * dst_row_indx_stride_m + act)
Expand All @@ -66,7 +67,7 @@ def _convert_dp_to_ep(
dst_off = dst_row * dst_stride_m + start_n + offs_n
for r in tl.static_range(N_RANKS):
if dst_rank == r:
iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n)
iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16)


def convert_dp_to_ep(src, expt_assignment, expt_indx, gate_indx, shmem):
Expand Down
5 changes: 3 additions & 2 deletions examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _fused_exp_matmul_ep_to_dp_kernel(
if r == SRC_RANK:
tl.store(dst_ptrs_2d, out, mask=store_mask)
else:
iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask)
iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask, hint=(1, 16))


def fused_exp_matmul_ep_to_dp(
Expand Down Expand Up @@ -213,8 +213,9 @@ def fused_exp_matmul_ep_to_dp(
N_RANKS=shmem.get_num_ranks(),
num_warps=8,
num_stages=2,
matrix_instr_nonkdim=16,
kpack=1,
)

torch.cuda.synchronize()
shmem.barrier()
return dst_local
6 changes: 3 additions & 3 deletions examples/31_expert_sharded_moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ def _allgather_push_kernel(
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK)
Comment on lines 47 to +48
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.multiple_of(offs, BLOCK) asserts every element of offs is divisible by BLOCK, but offs is pid*BLOCK + arange(0, BLOCK), so only the first lane is a multiple of BLOCK. This can enable misaligned vectorization assumptions. Apply tl.multiple_of to an aligned base (e.g., pid * BLOCK or the base pointer) and keep tl.max_contiguous for the per-lane offs/pointer.

Suggested change
offs = pid * BLOCK + tl.arange(0, BLOCK)
offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK)
base = tl.multiple_of(pid * BLOCK, BLOCK)
offs = tl.max_contiguous(base + tl.arange(0, BLOCK), BLOCK)

Copilot uses AI. Check for mistakes.
mask = offs < src_numel
data = tl.load(src_ptr + offs, mask=mask)
for r in tl.static_range(N_RANKS):
iris.store(dst_ptr + dst_offset + offs, data, CUR_RANK, r, heap_bases, mask=mask)
dst = dst_ptr + dst_offset + offs
iris.store(dst, data, CUR_RANK, r, heap_bases, mask=mask, hint=16)


def _allgather_iris(local_tensor, shmem):
Expand Down Expand Up @@ -288,8 +290,6 @@ def _tick(label):
# ------------------------------------------------------------------
flat_expt_indx = active_indx.to(torch.int32).reshape(-1)
if fusion_config.fuse_grouped_matmul_convert_ep_to_dp:
torch.cuda.synchronize()
shmem.barrier()
y_dp_local = fused_exp_matmul_ep_to_dp(
y_ep_local,
w_ep_local,
Expand Down
Loading