From 5f07e2c3247d00d204edaf60f9f842cda408ee3a Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 15:23:09 +0800 Subject: [PATCH 01/16] Testing the generalization of fusion operators --- tests/python/deepep/test_fused_deep_moe.py | 323 +++++++++++---------- 1 file changed, 170 insertions(+), 153 deletions(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 436cc77b..0f98120f 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -12,16 +12,23 @@ from utils import bench, calc_diff, hash_tensor, init_dist torch_npu.npu.config.allow_internal_format = True -test_topk_minus1 = False -small_bs_flag = False # ======================== Weight Initialization ======================== -def init_base_weights(): - w13_weight = torch.randint(-16, 16, [16, 4096, 7168]).to(torch.int8) - w2_weight = torch.randint(-16, 16, [16, 7168, 2048]).to(torch.int8) - w13_weight_scale = (torch.rand([16, 4096, 1]) * 0.0004 + 0.0015).bfloat16() - w2_weight_scale = (torch.rand([16, 7168, 1]) * 0.0004 + 0.0015).bfloat16() +def init_base_weights(num_local_experts, hidden_in=7168, hidden_mid=4096, hidden_out=2048): + """ + 初始化每个本地专家的权重。 + num_local_experts: 每个 rank 上的专家数 = num_experts // num_ranks + hidden_in: 输入维度 (默认 7168) + hidden_mid: 中间层维度 (默认 4096) + hidden_out: 输出维度 (默认 2048) + """ + + w13_weight = torch.randint(-16, 16, [num_local_experts, hidden_mid, hidden_in], dtype=torch.int8) + w2_weight = torch.randint(-16, 16, [num_local_experts, hidden_in, hidden_out], dtype=torch.int8) + + w13_weight_scale = (torch.rand([num_local_experts, hidden_mid, 1]) * 0.0004 + 0.0015).bfloat16() + w2_weight_scale = (torch.rand([num_local_experts, hidden_in, 1]) * 0.0004 + 0.0015).bfloat16() return w13_weight, w13_weight_scale, w2_weight, w2_weight_scale @@ -65,13 +72,13 @@ def reshape_fusion_gmm_weight(weight, dim): def init_fused_weights_int8( - w13_weight, - w13_weight_scale, - w2_weight, - w2_weight_scale, - device="npu", - block_m: int = 16, - block_n: int = 16, + w13_weight, + w13_weight_scale, + w2_weight, + w2_weight_scale, + device="npu", + block_m: int = 16, + block_n: int = 16, ): # -------- w13_weight -------- @@ -99,7 +106,7 @@ def init_fused_weights_int8( # ======================== Utility Functions ======================== def make_uniform_topk_idx( - num_tokens: int, num_experts: int, num_ranks: int, num_topk: int, device="npu" + num_tokens: int, num_experts: int, num_ranks: int, num_topk: int, device="npu" ): assert num_experts % num_ranks == 0, "num_experts must be divisible by num_ranks" experts_per_rank = num_experts // num_ranks @@ -131,18 +138,18 @@ def from_inclusive_prefix_sum(pref): # ======================== Baseline Reference ======================== def baseline_test( - buffer, - x, - topk_idx, - num_tokens, - num_experts, - cumulative_local_expert_recv_stats, - return_recv_hook, - w13, - w13_scale, - w2, - w2_scale, - topk_weights, + buffer, + x, + topk_idx, + num_tokens, + num_experts, + cumulative_local_expert_recv_stats, + return_recv_hook, + w13, + w13_scale, + w2, + w2_scale, + topk_weights, ): hidden_states, packed_recv_count, handle, _, _ = buffer.low_latency_dispatch( x, @@ -213,17 +220,17 @@ def baseline_test( # ======================== Main Test ======================== def test( - num_tokens: int, - hidden: int, - num_experts: int, - num_topk: int, - rank: int, - num_ranks: int, - group: dist.ProcessGroup, - buffer: Buffer, - buffer2: Buffer, - args: argparse.Namespace, - seed: int = 0, + num_tokens: int, + hidden: int, + num_experts: int, + num_topk: int, + rank: int, + num_ranks: int, + group: dist.ProcessGroup, + buffer: Buffer, + buffer2: Buffer, + args: argparse.Namespace, + seed: int = 0, ): torch.manual_seed(seed + rank) random.seed(seed + rank) @@ -234,12 +241,13 @@ def test( # NOTES: the integers greater than 256 exceeds the BF16 precision limit rank_offset = 128 assert ( - num_ranks - rank_offset < 257 + num_ranks - rank_offset < 257 ), "Too many ranks (exceeding test precision limit)" x = torch.rand((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * 10 - 5 + # ----- Routing(topk_idx) ----- - if args.active_ranks: + if args.debug and args.active_ranks: try: active_ranks = [ int(r.strip()) for r in args.active_ranks.split(",") if r.strip() @@ -274,22 +282,25 @@ def test( topk_idx = valid_experts[ torch.randint(0, len(valid_experts), (num_tokens, num_topk), device="npu") ] - if rank == 0: + if args.debug and rank == 0: print( f"[config] active_ranks={active_ranks}, valid_experts={len(valid_experts)}", flush=True, ) else: scores = ( - torch.randn( - (num_tokens, num_experts), dtype=torch.float32, device="npu" - ).abs() - + 1 + torch.randn( + (num_tokens, num_experts), dtype=torch.float32, device="npu" + ).abs() + + 1 ) topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] # ----- Weights ----- - w13_weight, w13_weight_scale, w2_weight, w2_weight_scale = init_base_weights() + w13_weight, w13_weight_scale, w2_weight, w2_weight_scale = init_base_weights( + num_local_experts=num_local_experts, + hidden_in=hidden, + ) w13, w13_scale, w2, w2_scale = init_baseline_weights( w13_weight.clone().detach(), w13_weight_scale.clone().detach(), @@ -303,7 +314,7 @@ def test( w2_weight_scale.clone().detach(), ) - if rank == 0: + if args.debug and rank == 0 : print("=== Check fused weights ===") print("w13_f:", w13_f.shape, w13_f.dtype, w13_f.device) print("w13s_f:", w13s_f.shape, w13s_f.dtype, w13s_f.device) @@ -316,96 +327,96 @@ def test( for r in range(num_ranks): start, end = r * experts_per_rank, (r + 1) * experts_per_rank tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum() - print(f"Tokens per rank: {tokens_per_rank}") - - # ----- Random drop ----- - if args.drop_prob > 0: - drop_mask = torch.rand_like(topk_idx, dtype=torch.float32) < args.drop_prob - topk_idx = topk_idx.masked_fill(drop_mask, -1) - for i in range(num_tokens): - if (topk_idx[i] == -1).all(): - topk_idx[i, 0] = torch.topk(scores[i], 1, largest=True)[1].item() - - topk_weights = torch.randn( - (num_tokens, num_topk), dtype=torch.float32, device="npu" - ).abs() - cumulative_local_expert_recv_stats = torch.zeros( - (num_local_experts,), dtype=torch.int, device="npu" - ) - return_recv_hook = False - hidden_states = x + if args.debug: + print(f"Tokens per rank: {tokens_per_rank}", flush=True) - if small_bs_flag and rank == 0: - # Test with a small batch size of 1 + # ====== ensure topk_weights is defined (fix missing var) ====== + topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="npu").abs() + + # ====== cumulative stats and flags ====== + cumulative_local_expert_recv_stats = torch.zeros((num_local_experts,), dtype=torch.int32, device="npu") + return_recv_hook = False + + # ----- Small-batch for debug (local rank 0 only) ----- + if args.small_bs_flag and rank == 0: + print("[rank 0] small_bs_flag active: truncating to batch 1", flush=True) x = x[:1, :] topk_idx = topk_idx[:1, :] topk_weights = topk_weights[:1, :] - - if test_topk_minus1: - topk_idx_minus1 = topk_idx.clone() - topk_idx_minus1[:, -2:-1] = -1 - topk_weights_minus1 = topk_weights.clone() - topk_weights_minus1[:, -2:-1] = 0 - # ----- Baseline ----- - baseline_output, base_ep_recv_count = baseline_test( - buffer2, - x, - topk_idx, - num_tokens, - num_experts, - cumulative_local_expert_recv_stats, - return_recv_hook, - w13, - w13_scale, - w2, - w2_scale, - topk_weights_minus1, - ) - # ----- Fused ----- - fused_output, fused_ep_recv_count = buffer.fused_deep_moe( - x, - topk_idx_minus1, - topk_weights, - w13_f, - w13s_f, - w2_f, - w2s_f, - num_tokens, - num_experts, - 0, - ) - + num_tokens = x.shape[0] + + # ----- Random or fixed drop ----- + if args.topk_drop_prob > 0 or args.topk_drop_col >= 0: + topk_idx_dropped = topk_idx.clone() + topk_weights_dropped = topk_weights.clone() + + # Random drop (based on probability) + if args.topk_drop_prob > 0: + drop_mask = torch.rand_like(topk_idx, dtype=torch.float32) < args.topk_drop_prob + topk_idx_dropped = topk_idx.clone() + topk_idx_dropped = topk_idx_dropped.masked_fill(drop_mask, -1) + + + # Guarantee that each token has at least one valid expert. + for i in range(num_tokens): + if (topk_idx_dropped[i] == -1).all(): + topk_idx_dropped[i, 0] = torch.topk(scores[i], 1, largest=True)[1].item() + + # Construct topk_weights_dropped + invalid_mask = topk_idx_dropped == -1 + topk_weights_dropped = topk_weights_dropped.masked_fill(invalid_mask, 0.0) + + if args.debug: + print(f"[DEBUG] topk_idx_dropped (after random drop):\n{topk_idx_dropped.cpu().numpy()}", flush=True) + print(f"[DEBUG] topk_weights_dropped (after random drop):\n{topk_weights_dropped.cpu().numpy()}", flush=True) + + + # Fixed column drop (for the test_topk_minus1 scenario) + if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk: + topk_idx_dropped[:, args.topk_drop_col] = -1 + topk_weights_dropped[:, args.topk_drop_col] = 0 + if args.debug: + print(f"[DEBUG] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", flush=True) + print(f"[DEBUG] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", flush=True) + + # print drop ratio + drop_ratio = (topk_idx_dropped == -1).float().mean().item() + if args.debug and rank == 0: + print(f"[rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", flush=True) else: - # ----- Baseline ----- - baseline_output, base_ep_recv_count = baseline_test( - buffer2, - x, - topk_idx, - num_tokens, - num_experts, - cumulative_local_expert_recv_stats, - return_recv_hook, - w13, - w13_scale, - w2, - w2_scale, - topk_weights, - ) + topk_idx_dropped = topk_idx + topk_weights_dropped = topk_weights + + # ----- Baseline ----- + baseline_output, base_ep_recv_count = baseline_test( + buffer2, + x, + topk_idx, + num_tokens, + num_experts, + cumulative_local_expert_recv_stats, + return_recv_hook, + w13, + w13_scale, + w2, + w2_scale, + topk_weights_dropped, + ) # ----- Fused ----- - fused_output, fused_ep_recv_count = buffer.fused_deep_moe( - x, - topk_idx, - topk_weights, - w13_f, - w13s_f, - w2_f, - w2s_f, - num_tokens, - num_experts, - 0, - ) + fused_output, fused_ep_recv_count = buffer.fused_deep_moe( + x, + topk_idx_dropped, + topk_weights, + w13_f, + w13s_f, + w2_f, + w2s_f, + num_tokens, + num_experts, + 0, + ) # ----- Compare Outputs ----- max_diff = torch.max(torch.abs(fused_output - baseline_output)).item() @@ -415,30 +426,29 @@ def test( print( f"[Rank {rank}] baseline_avg={baseline_output_avg:.6e}, fused_avg={fused_output_avg:.6e}, " - f"max_diff={max_diff:.6e}, avg_diff={avg_diff:.6e}" + f"max_diff={max_diff:.6e}, avg_diff={avg_diff:.6e}", + flush=True, ) + assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}" # ----- Compare RecvCount ----- - recv_count_diff = ( - from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count - ).abs() - max_recv_count_diff = recv_count_diff.max().item() - mean_recv_count_diff = recv_count_diff.mean().item() - print( - f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}" - ) - - if not test_topk_minus1: - assert ( - max_recv_count_diff < 1e-4 - ), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" + if args.topk_drop_col < 0 and args.topk_drop_prob == 0.0: + recv_count_diff = (from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count).abs() + max_recv_count_diff = recv_count_diff.max().item() + mean_recv_count_diff = recv_count_diff.mean().item() + print( + f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}", + flush=True, + ) + assert max_recv_count_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" # ======================== Distributed Entry ======================== def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - group2 = dist.new_group(list(range(16))) + group2 = dist.new_group(list(range(num_ranks))) + shared_expert_rank_num = int(os.getenv("MOE_SHARED_EXPERT_RANK_NUM", 0)) num_tokens, hidden = args.num_tokens, args.hidden num_topk, num_experts = args.num_topk, args.num_experts @@ -511,19 +521,22 @@ def str_to_bool(value): "--active-ranks", type=str, default="", - help="Comma-separated list of ranks that will receive tokens. " - 'Example: "0,1,3". If empty, all ranks may receive tokens.', + help='Comma-separated list of ranks that will receive tokens. Example: "0,1,3". If empty, all ranks may receive tokens.', ) parser.add_argument( - "--drop-prob", + "--topk-drop-prob", + dest="topk_drop_prob", type=float, default=0.0, - help="Probability of dropping an individual top-k index (set to -1). " - "Guaranteed that each token keeps at least one valid expert.", + help="Probability of randomly dropping a top-k index (set to -1).", ) parser.add_argument( - "--minus1-flag", type=str_to_bool, default=False, help="bool flag, True/False" + "--topk-drop-col", + dest="topk_drop_col", + type=int, + default=-1, + help="If >=0, drop this specific top-k column (set index to -1 for testing).", ) parser.add_argument( @@ -533,12 +546,16 @@ def str_to_bool(value): help="define small bs on certain rank", ) + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug logging.", + ) + args = parser.parse_args() num_processes = args.num_processes - test_topk_minus1 = args.minus1_flag - small_bs_flag = args.small_bs_flag - + # use args.small_bs_flag directly in test() torch.multiprocessing.spawn( test_loop, args=(num_processes, args), nprocs=num_processes ) From ea041526adeb34b6ec8a75d975333421d4495b99 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 15:26:49 +0800 Subject: [PATCH 02/16] fix lint --- tests/python/deepep/test_fused_deep_moe.py | 146 +++++++++++++-------- 1 file changed, 90 insertions(+), 56 deletions(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 0f98120f..3b6a2aa6 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -15,7 +15,9 @@ # ======================== Weight Initialization ======================== -def init_base_weights(num_local_experts, hidden_in=7168, hidden_mid=4096, hidden_out=2048): +def init_base_weights( + num_local_experts, hidden_in=7168, hidden_mid=4096, hidden_out=2048 +): """ 初始化每个本地专家的权重。 num_local_experts: 每个 rank 上的专家数 = num_experts // num_ranks @@ -24,11 +26,19 @@ def init_base_weights(num_local_experts, hidden_in=7168, hidden_mid=4096, hidden hidden_out: 输出维度 (默认 2048) """ - w13_weight = torch.randint(-16, 16, [num_local_experts, hidden_mid, hidden_in], dtype=torch.int8) - w2_weight = torch.randint(-16, 16, [num_local_experts, hidden_in, hidden_out], dtype=torch.int8) + w13_weight = torch.randint( + -16, 16, [num_local_experts, hidden_mid, hidden_in], dtype=torch.int8 + ) + w2_weight = torch.randint( + -16, 16, [num_local_experts, hidden_in, hidden_out], dtype=torch.int8 + ) - w13_weight_scale = (torch.rand([num_local_experts, hidden_mid, 1]) * 0.0004 + 0.0015).bfloat16() - w2_weight_scale = (torch.rand([num_local_experts, hidden_in, 1]) * 0.0004 + 0.0015).bfloat16() + w13_weight_scale = ( + torch.rand([num_local_experts, hidden_mid, 1]) * 0.0004 + 0.0015 + ).bfloat16() + w2_weight_scale = ( + torch.rand([num_local_experts, hidden_in, 1]) * 0.0004 + 0.0015 + ).bfloat16() return w13_weight, w13_weight_scale, w2_weight, w2_weight_scale @@ -72,13 +82,13 @@ def reshape_fusion_gmm_weight(weight, dim): def init_fused_weights_int8( - w13_weight, - w13_weight_scale, - w2_weight, - w2_weight_scale, - device="npu", - block_m: int = 16, - block_n: int = 16, + w13_weight, + w13_weight_scale, + w2_weight, + w2_weight_scale, + device="npu", + block_m: int = 16, + block_n: int = 16, ): # -------- w13_weight -------- @@ -106,7 +116,7 @@ def init_fused_weights_int8( # ======================== Utility Functions ======================== def make_uniform_topk_idx( - num_tokens: int, num_experts: int, num_ranks: int, num_topk: int, device="npu" + num_tokens: int, num_experts: int, num_ranks: int, num_topk: int, device="npu" ): assert num_experts % num_ranks == 0, "num_experts must be divisible by num_ranks" experts_per_rank = num_experts // num_ranks @@ -138,18 +148,18 @@ def from_inclusive_prefix_sum(pref): # ======================== Baseline Reference ======================== def baseline_test( - buffer, - x, - topk_idx, - num_tokens, - num_experts, - cumulative_local_expert_recv_stats, - return_recv_hook, - w13, - w13_scale, - w2, - w2_scale, - topk_weights, + buffer, + x, + topk_idx, + num_tokens, + num_experts, + cumulative_local_expert_recv_stats, + return_recv_hook, + w13, + w13_scale, + w2, + w2_scale, + topk_weights, ): hidden_states, packed_recv_count, handle, _, _ = buffer.low_latency_dispatch( x, @@ -220,17 +230,17 @@ def baseline_test( # ======================== Main Test ======================== def test( - num_tokens: int, - hidden: int, - num_experts: int, - num_topk: int, - rank: int, - num_ranks: int, - group: dist.ProcessGroup, - buffer: Buffer, - buffer2: Buffer, - args: argparse.Namespace, - seed: int = 0, + num_tokens: int, + hidden: int, + num_experts: int, + num_topk: int, + rank: int, + num_ranks: int, + group: dist.ProcessGroup, + buffer: Buffer, + buffer2: Buffer, + args: argparse.Namespace, + seed: int = 0, ): torch.manual_seed(seed + rank) random.seed(seed + rank) @@ -241,7 +251,7 @@ def test( # NOTES: the integers greater than 256 exceeds the BF16 precision limit rank_offset = 128 assert ( - num_ranks - rank_offset < 257 + num_ranks - rank_offset < 257 ), "Too many ranks (exceeding test precision limit)" x = torch.rand((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * 10 - 5 @@ -289,10 +299,10 @@ def test( ) else: scores = ( - torch.randn( - (num_tokens, num_experts), dtype=torch.float32, device="npu" - ).abs() - + 1 + torch.randn( + (num_tokens, num_experts), dtype=torch.float32, device="npu" + ).abs() + + 1 ) topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] @@ -314,7 +324,7 @@ def test( w2_weight_scale.clone().detach(), ) - if args.debug and rank == 0 : + if args.debug and rank == 0: print("=== Check fused weights ===") print("w13_f:", w13_f.shape, w13_f.dtype, w13_f.device) print("w13s_f:", w13s_f.shape, w13s_f.dtype, w13s_f.device) @@ -332,10 +342,14 @@ def test( print(f"Tokens per rank: {tokens_per_rank}", flush=True) # ====== ensure topk_weights is defined (fix missing var) ====== - topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device="npu").abs() + topk_weights = torch.randn( + (num_tokens, num_topk), dtype=torch.float32, device="npu" + ).abs() # ====== cumulative stats and flags ====== - cumulative_local_expert_recv_stats = torch.zeros((num_local_experts,), dtype=torch.int32, device="npu") + cumulative_local_expert_recv_stats = torch.zeros( + (num_local_experts,), dtype=torch.int32, device="npu" + ) return_recv_hook = False # ----- Small-batch for debug (local rank 0 only) ----- @@ -353,37 +367,53 @@ def test( # Random drop (based on probability) if args.topk_drop_prob > 0: - drop_mask = torch.rand_like(topk_idx, dtype=torch.float32) < args.topk_drop_prob + drop_mask = ( + torch.rand_like(topk_idx, dtype=torch.float32) < args.topk_drop_prob + ) topk_idx_dropped = topk_idx.clone() topk_idx_dropped = topk_idx_dropped.masked_fill(drop_mask, -1) - # Guarantee that each token has at least one valid expert. for i in range(num_tokens): if (topk_idx_dropped[i] == -1).all(): - topk_idx_dropped[i, 0] = torch.topk(scores[i], 1, largest=True)[1].item() + topk_idx_dropped[i, 0] = torch.topk(scores[i], 1, largest=True)[ + 1 + ].item() # Construct topk_weights_dropped invalid_mask = topk_idx_dropped == -1 topk_weights_dropped = topk_weights_dropped.masked_fill(invalid_mask, 0.0) if args.debug: - print(f"[DEBUG] topk_idx_dropped (after random drop):\n{topk_idx_dropped.cpu().numpy()}", flush=True) - print(f"[DEBUG] topk_weights_dropped (after random drop):\n{topk_weights_dropped.cpu().numpy()}", flush=True) - + print( + f"[DEBUG] topk_idx_dropped (after random drop):\n{topk_idx_dropped.cpu().numpy()}", + flush=True, + ) + print( + f"[DEBUG] topk_weights_dropped (after random drop):\n{topk_weights_dropped.cpu().numpy()}", + flush=True, + ) # Fixed column drop (for the test_topk_minus1 scenario) if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk: topk_idx_dropped[:, args.topk_drop_col] = -1 topk_weights_dropped[:, args.topk_drop_col] = 0 if args.debug: - print(f"[DEBUG] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", flush=True) - print(f"[DEBUG] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", flush=True) + print( + f"[DEBUG] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", + flush=True, + ) + print( + f"[DEBUG] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", + flush=True, + ) # print drop ratio drop_ratio = (topk_idx_dropped == -1).float().mean().item() if args.debug and rank == 0: - print(f"[rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", flush=True) + print( + f"[rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", flush=True + ) else: topk_idx_dropped = topk_idx topk_weights_dropped = topk_weights @@ -404,7 +434,7 @@ def test( topk_weights_dropped, ) - # ----- Fused ----- + # ----- Fused ----- fused_output, fused_ep_recv_count = buffer.fused_deep_moe( x, topk_idx_dropped, @@ -434,14 +464,18 @@ def test( # ----- Compare RecvCount ----- if args.topk_drop_col < 0 and args.topk_drop_prob == 0.0: - recv_count_diff = (from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count).abs() + recv_count_diff = ( + from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count + ).abs() max_recv_count_diff = recv_count_diff.max().item() mean_recv_count_diff = recv_count_diff.mean().item() print( f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}", flush=True, ) - assert max_recv_count_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" + assert ( + max_recv_count_diff < 1e-4 + ), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" # ======================== Distributed Entry ======================== From 90533400c0998d24b9f0f2418a0f3c660ab788a4 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 15:29:51 +0800 Subject: [PATCH 03/16] fix lint --- tests/python/deepep/test_fused_deep_moe.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 3b6a2aa6..24e49c41 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -19,11 +19,11 @@ def init_base_weights( num_local_experts, hidden_in=7168, hidden_mid=4096, hidden_out=2048 ): """ - 初始化每个本地专家的权重。 - num_local_experts: 每个 rank 上的专家数 = num_experts // num_ranks - hidden_in: 输入维度 (默认 7168) - hidden_mid: 中间层维度 (默认 4096) - hidden_out: 输出维度 (默认 2048) + Initialize the weights for each local expert. + `num_local_experts`: Number of experts per rank = `num_experts` // `num_ranks` + `hidden_in`: Input dimension (default 7168) + `hidden_mid`: Intermediate layer dimension (default 4096) + `hidden_out`: Output dimension (default 2048) """ w13_weight = torch.randint( From 9c28fbc3e13bb92b5ec74fd182992f2ffb64b71e Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 15:41:23 +0800 Subject: [PATCH 04/16] add to build --- .github/workflows/pr-test-npu.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index 6c5942a6..da22bf09 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -72,7 +72,9 @@ jobs: HCCL_BUFFSIZE: 2000 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py - python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --minus1-flag True --small-bs-flag True + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 1 --small-bs-flag True + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 test-build-deepep: if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') && @@ -128,7 +130,9 @@ jobs: HCCL_BUFFSIZE: 2000 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py - python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --minus1-flag True --small-bs-flag True + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 1 --small-bs-flag True + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 finish: if: always() From 26025add4a158d5a341df3ab9f54e4cae732b297 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 16:07:05 +0800 Subject: [PATCH 05/16] fix word --- tests/python/deepep/test_fused_deep_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 24e49c41..c28ffb8d 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -352,7 +352,7 @@ def test( ) return_recv_hook = False - # ----- Small-batch for debug (local rank 0 only) ----- + # ----- Small-batch for debug ---- if args.small_bs_flag and rank == 0: print("[rank 0] small_bs_flag active: truncating to batch 1", flush=True) x = x[:1, :] From 015bbfd888deafa9eb83a54b0421ffac1629c899 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 16:51:34 +0800 Subject: [PATCH 06/16] fix word --- tests/python/deepep/test_fused_deep_moe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index c28ffb8d..e4e76b39 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -564,7 +564,6 @@ def str_to_bool(value): default=0.0, help="Probability of randomly dropping a top-k index (set to -1).", ) - parser.add_argument( "--topk-drop-col", dest="topk_drop_col", @@ -572,14 +571,12 @@ def str_to_bool(value): default=-1, help="If >=0, drop this specific top-k column (set index to -1 for testing).", ) - parser.add_argument( "--small-bs-flag", type=str_to_bool, default=False, help="define small bs on certain rank", ) - parser.add_argument( "--debug", action="store_true", From d9e3c1801eeffbee8d382dadff3120cca6e4aa99 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 17:31:54 +0800 Subject: [PATCH 07/16] fix word --- tests/python/deepep/test_fused_deep_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index e4e76b39..fb86f297 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -462,7 +462,7 @@ def test( assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}" - # ----- Compare RecvCount ----- + # ----- Compare Recv Count ----- if args.topk_drop_col < 0 and args.topk_drop_prob == 0.0: recv_count_diff = ( from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count From 41395311085d48e134bba1b7867ed97c865a6ba5 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 19:33:37 +0800 Subject: [PATCH 08/16] fix word --- .github/workflows/pr-test-npu.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index da22bf09..315ea96e 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -75,6 +75,7 @@ jobs: python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 1 --small-bs-flag True python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-experts 16 --topk-drop-col 3 test-build-deepep: if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') && @@ -133,6 +134,7 @@ jobs: python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 1 --small-bs-flag True python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-experts 16 --topk-drop-col 3 finish: if: always() From 8f212802f8e28cb0cd224efad4700ff0e0ff9bd0 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 20:41:55 +0800 Subject: [PATCH 09/16] update test case --- .github/workflows/pr-test-npu.yml | 10 ++++++---- tests/python/deepep/test_fused_deep_moe.py | 15 --------------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index 315ea96e..f9475675 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -72,10 +72,11 @@ jobs: HCCL_BUFFSIZE: 2000 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py - python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 1 --small-bs-flag True + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 - python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-experts 16 --topk-drop-col 3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16 test-build-deepep: if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') && @@ -131,10 +132,11 @@ jobs: HCCL_BUFFSIZE: 2000 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py - python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 1 --small-bs-flag True + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 - python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-experts 16 --topk-drop-col 3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16 finish: if: always() diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index fb86f297..5ba62930 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -352,14 +352,6 @@ def test( ) return_recv_hook = False - # ----- Small-batch for debug ---- - if args.small_bs_flag and rank == 0: - print("[rank 0] small_bs_flag active: truncating to batch 1", flush=True) - x = x[:1, :] - topk_idx = topk_idx[:1, :] - topk_weights = topk_weights[:1, :] - num_tokens = x.shape[0] - # ----- Random or fixed drop ----- if args.topk_drop_prob > 0 or args.topk_drop_col >= 0: topk_idx_dropped = topk_idx.clone() @@ -571,12 +563,6 @@ def str_to_bool(value): default=-1, help="If >=0, drop this specific top-k column (set index to -1 for testing).", ) - parser.add_argument( - "--small-bs-flag", - type=str_to_bool, - default=False, - help="define small bs on certain rank", - ) parser.add_argument( "--debug", action="store_true", @@ -586,7 +572,6 @@ def str_to_bool(value): args = parser.parse_args() num_processes = args.num_processes - # use args.small_bs_flag directly in test() torch.multiprocessing.spawn( test_loop, args=(num_processes, args), nprocs=num_processes ) From 3bf64debf22b23a42e0ef3261e693da353bd3216 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 21:37:51 +0800 Subject: [PATCH 10/16] update test case --- tests/python/deepep/test_fused_deep_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 5ba62930..78ec09cb 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -570,7 +570,6 @@ def str_to_bool(value): ) args = parser.parse_args() - num_processes = args.num_processes torch.multiprocessing.spawn( test_loop, args=(num_processes, args), nprocs=num_processes From 76f5c3a85e573c3f4cb80398065219a90fc1227b Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Thu, 6 Nov 2025 22:21:41 +0800 Subject: [PATCH 11/16] update test case --- tests/python/deepep/test_fused_deep_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 78ec09cb..ffdea8de 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -404,7 +404,7 @@ def test( drop_ratio = (topk_idx_dropped == -1).float().mean().item() if args.debug and rank == 0: print( - f"[rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", flush=True + f"[DEBUG] [rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", flush=True ) else: topk_idx_dropped = topk_idx From e02f6120e947b0fe223aaa98271dbaa0b12ef306 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Fri, 7 Nov 2025 09:36:56 +0800 Subject: [PATCH 12/16] update test case --- tests/python/deepep/test_fused_deep_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index ffdea8de..7a9eda6a 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -339,7 +339,7 @@ def test( tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum() if args.debug: - print(f"Tokens per rank: {tokens_per_rank}", flush=True) + print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True) # ====== ensure topk_weights is defined (fix missing var) ====== topk_weights = torch.randn( From 9804516aafe5b592843eacbf26132ec6b6b26796 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Fri, 7 Nov 2025 09:45:11 +0800 Subject: [PATCH 13/16] update test case --- tests/python/deepep/test_fused_deep_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 7a9eda6a..1f300fcf 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -404,7 +404,8 @@ def test( drop_ratio = (topk_idx_dropped == -1).float().mean().item() if args.debug and rank == 0: print( - f"[DEBUG] [rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", flush=True + f"[DEBUG] [rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", + flush=True, ) else: topk_idx_dropped = topk_idx From 5bb7c555af7a876a79fe873b173f5e4c46d0c3c1 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Mon, 10 Nov 2025 14:41:48 +0800 Subject: [PATCH 14/16] update test case --- .github/workflows/pr-test-npu.yml | 2 + tests/python/deepep/test_fused_deep_moe.py | 110 +++++++++++++-------- 2 files changed, 71 insertions(+), 41 deletions(-) diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index f9475675..94aba705 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -77,6 +77,7 @@ jobs: python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 4 --topk-drop-col 1 --num-experts 32 test-build-deepep: if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') && @@ -137,6 +138,7 @@ jobs: python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2 python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 4 --topk-drop-col 1 --num-experts 32 finish: if: always() diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 1f300fcf..cced4802 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -257,7 +257,7 @@ def test( x = torch.rand((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * 10 - 5 # ----- Routing(topk_idx) ----- - if args.debug and args.active_ranks: + if args.active_ranks: try: active_ranks = [ int(r.strip()) for r in args.active_ranks.split(",") if r.strip() @@ -292,7 +292,7 @@ def test( topk_idx = valid_experts[ torch.randint(0, len(valid_experts), (num_tokens, num_topk), device="npu") ] - if args.debug and rank == 0: + if rank == 0: print( f"[config] active_ranks={active_ranks}, valid_experts={len(valid_experts)}", flush=True, @@ -324,7 +324,7 @@ def test( w2_weight_scale.clone().detach(), ) - if args.debug and rank == 0: + if rank == 0: print("=== Check fused weights ===") print("w13_f:", w13_f.shape, w13_f.dtype, w13_f.device) print("w13s_f:", w13s_f.shape, w13s_f.dtype, w13s_f.device) @@ -338,8 +338,7 @@ def test( start, end = r * experts_per_rank, (r + 1) * experts_per_rank tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum() - if args.debug: - print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True) + print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True) # ====== ensure topk_weights is defined (fix missing var) ====== topk_weights = torch.randn( @@ -376,33 +375,23 @@ def test( invalid_mask = topk_idx_dropped == -1 topk_weights_dropped = topk_weights_dropped.masked_fill(invalid_mask, 0.0) - if args.debug: - print( - f"[DEBUG] topk_idx_dropped (after random drop):\n{topk_idx_dropped.cpu().numpy()}", - flush=True, - ) - print( - f"[DEBUG] topk_weights_dropped (after random drop):\n{topk_weights_dropped.cpu().numpy()}", - flush=True, - ) - # Fixed column drop (for the test_topk_minus1 scenario) if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk: topk_idx_dropped[:, args.topk_drop_col] = -1 topk_weights_dropped[:, args.topk_drop_col] = 0 - if args.debug: - print( - f"[DEBUG] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", - flush=True, - ) - print( - f"[DEBUG] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", - flush=True, - ) + + print( + f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", + flush=True, + ) + print( + f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", + flush=True, + ) # print drop ratio drop_ratio = (topk_idx_dropped == -1).float().mean().item() - if args.debug and rank == 0: + if rank == 0: print( f"[DEBUG] [rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", flush=True, @@ -411,6 +400,20 @@ def test( topk_idx_dropped = topk_idx topk_weights_dropped = topk_weights + # Expert meta + num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="npu") + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx_dropped == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}") + if rank == 0: + print( + f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}" + ) + base_prefix_sum = num_tokens_per_expert.clone() + # ----- Baseline ----- baseline_output, base_ep_recv_count = baseline_test( buffer2, @@ -456,19 +459,49 @@ def test( assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}" # ----- Compare Recv Count ----- - if args.topk_drop_col < 0 and args.topk_drop_prob == 0.0: - recv_count_diff = ( - from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count - ).abs() - max_recv_count_diff = recv_count_diff.max().item() - mean_recv_count_diff = recv_count_diff.mean().item() + global_base_prefix_sum = [ + torch.zeros_like(base_prefix_sum) for _ in range(num_ranks) + ] + dist.all_gather(global_base_prefix_sum, base_prefix_sum) + + global_base_prefix_sum = torch.stack(global_base_prefix_sum, dim=0) + + if rank == 0: print( - f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}", - flush=True, + f"[DEBUG] Global base_prefix_sum (before transpose):\n{global_base_prefix_sum}" ) - assert ( - max_recv_count_diff < 1e-4 - ), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" + + transposed_base_prefix_sum = global_base_prefix_sum.T + if rank == 0: + print(f"[DEBUG] Transposed base_prefix_sum:\n{transposed_base_prefix_sum}") + print(f"[DEBUG] Transposed base_prefix_sum: {transposed_base_prefix_sum.shape}") + + experts_per_rank = num_experts // dist.get_world_size() + start_expert = rank * experts_per_rank + end_expert = start_expert + experts_per_rank + + # shape [experts_per_rank * num_ranks] + expected_recv = transposed_base_prefix_sum[start_expert:end_expert].reshape(-1) + fused_recv = fused_ep_recv_count + + print(f"expected_recv: {expected_recv}") + print(f"fused_recv: {fused_recv}") + + diff = (expected_recv - fused_recv).abs() + print( + f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}", + flush=True, + ) + + max_recv_count_diff = diff.max().item() + mean_recv_count_diff = diff.mean().item() + print( + f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}", + flush=True, + ) + assert ( + max_recv_count_diff < 1e-4 + ), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" # ======================== Distributed Entry ======================== @@ -564,11 +597,6 @@ def str_to_bool(value): default=-1, help="If >=0, drop this specific top-k column (set index to -1 for testing).", ) - parser.add_argument( - "--debug", - action="store_true", - help="Enable debug logging.", - ) args = parser.parse_args() num_processes = args.num_processes From bf3f45ea1fd56cc2eb3486c4182d6adfc7b160dc Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Fri, 21 Nov 2025 10:10:19 +0800 Subject: [PATCH 15/16] fix cleancode --- tests/python/deepep/test_fused_deep_moe.py | 78 +++++++++++++--------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index cced4802..43f6ffde 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -324,7 +324,7 @@ def test( w2_weight_scale.clone().detach(), ) - if rank == 0: + if args.debug and rank == 0: print("=== Check fused weights ===") print("w13_f:", w13_f.shape, w13_f.dtype, w13_f.device) print("w13s_f:", w13s_f.shape, w13s_f.dtype, w13s_f.device) @@ -338,7 +338,8 @@ def test( start, end = r * experts_per_rank, (r + 1) * experts_per_rank tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum() - print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True) + if args.debug: + print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True) # ====== ensure topk_weights is defined (fix missing var) ====== topk_weights = torch.randn( @@ -379,15 +380,15 @@ def test( if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk: topk_idx_dropped[:, args.topk_drop_col] = -1 topk_weights_dropped[:, args.topk_drop_col] = 0 - - print( - f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", - flush=True, - ) - print( - f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", - flush=True, - ) + if args.debug: + print( + f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", + flush=True, + ) + print( + f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", + flush=True, + ) # print drop ratio drop_ratio = (topk_idx_dropped == -1).float().mean().item() @@ -407,12 +408,15 @@ def test( gbl_num_tokens_per_expert = num_tokens_per_expert.clone() dist.all_reduce(gbl_num_tokens_per_expert, group=group) - print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}") - if rank == 0: - print( - f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}" - ) - base_prefix_sum = num_tokens_per_expert.clone() + + if args.debug: + print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}") + if rank == 0: + print( + f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}" + ) + + local_expert_token_count = num_tokens_per_expert.clone() # ----- Baseline ----- baseline_output, base_ep_recv_count = baseline_test( @@ -459,22 +463,22 @@ def test( assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}" # ----- Compare Recv Count ----- - global_base_prefix_sum = [ - torch.zeros_like(base_prefix_sum) for _ in range(num_ranks) + all_expert_token_counts = [ + torch.zeros_like(local_expert_token_count) for _ in range(num_ranks) ] - dist.all_gather(global_base_prefix_sum, base_prefix_sum) + dist.all_gather(all_expert_token_counts, local_expert_token_count) - global_base_prefix_sum = torch.stack(global_base_prefix_sum, dim=0) + all_expert_token_counts = torch.stack(all_expert_token_counts, dim=0) - if rank == 0: + if args.debug and rank == 0: print( - f"[DEBUG] Global base_prefix_sum (before transpose):\n{global_base_prefix_sum}" + f"[DEBUG] Global local_expert_token_count (before transpose):\n{all_expert_token_counts}" ) - transposed_base_prefix_sum = global_base_prefix_sum.T - if rank == 0: - print(f"[DEBUG] Transposed base_prefix_sum:\n{transposed_base_prefix_sum}") - print(f"[DEBUG] Transposed base_prefix_sum: {transposed_base_prefix_sum.shape}") + transposed_base_prefix_sum = all_expert_token_counts.T + if args.debug and rank == 0: + print(f"[DEBUG] Transposed local_expert_token_count:\n{transposed_base_prefix_sum}") + print(f"[DEBUG] Transposed local_expert_token_count: {transposed_base_prefix_sum.shape}") experts_per_rank = num_experts // dist.get_world_size() start_expert = rank * experts_per_rank @@ -484,14 +488,16 @@ def test( expected_recv = transposed_base_prefix_sum[start_expert:end_expert].reshape(-1) fused_recv = fused_ep_recv_count - print(f"expected_recv: {expected_recv}") - print(f"fused_recv: {fused_recv}") + if args.debug: + print(f"expected_recv: {expected_recv}") + print(f"fused_recv: {fused_recv}") diff = (expected_recv - fused_recv).abs() - print( - f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}", - flush=True, - ) + if args.debug: + print( + f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}", + flush=True, + ) max_recv_count_diff = diff.max().item() mean_recv_count_diff = diff.mean().item() @@ -597,6 +603,12 @@ def str_to_bool(value): default=-1, help="If >=0, drop this specific top-k column (set index to -1 for testing).", ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Enable debug logging.", + ) args = parser.parse_args() num_processes = args.num_processes From 752f68c9805d477dabc1e025b70d8d6cc64d7a4a Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Fri, 21 Nov 2025 10:12:58 +0800 Subject: [PATCH 16/16] fix lint --- tests/python/deepep/test_fused_deep_moe.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 43f6ffde..aefc353f 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -382,9 +382,9 @@ def test( topk_weights_dropped[:, args.topk_drop_col] = 0 if args.debug: print( - f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", - flush=True, - ) + f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", + flush=True, + ) print( f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", flush=True, @@ -408,7 +408,6 @@ def test( gbl_num_tokens_per_expert = num_tokens_per_expert.clone() dist.all_reduce(gbl_num_tokens_per_expert, group=group) - if args.debug: print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}") if rank == 0: @@ -477,8 +476,12 @@ def test( transposed_base_prefix_sum = all_expert_token_counts.T if args.debug and rank == 0: - print(f"[DEBUG] Transposed local_expert_token_count:\n{transposed_base_prefix_sum}") - print(f"[DEBUG] Transposed local_expert_token_count: {transposed_base_prefix_sum.shape}") + print( + f"[DEBUG] Transposed local_expert_token_count:\n{transposed_base_prefix_sum}" + ) + print( + f"[DEBUG] Transposed local_expert_token_count: {transposed_base_prefix_sum.shape}" + ) experts_per_rank = num_experts // dist.get_world_size() start_expert = rank * experts_per_rank