Skip to content

Commit 5bb7c55

Browse files
author
Kaniel_Zhou
committed
update test case
1 parent 9804516 commit 5bb7c55

File tree

2 files changed

+71
-41
lines changed

2 files changed

+71
-41
lines changed

.github/workflows/pr-test-npu.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ jobs:
7777
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3
7878
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2
7979
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16
80+
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 4 --topk-drop-col 1 --num-experts 32
8081
8182
test-build-deepep:
8283
if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') &&
@@ -137,6 +138,7 @@ jobs:
137138
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3
138139
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2
139140
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16
141+
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 4 --topk-drop-col 1 --num-experts 32
140142
141143
finish:
142144
if: always()

tests/python/deepep/test_fused_deep_moe.py

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test(
257257
x = torch.rand((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * 10 - 5
258258

259259
# ----- Routing(topk_idx) -----
260-
if args.debug and args.active_ranks:
260+
if args.active_ranks:
261261
try:
262262
active_ranks = [
263263
int(r.strip()) for r in args.active_ranks.split(",") if r.strip()
@@ -292,7 +292,7 @@ def test(
292292
topk_idx = valid_experts[
293293
torch.randint(0, len(valid_experts), (num_tokens, num_topk), device="npu")
294294
]
295-
if args.debug and rank == 0:
295+
if rank == 0:
296296
print(
297297
f"[config] active_ranks={active_ranks}, valid_experts={len(valid_experts)}",
298298
flush=True,
@@ -324,7 +324,7 @@ def test(
324324
w2_weight_scale.clone().detach(),
325325
)
326326

327-
if args.debug and rank == 0:
327+
if rank == 0:
328328
print("=== Check fused weights ===")
329329
print("w13_f:", w13_f.shape, w13_f.dtype, w13_f.device)
330330
print("w13s_f:", w13s_f.shape, w13s_f.dtype, w13s_f.device)
@@ -338,8 +338,7 @@ def test(
338338
start, end = r * experts_per_rank, (r + 1) * experts_per_rank
339339
tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum()
340340

341-
if args.debug:
342-
print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True)
341+
print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True)
343342

344343
# ====== ensure topk_weights is defined (fix missing var) ======
345344
topk_weights = torch.randn(
@@ -376,33 +375,23 @@ def test(
376375
invalid_mask = topk_idx_dropped == -1
377376
topk_weights_dropped = topk_weights_dropped.masked_fill(invalid_mask, 0.0)
378377

379-
if args.debug:
380-
print(
381-
f"[DEBUG] topk_idx_dropped (after random drop):\n{topk_idx_dropped.cpu().numpy()}",
382-
flush=True,
383-
)
384-
print(
385-
f"[DEBUG] topk_weights_dropped (after random drop):\n{topk_weights_dropped.cpu().numpy()}",
386-
flush=True,
387-
)
388-
389378
# Fixed column drop (for the test_topk_minus1 scenario)
390379
if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk:
391380
topk_idx_dropped[:, args.topk_drop_col] = -1
392381
topk_weights_dropped[:, args.topk_drop_col] = 0
393-
if args.debug:
394-
print(
395-
f"[DEBUG] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}",
396-
flush=True,
397-
)
398-
print(
399-
f"[DEBUG] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}",
400-
flush=True,
401-
)
382+
383+
print(
384+
f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}",
385+
flush=True,
386+
)
387+
print(
388+
f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}",
389+
flush=True,
390+
)
402391

403392
# print drop ratio
404393
drop_ratio = (topk_idx_dropped == -1).float().mean().item()
405-
if args.debug and rank == 0:
394+
if rank == 0:
406395
print(
407396
f"[DEBUG] [rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%",
408397
flush=True,
@@ -411,6 +400,20 @@ def test(
411400
topk_idx_dropped = topk_idx
412401
topk_weights_dropped = topk_weights
413402

403+
# Expert meta
404+
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="npu")
405+
for i in range(num_experts):
406+
num_tokens_per_expert[i] = (topk_idx_dropped == i).sum()
407+
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
408+
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
409+
410+
print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}")
411+
if rank == 0:
412+
print(
413+
f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}"
414+
)
415+
base_prefix_sum = num_tokens_per_expert.clone()
416+
414417
# ----- Baseline -----
415418
baseline_output, base_ep_recv_count = baseline_test(
416419
buffer2,
@@ -456,19 +459,49 @@ def test(
456459
assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}"
457460

458461
# ----- Compare Recv Count -----
459-
if args.topk_drop_col < 0 and args.topk_drop_prob == 0.0:
460-
recv_count_diff = (
461-
from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count
462-
).abs()
463-
max_recv_count_diff = recv_count_diff.max().item()
464-
mean_recv_count_diff = recv_count_diff.mean().item()
462+
global_base_prefix_sum = [
463+
torch.zeros_like(base_prefix_sum) for _ in range(num_ranks)
464+
]
465+
dist.all_gather(global_base_prefix_sum, base_prefix_sum)
466+
467+
global_base_prefix_sum = torch.stack(global_base_prefix_sum, dim=0)
468+
469+
if rank == 0:
465470
print(
466-
f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}",
467-
flush=True,
471+
f"[DEBUG] Global base_prefix_sum (before transpose):\n{global_base_prefix_sum}"
468472
)
469-
assert (
470-
max_recv_count_diff < 1e-4
471-
), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}"
473+
474+
transposed_base_prefix_sum = global_base_prefix_sum.T
475+
if rank == 0:
476+
print(f"[DEBUG] Transposed base_prefix_sum:\n{transposed_base_prefix_sum}")
477+
print(f"[DEBUG] Transposed base_prefix_sum: {transposed_base_prefix_sum.shape}")
478+
479+
experts_per_rank = num_experts // dist.get_world_size()
480+
start_expert = rank * experts_per_rank
481+
end_expert = start_expert + experts_per_rank
482+
483+
# shape [experts_per_rank * num_ranks]
484+
expected_recv = transposed_base_prefix_sum[start_expert:end_expert].reshape(-1)
485+
fused_recv = fused_ep_recv_count
486+
487+
print(f"expected_recv: {expected_recv}")
488+
print(f"fused_recv: {fused_recv}")
489+
490+
diff = (expected_recv - fused_recv).abs()
491+
print(
492+
f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}",
493+
flush=True,
494+
)
495+
496+
max_recv_count_diff = diff.max().item()
497+
mean_recv_count_diff = diff.mean().item()
498+
print(
499+
f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}",
500+
flush=True,
501+
)
502+
assert (
503+
max_recv_count_diff < 1e-4
504+
), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}"
472505

473506

474507
# ======================== Distributed Entry ========================
@@ -564,11 +597,6 @@ def str_to_bool(value):
564597
default=-1,
565598
help="If >=0, drop this specific top-k column (set index to -1 for testing).",
566599
)
567-
parser.add_argument(
568-
"--debug",
569-
action="store_true",
570-
help="Enable debug logging.",
571-
)
572600

573601
args = parser.parse_args()
574602
num_processes = args.num_processes

0 commit comments

Comments
 (0)