Skip to content

Commit 9f07ce2

Browse files
committed
sanity checkout update and w_mxfp4_a_fp16 collector update
1 parent 19fa3cf commit 9f07ce2

23 files changed

+266
-188
lines changed

collector/collect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def create_process_exit_error(device_id, exit_code):
267267
# Wait for processes
268268
for p in processes:
269269
if "moe" in func.__name__:
270-
p.join(timeout=2000)
270+
p.join(timeout=500) # tune + 30 tokens cases
271271
else:
272272
p.join(timeout=10)
273273
if p.is_alive():

collector/collect_all_reduce.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626

2727
# isort: on
2828
import tensorrt_llm as tllm
29-
from cuda import cudart
29+
try:
30+
from cuda import cudart
31+
except:
32+
from cuda.bindings import runtime as cudart
3033
from tensorrt_llm import Mapping
3134
from tensorrt_llm._torch.distributed import AllReduce, AllReduceFusionOp
3235
from tensorrt_llm._torch.distributed import AllReduceParams as TorchAllReduceParams

collector/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
try:
1414
from cuda import cuda
1515
except:
16-
pass
16+
from cuda.bindings import driver as cuda
1717
from datetime import datetime
1818
from pathlib import Path
1919

collector/slurm_comm_collector/collect_allreduce.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import os
66

77
import torch
8-
from cuda import cudart
8+
try:
9+
from cuda import cudart
10+
except:
11+
from cuda.bindings import runtime as cudart
912
from tensorrt_llm import Mapping
1013
from tensorrt_llm._torch.distributed import (
1114
AllReduce,

collector/trtllm/collect_gemm_trt.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import tensorrt as trt
55
import tensorrt_llm
66
import torch
7-
from cuda import cudart
7+
try:
8+
from cuda import cudart
9+
except:
10+
from cuda.bindings import runtime as cudart
811
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner
912
from tensorrt_llm import Tensor
1013
from tensorrt_llm._utils import str_dtype_to_torch

collector/trtllm/collect_moe.py

Lines changed: 131 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import glob
5+
import json
46
import math
57
import os
68

@@ -20,8 +22,40 @@
2022
aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112
2123

2224

23-
def balanced_logits(num_tokens, num_experts, topk):
24-
h_selected_experts = -torch.ones([num_tokens, topk])
25+
def cleanup_empty_json_files(directory="moe_tune_path"):
26+
if not os.path.exists(directory):
27+
return
28+
29+
json_files = glob.glob(os.path.join(directory, "*.json"))
30+
deleted_count = 0
31+
32+
for file_path in json_files:
33+
try:
34+
if os.path.getsize(file_path) == 0:
35+
os.remove(file_path)
36+
deleted_count += 1
37+
print(f"Deleted empty file: {file_path}")
38+
else:
39+
with open(file_path, 'r') as f:
40+
data = json.load(f)
41+
if not data:
42+
os.remove(file_path)
43+
deleted_count += 1
44+
print(f"Deleted empty JSON content: {file_path}")
45+
except (OSError, json.JSONDecodeError) as e:
46+
try:
47+
os.remove(file_path)
48+
deleted_count += 1
49+
print(f"Deleted invalid file: {file_path} (Error: {e})")
50+
except OSError:
51+
pass
52+
53+
if deleted_count > 0:
54+
print(f"Total deleted {deleted_count} invalid JSON files from {directory}")
55+
56+
57+
def balanced_logits(num_tokens, num_experts, topk, device):
58+
h_selected_experts = -torch.ones([num_tokens, topk]).to(torch.device(device))
2559
stride = math.ceil(num_experts / topk)
2660

2761
for token_i in range(num_tokens):
@@ -42,11 +76,11 @@ def sample_power_law(size, alpha, xmin, xmax):
4276
return inv_cdf
4377

4478

45-
def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
79+
def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, device):
4680
if num_tokens * topk > num_experts:
47-
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8)
81+
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8).to(torch.device(device))
4882
else:
49-
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2)
83+
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2).to(torch.device(device))
5084

5185
target_sum = num_tokens * topk
5286

@@ -85,8 +119,8 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
85119
stride=num_experts // ep,
86120
padding=0,
87121
bias=False,
88-
)
89-
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)])
122+
).to(torch.device(device))
123+
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]).to(torch.device(device))
90124
conv1d.weight.copy_(conv1d_weights)
91125

92126
res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
@@ -110,7 +144,7 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
110144
for expert_id in num_tokens_per_expert_sorted_index_lists:
111145
expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id])
112146

113-
expert_assignments = torch.tensor(expert_assignments, dtype=torch.long)
147+
expert_assignments = torch.tensor(expert_assignments, dtype=torch.long).to(torch.device(device))
114148
h_selected_experts = expert_assignments.reshape(topk, num_tokens).T
115149

116150
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
@@ -148,13 +182,14 @@ def get_moe_test_cases():
148182
12288,
149183
16384,
150184
20480,
151-
# 32768,
152-
# 65536,
185+
32768,
186+
65536,
153187
]
154188
tp_list = [1, 2, 4, 8, 16, 32]
155189
ep_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
156190
num_gpu_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
157-
alpha_list = [1.01, 1.2]
191+
alpha_list = [0.0, 1.01, 1.2] # 0.0 for balanced distribution
192+
# alpha_list = [1.01, 1.2]
158193
# hidden_size,inter_s,topk,num_expert, gated act
159194
# [15360,30720,2,16],# GPT-MOE-1.8T
160195
# [15360,3840,16,128],# GPT-MOE-1.8T-FineGrained
@@ -246,23 +281,6 @@ def get_moe_test_cases():
246281
power_law_alpha,
247282
]
248283
)
249-
# test_cases.append(
250-
# [
251-
# moe_type,
252-
# num_tokens,
253-
# hs,
254-
# inter_s,
255-
# topk,
256-
# num_experts,
257-
# tp,
258-
# ep,
259-
# True,
260-
# model_name,
261-
# "moe_perf.txt",
262-
# "balanced",
263-
# 0,
264-
# ]
265-
# )
266284

267285
for power_law_alpha in alpha_list:
268286
test_cases.append(
@@ -282,23 +300,6 @@ def get_moe_test_cases():
282300
power_law_alpha,
283301
]
284302
)
285-
# test_cases.append(
286-
# [
287-
# moe_type,
288-
# num_tokens,
289-
# hs,
290-
# inter_s,
291-
# topk,
292-
# num_experts,
293-
# tp,
294-
# ep,
295-
# False,
296-
# model_name,
297-
# "moe_perf.txt",
298-
# "balanced",
299-
# 0,
300-
# ]
301-
# )
302303
return test_cases
303304

304305

@@ -318,8 +319,6 @@ def run_moe_torch(
318319
power_law_alpha=0.0,
319320
device="cuda:0",
320321
):
321-
torch.cuda.set_device(device)
322-
torch.set_default_device(device)
323322

324323
# moe type support float16, fp8_qdq, fp8_block, w4a8, nvfp4(not implemented yet)
325324
dtype = torch.bfloat16
@@ -341,6 +340,9 @@ def run_moe_torch(
341340
quant_algo = QuantAlgo.W4A16_MXFP4
342341
quant_group_size = 32
343342

343+
if power_law_alpha - 0.0 < 1e-6:
344+
distributed = "balanced"
345+
344346
quant_config = QuantConfig(
345347
quant_algo=quant_algo,
346348
kv_cache_quant_algo=None,
@@ -369,9 +371,9 @@ def run_moe_torch(
369371
if model_name in ["GPT_OSS_120B", "GPT_OSS_20B"]:
370372
# use triton backend for best performance on Hopper
371373
model_config.moe_backend = "triton"
372-
swiglu_alpha = torch.tensor([1.702] * (num_experts // moe_ep_size), dtype=torch.float32).cuda()
373-
swiglu_beta = torch.tensor([1.0] * (num_experts // moe_ep_size), dtype=torch.float32).cuda()
374-
swiglu_limit = torch.tensor([7.0] * (num_experts // moe_ep_size), dtype=torch.float32).cuda()
374+
swiglu_alpha = torch.tensor([1.702] * (num_experts // moe_ep_size), dtype=torch.float32).to(torch.device(device))
375+
swiglu_beta = torch.tensor([1.0] * (num_experts // moe_ep_size), dtype=torch.float32).to(torch.device(device))
376+
swiglu_limit = torch.tensor([7.0] * (num_experts // moe_ep_size), dtype=torch.float32).to(torch.device(device))
375377
if 86 < get_sm_version() < 100:
376378
model_config.moe_backend = "triton"
377379
else:
@@ -416,65 +418,101 @@ def run_moe_torch(
416418
swiglu_beta=swiglu_beta,
417419
swiglu_limit=swiglu_limit,
418420
)
419-
420-
ffn1_weights = Parameter(
421-
torch.randn(moe.w3_w1_weight.shape, dtype=torch.bfloat16, device=torch.device(device)).to(
422-
dtype=moe.w3_w1_weight.dtype
423-
),
424-
requires_grad=False,
425-
)
426-
ffn2_weights = Parameter(
427-
torch.randn(moe.w2_weight.shape, dtype=torch.bfloat16, device=torch.device(device)).to(
428-
dtype=moe.w2_weight.dtype
429-
),
430-
requires_grad=False,
421+
moe.to(torch.device(device))
422+
423+
if moe_type == "w4a16_mxfp4":
424+
w1_weight = torch.randn((num_experts, inter_size, hidden_size),
425+
dtype=dtype).cuda()
426+
w2_weight = torch.randn((num_experts, hidden_size, inter_size),
427+
dtype=dtype).cuda()
428+
w3_weight = torch.randn((num_experts, inter_size, hidden_size),
429+
dtype=dtype).cuda()
430+
w1_bias = torch.randn((num_experts, inter_size),
431+
dtype=dtype).cuda()
432+
w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype).cuda()
433+
w3_bias = torch.randn((num_experts, inter_size),
434+
dtype=dtype).cuda()
435+
436+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch
437+
438+
def fp32_to_mxfp4(tensor):
439+
tensor = tensor.transpose(1, 2).contiguous()
440+
tensor_fp4, tensor_scales = downcast_to_mxfp_torch(tensor,
441+
torch.uint8,
442+
axis=1)
443+
tensor_fp4 = tensor_fp4.transpose(1, 2).contiguous()
444+
tensor_scales = tensor_scales.transpose(1, 2).contiguous()
445+
return tensor_fp4, tensor_scales
446+
447+
w1_weight_fp4, w1_weight_scale = fp32_to_mxfp4(w1_weight)
448+
w2_weight_fp4, w2_weight_scale = fp32_to_mxfp4(w2_weight)
449+
w3_weight_fp4, w3_weight_scale = fp32_to_mxfp4(w3_weight)
450+
451+
weights = {}
452+
for expert_id in range(num_experts):
453+
weights[f"{expert_id}.w1.weight"] = w1_weight_fp4[expert_id]
454+
weights[f"{expert_id}.w2.weight"] = w2_weight_fp4[expert_id]
455+
weights[f"{expert_id}.w3.weight"] = w3_weight_fp4[expert_id]
456+
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale[
457+
expert_id]
458+
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale[
459+
expert_id]
460+
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale[
461+
expert_id]
462+
weights[f"{expert_id}.w1.bias"] = w1_bias[expert_id]
463+
weights[f"{expert_id}.w2.bias"] = w2_bias[expert_id]
464+
weights[f"{expert_id}.w3.bias"] = w3_bias[expert_id]
465+
moe.load_weights([weights])
466+
467+
hidden_states_max_tokens = (
468+
torch.randn([num_tokens_lists[-1], hidden_size]).bfloat16().to(torch.device(device))
431469
)
432470

433-
moe.w3_w1_weight = ffn1_weights
434-
moe.w2_weight = ffn2_weights
471+
logits_max_tokens = (
472+
balanced_logits(num_tokens_lists[-1], num_experts, topk, torch.device(device)).to(router_logits_dtype)
473+
)
435474

436-
max_index = -1
437-
while True:
438-
try:
439-
hidden_states_max_tokens = (
440-
torch.randn([num_tokens_lists[max_index], hidden_size]).bfloat16().to(torch.device(device))
441-
)
442-
logits_max_tokens = (
443-
torch.randn([num_tokens_lists[max_index], num_experts]).to(router_logits_dtype).to(torch.device(device))
444-
)
475+
# dty run
476+
torch.cuda.synchronize()
477+
moe.forward(hidden_states_max_tokens, logits_max_tokens, do_finalize=not min_latency_mode)
478+
torch.cuda.synchronize()
479+
480+
if moe_type != "w4a16_mxfp4":
481+
cleanup_empty_json_files("moe_tune_path")
482+
cache_path = "moe_tune_path/{}_{}_{}_{}".format(
483+
moe_type, hidden_size,
484+
inter_size//moe_tp_size,
485+
num_experts//moe_ep_size
486+
)
487+
existing_files = glob.glob(f"{cache_path}*")
488+
cache_loaded = False
489+
if existing_files:
490+
json_path = existing_files[0]
491+
try:
492+
AutoTuner.get().profiling_cache.load_cache(json_path)
493+
cache_loaded = True
494+
print(f"Loaded profiling cache from {json_path}")
495+
except (OSError, json.JSONDecodeError):
496+
pass
497+
498+
if not cache_loaded:
445499
torch.cuda.synchronize()
446-
AutoTuner.get().clear_cache()
447-
with torch.inference_mode(), autotune():
500+
with torch.inference_mode(), autotune(cache_path=cache_path, rank=torch.device(device).index):
448501
moe.forward(hidden_states_max_tokens, logits_max_tokens, do_finalize=not min_latency_mode)
449502
torch.cuda.synchronize()
450-
if aic_debug == 1:
451-
print(f"tune success for tokens size {num_tokens_lists[max_index]}")
452-
break
453-
except Exception as e:
454-
if aic_debug == 1:
455-
print(
456-
f"tune failed for tokens size {num_tokens_lists[max_index]}, fallback to "
457-
f"tokens size {num_tokens_lists[max_index - 1]}"
458-
)
459-
max_index -= -3
460-
if max_index == -len(num_tokens_lists):
461-
raise ValueError("tune failed") from e
462-
continue
463503

464504
for num_tokens in num_tokens_lists:
465505
hidden_states = torch.randn([num_tokens, hidden_size]).bfloat16().to(torch.device(device))
466-
467506
num_iter = 5 if distributed == "power_law" else 1
468507
if distributed == "power_law":
469508
actual_logits_list = [
470-
power_law_logits_v3(num_tokens, num_experts, topk, moe_ep_size, power_law_alpha)
509+
power_law_logits_v3(num_tokens, num_experts, topk, moe_ep_size, power_law_alpha, torch.device(device))
471510
.to(router_logits_dtype)
472-
.to(torch.device(device))
473511
for _ in range(num_iter)
474512
]
475513
elif distributed == "balanced":
476514
actual_logits = (
477-
balanced_logits(num_tokens, num_experts, topk).to(router_logits_dtype).to(torch.device(device))
515+
balanced_logits(num_tokens, num_experts, topk, torch.device(device)).to(router_logits_dtype)
478516
)
479517
else:
480518
raise ValueError(f"Unsupported distributed mode: {distributed}")
@@ -534,5 +572,5 @@ def run_moe_torch(
534572
perf_filename=perf_filename,
535573
)
536574

537-
del moe, ffn1_weights, ffn2_weights, hidden_states, actual_logits, actual_logits_list
575+
del moe, hidden_states, actual_logits
538576
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)