11# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4+ import glob
5+ import json
46import math
57import os
68
2022aic_debug = int (os .getenv ("aic_moe_debug" , "0" )) # noqa: SIM112
2123
2224
23- def balanced_logits (num_tokens , num_experts , topk ):
24- h_selected_experts = - torch .ones ([num_tokens , topk ])
25+ def cleanup_empty_json_files (directory = "moe_tune_path" ):
26+ if not os .path .exists (directory ):
27+ return
28+
29+ json_files = glob .glob (os .path .join (directory , "*.json" ))
30+ deleted_count = 0
31+
32+ for file_path in json_files :
33+ try :
34+ if os .path .getsize (file_path ) == 0 :
35+ os .remove (file_path )
36+ deleted_count += 1
37+ print (f"Deleted empty file: { file_path } " )
38+ else :
39+ with open (file_path , 'r' ) as f :
40+ data = json .load (f )
41+ if not data :
42+ os .remove (file_path )
43+ deleted_count += 1
44+ print (f"Deleted empty JSON content: { file_path } " )
45+ except (OSError , json .JSONDecodeError ) as e :
46+ try :
47+ os .remove (file_path )
48+ deleted_count += 1
49+ print (f"Deleted invalid file: { file_path } (Error: { e } )" )
50+ except OSError :
51+ pass
52+
53+ if deleted_count > 0 :
54+ print (f"Total deleted { deleted_count } invalid JSON files from { directory } " )
55+
56+
57+ def balanced_logits (num_tokens , num_experts , topk , device ):
58+ h_selected_experts = - torch .ones ([num_tokens , topk ]).to (torch .device (device ))
2559 stride = math .ceil (num_experts / topk )
2660
2761 for token_i in range (num_tokens ):
@@ -42,11 +76,11 @@ def sample_power_law(size, alpha, xmin, xmax):
4276 return inv_cdf
4377
4478
45- def power_law_logits_v3 (num_tokens , num_experts , topk , ep , alpha ):
79+ def power_law_logits_v3 (num_tokens , num_experts , topk , ep , alpha , device ):
4680 if num_tokens * topk > num_experts :
47- num_tokens_per_expert = sample_power_law (num_experts , alpha , 1 , num_tokens * 0.8 )
81+ num_tokens_per_expert = sample_power_law (num_experts , alpha , 1 , num_tokens * 0.8 ). to ( torch . device ( device ))
4882 else :
49- num_tokens_per_expert = sample_power_law (num_experts , alpha , 0.01 , 2 )
83+ num_tokens_per_expert = sample_power_law (num_experts , alpha , 0.01 , 2 ). to ( torch . device ( device ))
5084
5185 target_sum = num_tokens * topk
5286
@@ -85,8 +119,8 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
85119 stride = num_experts // ep ,
86120 padding = 0 ,
87121 bias = False ,
88- )
89- conv1d_weights = torch .tensor ([1 for _ in range (num_experts // ep )])
122+ ). to ( torch . device ( device ))
123+ conv1d_weights = torch .tensor ([1 for _ in range (num_experts // ep )]). to ( torch . device ( device ))
90124 conv1d .weight .copy_ (conv1d_weights )
91125
92126 res = conv1d (num_tokens_per_expert .unsqueeze (0 ).unsqueeze (0 ).float ())
@@ -110,7 +144,7 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
110144 for expert_id in num_tokens_per_expert_sorted_index_lists :
111145 expert_assignments .extend ([expert_id ] * num_tokens_per_expert [expert_id ])
112146
113- expert_assignments = torch .tensor (expert_assignments , dtype = torch .long )
147+ expert_assignments = torch .tensor (expert_assignments , dtype = torch .long ). to ( torch . device ( device ))
114148 h_selected_experts = expert_assignments .reshape (topk , num_tokens ).T
115149
116150 expert_map = F .one_hot (h_selected_experts .long (), num_classes = num_experts ).sum (1 )
@@ -148,13 +182,14 @@ def get_moe_test_cases():
148182 12288 ,
149183 16384 ,
150184 20480 ,
151- # 32768,
152- # 65536,
185+ 32768 ,
186+ 65536 ,
153187 ]
154188 tp_list = [1 , 2 , 4 , 8 , 16 , 32 ]
155189 ep_list = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
156190 num_gpu_list = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
157- alpha_list = [1.01 , 1.2 ]
191+ alpha_list = [0.0 , 1.01 , 1.2 ] # 0.0 for balanced distribution
192+ # alpha_list = [1.01, 1.2]
158193 # hidden_size,inter_s,topk,num_expert, gated act
159194 # [15360,30720,2,16],# GPT-MOE-1.8T
160195 # [15360,3840,16,128],# GPT-MOE-1.8T-FineGrained
@@ -246,23 +281,6 @@ def get_moe_test_cases():
246281 power_law_alpha ,
247282 ]
248283 )
249- # test_cases.append(
250- # [
251- # moe_type,
252- # num_tokens,
253- # hs,
254- # inter_s,
255- # topk,
256- # num_experts,
257- # tp,
258- # ep,
259- # True,
260- # model_name,
261- # "moe_perf.txt",
262- # "balanced",
263- # 0,
264- # ]
265- # )
266284
267285 for power_law_alpha in alpha_list :
268286 test_cases .append (
@@ -282,23 +300,6 @@ def get_moe_test_cases():
282300 power_law_alpha ,
283301 ]
284302 )
285- # test_cases.append(
286- # [
287- # moe_type,
288- # num_tokens,
289- # hs,
290- # inter_s,
291- # topk,
292- # num_experts,
293- # tp,
294- # ep,
295- # False,
296- # model_name,
297- # "moe_perf.txt",
298- # "balanced",
299- # 0,
300- # ]
301- # )
302303 return test_cases
303304
304305
@@ -318,8 +319,6 @@ def run_moe_torch(
318319 power_law_alpha = 0.0 ,
319320 device = "cuda:0" ,
320321):
321- torch .cuda .set_device (device )
322- torch .set_default_device (device )
323322
324323 # moe type support float16, fp8_qdq, fp8_block, w4a8, nvfp4(not implemented yet)
325324 dtype = torch .bfloat16
@@ -341,6 +340,9 @@ def run_moe_torch(
341340 quant_algo = QuantAlgo .W4A16_MXFP4
342341 quant_group_size = 32
343342
343+ if power_law_alpha - 0.0 < 1e-6 :
344+ distributed = "balanced"
345+
344346 quant_config = QuantConfig (
345347 quant_algo = quant_algo ,
346348 kv_cache_quant_algo = None ,
@@ -369,9 +371,9 @@ def run_moe_torch(
369371 if model_name in ["GPT_OSS_120B" , "GPT_OSS_20B" ]:
370372 # use triton backend for best performance on Hopper
371373 model_config .moe_backend = "triton"
372- swiglu_alpha = torch .tensor ([1.702 ] * (num_experts // moe_ep_size ), dtype = torch .float32 ).cuda ( )
373- swiglu_beta = torch .tensor ([1.0 ] * (num_experts // moe_ep_size ), dtype = torch .float32 ).cuda ( )
374- swiglu_limit = torch .tensor ([7.0 ] * (num_experts // moe_ep_size ), dtype = torch .float32 ).cuda ( )
374+ swiglu_alpha = torch .tensor ([1.702 ] * (num_experts // moe_ep_size ), dtype = torch .float32 ).to ( torch . device ( device ) )
375+ swiglu_beta = torch .tensor ([1.0 ] * (num_experts // moe_ep_size ), dtype = torch .float32 ).to ( torch . device ( device ) )
376+ swiglu_limit = torch .tensor ([7.0 ] * (num_experts // moe_ep_size ), dtype = torch .float32 ).to ( torch . device ( device ) )
375377 if 86 < get_sm_version () < 100 :
376378 model_config .moe_backend = "triton"
377379 else :
@@ -416,65 +418,101 @@ def run_moe_torch(
416418 swiglu_beta = swiglu_beta ,
417419 swiglu_limit = swiglu_limit ,
418420 )
419-
420- ffn1_weights = Parameter (
421- torch .randn (moe .w3_w1_weight .shape , dtype = torch .bfloat16 , device = torch .device (device )).to (
422- dtype = moe .w3_w1_weight .dtype
423- ),
424- requires_grad = False ,
425- )
426- ffn2_weights = Parameter (
427- torch .randn (moe .w2_weight .shape , dtype = torch .bfloat16 , device = torch .device (device )).to (
428- dtype = moe .w2_weight .dtype
429- ),
430- requires_grad = False ,
421+ moe .to (torch .device (device ))
422+
423+ if moe_type == "w4a16_mxfp4" :
424+ w1_weight = torch .randn ((num_experts , inter_size , hidden_size ),
425+ dtype = dtype ).cuda ()
426+ w2_weight = torch .randn ((num_experts , hidden_size , inter_size ),
427+ dtype = dtype ).cuda ()
428+ w3_weight = torch .randn ((num_experts , inter_size , hidden_size ),
429+ dtype = dtype ).cuda ()
430+ w1_bias = torch .randn ((num_experts , inter_size ),
431+ dtype = dtype ).cuda ()
432+ w2_bias = torch .randn ((num_experts , hidden_size ), dtype = dtype ).cuda ()
433+ w3_bias = torch .randn ((num_experts , inter_size ),
434+ dtype = dtype ).cuda ()
435+
436+ from triton_kernels .numerics_details .mxfp import downcast_to_mxfp_torch
437+
438+ def fp32_to_mxfp4 (tensor ):
439+ tensor = tensor .transpose (1 , 2 ).contiguous ()
440+ tensor_fp4 , tensor_scales = downcast_to_mxfp_torch (tensor ,
441+ torch .uint8 ,
442+ axis = 1 )
443+ tensor_fp4 = tensor_fp4 .transpose (1 , 2 ).contiguous ()
444+ tensor_scales = tensor_scales .transpose (1 , 2 ).contiguous ()
445+ return tensor_fp4 , tensor_scales
446+
447+ w1_weight_fp4 , w1_weight_scale = fp32_to_mxfp4 (w1_weight )
448+ w2_weight_fp4 , w2_weight_scale = fp32_to_mxfp4 (w2_weight )
449+ w3_weight_fp4 , w3_weight_scale = fp32_to_mxfp4 (w3_weight )
450+
451+ weights = {}
452+ for expert_id in range (num_experts ):
453+ weights [f"{ expert_id } .w1.weight" ] = w1_weight_fp4 [expert_id ]
454+ weights [f"{ expert_id } .w2.weight" ] = w2_weight_fp4 [expert_id ]
455+ weights [f"{ expert_id } .w3.weight" ] = w3_weight_fp4 [expert_id ]
456+ weights [f"{ expert_id } .w1.weight_scale" ] = w1_weight_scale [
457+ expert_id ]
458+ weights [f"{ expert_id } .w2.weight_scale" ] = w2_weight_scale [
459+ expert_id ]
460+ weights [f"{ expert_id } .w3.weight_scale" ] = w3_weight_scale [
461+ expert_id ]
462+ weights [f"{ expert_id } .w1.bias" ] = w1_bias [expert_id ]
463+ weights [f"{ expert_id } .w2.bias" ] = w2_bias [expert_id ]
464+ weights [f"{ expert_id } .w3.bias" ] = w3_bias [expert_id ]
465+ moe .load_weights ([weights ])
466+
467+ hidden_states_max_tokens = (
468+ torch .randn ([num_tokens_lists [- 1 ], hidden_size ]).bfloat16 ().to (torch .device (device ))
431469 )
432470
433- moe .w3_w1_weight = ffn1_weights
434- moe .w2_weight = ffn2_weights
471+ logits_max_tokens = (
472+ balanced_logits (num_tokens_lists [- 1 ], num_experts , topk , torch .device (device )).to (router_logits_dtype )
473+ )
435474
436- max_index = - 1
437- while True :
438- try :
439- hidden_states_max_tokens = (
440- torch .randn ([num_tokens_lists [max_index ], hidden_size ]).bfloat16 ().to (torch .device (device ))
441- )
442- logits_max_tokens = (
443- torch .randn ([num_tokens_lists [max_index ], num_experts ]).to (router_logits_dtype ).to (torch .device (device ))
444- )
475+ # dty run
476+ torch .cuda .synchronize ()
477+ moe .forward (hidden_states_max_tokens , logits_max_tokens , do_finalize = not min_latency_mode )
478+ torch .cuda .synchronize ()
479+
480+ if moe_type != "w4a16_mxfp4" :
481+ cleanup_empty_json_files ("moe_tune_path" )
482+ cache_path = "moe_tune_path/{}_{}_{}_{}" .format (
483+ moe_type , hidden_size ,
484+ inter_size // moe_tp_size ,
485+ num_experts // moe_ep_size
486+ )
487+ existing_files = glob .glob (f"{ cache_path } *" )
488+ cache_loaded = False
489+ if existing_files :
490+ json_path = existing_files [0 ]
491+ try :
492+ AutoTuner .get ().profiling_cache .load_cache (json_path )
493+ cache_loaded = True
494+ print (f"Loaded profiling cache from { json_path } " )
495+ except (OSError , json .JSONDecodeError ):
496+ pass
497+
498+ if not cache_loaded :
445499 torch .cuda .synchronize ()
446- AutoTuner .get ().clear_cache ()
447- with torch .inference_mode (), autotune ():
500+ with torch .inference_mode (), autotune (cache_path = cache_path , rank = torch .device (device ).index ):
448501 moe .forward (hidden_states_max_tokens , logits_max_tokens , do_finalize = not min_latency_mode )
449502 torch .cuda .synchronize ()
450- if aic_debug == 1 :
451- print (f"tune success for tokens size { num_tokens_lists [max_index ]} " )
452- break
453- except Exception as e :
454- if aic_debug == 1 :
455- print (
456- f"tune failed for tokens size { num_tokens_lists [max_index ]} , fallback to "
457- f"tokens size { num_tokens_lists [max_index - 1 ]} "
458- )
459- max_index -= - 3
460- if max_index == - len (num_tokens_lists ):
461- raise ValueError ("tune failed" ) from e
462- continue
463503
464504 for num_tokens in num_tokens_lists :
465505 hidden_states = torch .randn ([num_tokens , hidden_size ]).bfloat16 ().to (torch .device (device ))
466-
467506 num_iter = 5 if distributed == "power_law" else 1
468507 if distributed == "power_law" :
469508 actual_logits_list = [
470- power_law_logits_v3 (num_tokens , num_experts , topk , moe_ep_size , power_law_alpha )
509+ power_law_logits_v3 (num_tokens , num_experts , topk , moe_ep_size , power_law_alpha , torch . device ( device ) )
471510 .to (router_logits_dtype )
472- .to (torch .device (device ))
473511 for _ in range (num_iter )
474512 ]
475513 elif distributed == "balanced" :
476514 actual_logits = (
477- balanced_logits (num_tokens , num_experts , topk ). to ( router_logits_dtype ). to ( torch .device (device ))
515+ balanced_logits (num_tokens , num_experts , topk , torch .device (device )). to ( router_logits_dtype )
478516 )
479517 else :
480518 raise ValueError (f"Unsupported distributed mode: { distributed } " )
@@ -534,5 +572,5 @@ def run_moe_torch(
534572 perf_filename = perf_filename ,
535573 )
536574
537- del moe , ffn1_weights , ffn2_weights , hidden_states , actual_logits , actual_logits_list
575+ del moe , hidden_states , actual_logits
538576 torch .cuda .empty_cache ()
0 commit comments