Skip to content

Commit 8286c7c

Browse files
committed
fix: main branch mla data read
Signed-off-by: Kimi Zhao <[email protected]>
1 parent 61bc521 commit 8286c7c

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

collector/trtllm/collect_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020

2121
aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112
2222

23+
moe_tune_path = os.path.dirname(os.path.abspath(__file__))
2324

24-
def cleanup_empty_json_files(directory="/tmp/moe_tune_path"):
25+
26+
def cleanup_empty_json_files(directory):
2527
if not os.path.exists(directory):
2628
return
2729

@@ -466,9 +468,9 @@ def fp32_to_mxfp4(tensor):
466468
torch.cuda.synchronize()
467469

468470
if moe_type != "w4a16_mxfp4":
469-
cleanup_empty_json_files("/tmp/moe_tune_path")
471+
cleanup_empty_json_files(moe_tune_path)
470472
cache_path = (
471-
f"/tmp/moe_tune_path/{moe_type}_{hidden_size}_{inter_size // moe_tp_size}_{num_experts // moe_ep_size}"
473+
f"{moe_tune_path}/{moe_type}_{hidden_size}_{inter_size // moe_tp_size}_{num_experts // moe_ep_size}"
472474
)
473475
existing_files = glob.glob(f"{cache_path}*")
474476
cache_loaded = False

src/aiconfigurator/sdk/perf_database.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,14 +1490,13 @@ def __init__(self, system: str, backend: str, version: str, systems_dir: str = "
14901490
) # s
14911491
target_z_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 1024, 2048] # b
14921492

1493-
self._extrapolate_data_grid(
1494-
data_dict=data_dict, # tpsize,sb
1495-
target_x_list=target_x_list,
1496-
target_y_list=target_y_list,
1497-
target_z_list=target_z_list,
1498-
sqrt_y_value=True,
1499-
)
1500-
1493+
self._extrapolate_data_grid(
1494+
data_dict=data_dict, # tpsize,sb
1495+
target_x_list=target_x_list,
1496+
target_y_list=target_y_list,
1497+
target_z_list=target_z_list,
1498+
sqrt_y_value=True,
1499+
)
15011500
# wideep generation mla
15021501
if getattr(self, "_wideep_generation_mla_data", None) is not None:
15031502
for kernel_source in self._wideep_generation_mla_data:
@@ -1581,12 +1580,12 @@ def __init__(self, system: str, backend: str, version: str, systems_dir: str = "
15811580
2097152 * 8,
15821581
] # s
15831582

1584-
self._extrapolate_data_grid(
1585-
data_dict=data_dict, # tpsize, bs
1586-
target_x_list=target_x_list,
1587-
target_y_list=target_y_list,
1588-
target_z_list=target_z_list,
1589-
)
1583+
self._extrapolate_data_grid(
1584+
data_dict=data_dict, # tpsize, bs
1585+
target_x_list=target_x_list,
1586+
target_y_list=target_y_list,
1587+
target_z_list=target_z_list,
1588+
)
15901589

15911590
# post-correction
15921591
self._correct_data()

0 commit comments

Comments
 (0)