@@ -257,7 +257,7 @@ def test(
257257 x = torch .rand ((num_tokens , hidden ), dtype = torch .bfloat16 , device = "npu" ) * 10 - 5
258258
259259 # ----- Routing(topk_idx) -----
260- if args .debug and args . active_ranks :
260+ if args .active_ranks :
261261 try :
262262 active_ranks = [
263263 int (r .strip ()) for r in args .active_ranks .split ("," ) if r .strip ()
@@ -292,7 +292,7 @@ def test(
292292 topk_idx = valid_experts [
293293 torch .randint (0 , len (valid_experts ), (num_tokens , num_topk ), device = "npu" )
294294 ]
295- if args . debug and rank == 0 :
295+ if rank == 0 :
296296 print (
297297 f"[config] active_ranks={ active_ranks } , valid_experts={ len (valid_experts )} " ,
298298 flush = True ,
@@ -324,7 +324,7 @@ def test(
324324 w2_weight_scale .clone ().detach (),
325325 )
326326
327- if args . debug and rank == 0 :
327+ if rank == 0 :
328328 print ("=== Check fused weights ===" )
329329 print ("w13_f:" , w13_f .shape , w13_f .dtype , w13_f .device )
330330 print ("w13s_f:" , w13s_f .shape , w13s_f .dtype , w13s_f .device )
@@ -338,8 +338,7 @@ def test(
338338 start , end = r * experts_per_rank , (r + 1 ) * experts_per_rank
339339 tokens_per_rank [r ] = ((topk_idx >= start ) & (topk_idx < end )).sum ()
340340
341- if args .debug :
342- print (f"[DEBUG] Tokens per rank: { tokens_per_rank } " , flush = True )
341+ print (f"[DEBUG] Tokens per rank: { tokens_per_rank } " , flush = True )
343342
344343 # ====== ensure topk_weights is defined (fix missing var) ======
345344 topk_weights = torch .randn (
@@ -376,33 +375,23 @@ def test(
376375 invalid_mask = topk_idx_dropped == - 1
377376 topk_weights_dropped = topk_weights_dropped .masked_fill (invalid_mask , 0.0 )
378377
379- if args .debug :
380- print (
381- f"[DEBUG] topk_idx_dropped (after random drop):\n { topk_idx_dropped .cpu ().numpy ()} " ,
382- flush = True ,
383- )
384- print (
385- f"[DEBUG] topk_weights_dropped (after random drop):\n { topk_weights_dropped .cpu ().numpy ()} " ,
386- flush = True ,
387- )
388-
389378 # Fixed column drop (for the test_topk_minus1 scenario)
390379 if args .topk_drop_col >= 0 and args .topk_drop_col < num_topk :
391380 topk_idx_dropped [:, args .topk_drop_col ] = - 1
392381 topk_weights_dropped [:, args .topk_drop_col ] = 0
393- if args . debug :
394- print (
395- f"[DEBUG] topk_idx_dropped (after fixed-column drop):\n { topk_idx_dropped .cpu ().numpy ()} " ,
396- flush = True ,
397- )
398- print (
399- f"[DEBUG] topk_weights_dropped (after fixed-column drop):\n { topk_weights_dropped .cpu ().numpy ()} " ,
400- flush = True ,
401- )
382+
383+ print (
384+ f"[DEBUG] [rank { rank } ] topk_idx_dropped (after fixed-column drop):\n { topk_idx_dropped .cpu ().numpy ()} " ,
385+ flush = True ,
386+ )
387+ print (
388+ f"[DEBUG] [rank { rank } ] topk_weights_dropped (after fixed-column drop):\n { topk_weights_dropped .cpu ().numpy ()} " ,
389+ flush = True ,
390+ )
402391
403392 # print drop ratio
404393 drop_ratio = (topk_idx_dropped == - 1 ).float ().mean ().item ()
405- if args . debug and rank == 0 :
394+ if rank == 0 :
406395 print (
407396 f"[DEBUG] [rank { rank } ] topk dropped ratio = { drop_ratio * 100 :.2f} %" ,
408397 flush = True ,
@@ -411,6 +400,20 @@ def test(
411400 topk_idx_dropped = topk_idx
412401 topk_weights_dropped = topk_weights
413402
403+ # Expert meta
404+ num_tokens_per_expert = torch .zeros ((num_experts ,), dtype = torch .int , device = "npu" )
405+ for i in range (num_experts ):
406+ num_tokens_per_expert [i ] = (topk_idx_dropped == i ).sum ()
407+ gbl_num_tokens_per_expert = num_tokens_per_expert .clone ()
408+ dist .all_reduce (gbl_num_tokens_per_expert , group = group )
409+
410+ print (f"[Rank { rank } ] num_tokens_per_expert: { num_tokens_per_expert .tolist ()} " )
411+ if rank == 0 :
412+ print (
413+ f"[Rank { rank } ] gbl_num_tokens_per_expert: { gbl_num_tokens_per_expert .tolist ()} "
414+ )
415+ base_prefix_sum = num_tokens_per_expert .clone ()
416+
414417 # ----- Baseline -----
415418 baseline_output , base_ep_recv_count = baseline_test (
416419 buffer2 ,
@@ -456,19 +459,49 @@ def test(
456459 assert avg_diff < 1e-4 , f"[Rank { rank } ] Mismatch detected! diff={ avg_diff } "
457460
458461 # ----- Compare Recv Count -----
459- if args .topk_drop_col < 0 and args .topk_drop_prob == 0.0 :
460- recv_count_diff = (
461- from_inclusive_prefix_sum (base_ep_recv_count ) - fused_ep_recv_count
462- ).abs ()
463- max_recv_count_diff = recv_count_diff .max ().item ()
464- mean_recv_count_diff = recv_count_diff .mean ().item ()
462+ global_base_prefix_sum = [
463+ torch .zeros_like (base_prefix_sum ) for _ in range (num_ranks )
464+ ]
465+ dist .all_gather (global_base_prefix_sum , base_prefix_sum )
466+
467+ global_base_prefix_sum = torch .stack (global_base_prefix_sum , dim = 0 )
468+
469+ if rank == 0 :
465470 print (
466- f"[Rank { rank } ] Difference between base and fused recv_count -> max: { max_recv_count_diff } , mean: { mean_recv_count_diff } " ,
467- flush = True ,
471+ f"[DEBUG] Global base_prefix_sum (before transpose):\n { global_base_prefix_sum } "
468472 )
469- assert (
470- max_recv_count_diff < 1e-4
471- ), f"[Rank { rank } ] Mismatch detected! diff={ max_recv_count_diff } "
473+
474+ transposed_base_prefix_sum = global_base_prefix_sum .T
475+ if rank == 0 :
476+ print (f"[DEBUG] Transposed base_prefix_sum:\n { transposed_base_prefix_sum } " )
477+ print (f"[DEBUG] Transposed base_prefix_sum: { transposed_base_prefix_sum .shape } " )
478+
479+ experts_per_rank = num_experts // dist .get_world_size ()
480+ start_expert = rank * experts_per_rank
481+ end_expert = start_expert + experts_per_rank
482+
483+ # shape [experts_per_rank * num_ranks]
484+ expected_recv = transposed_base_prefix_sum [start_expert :end_expert ].reshape (- 1 )
485+ fused_recv = fused_ep_recv_count
486+
487+ print (f"expected_recv: { expected_recv } " )
488+ print (f"fused_recv: { fused_recv } " )
489+
490+ diff = (expected_recv - fused_recv ).abs ()
491+ print (
492+ f"[Rank { rank } ] diff (experts { start_expert } ~{ end_expert - 1 } ): { diff .cpu ().numpy ()} " ,
493+ flush = True ,
494+ )
495+
496+ max_recv_count_diff = diff .max ().item ()
497+ mean_recv_count_diff = diff .mean ().item ()
498+ print (
499+ f"[Rank { rank } ] Difference between base and fused recv_count -> max: { max_recv_count_diff } , mean: { mean_recv_count_diff } " ,
500+ flush = True ,
501+ )
502+ assert (
503+ max_recv_count_diff < 1e-4
504+ ), f"[Rank { rank } ] Mismatch detected! diff={ max_recv_count_diff } "
472505
473506
474507# ======================== Distributed Entry ========================
@@ -564,11 +597,6 @@ def str_to_bool(value):
564597 default = - 1 ,
565598 help = "If >=0, drop this specific top-k column (set index to -1 for testing)." ,
566599 )
567- parser .add_argument (
568- "--debug" ,
569- action = "store_true" ,
570- help = "Enable debug logging." ,
571- )
572600
573601 args = parser .parse_args ()
574602 num_processes = args .num_processes
0 commit comments