From 61bc5218572f746fa908f3f48eaeda9fc2f70209 Mon Sep 17 00:00:00 2001 From: Kimi Zhao Date: Fri, 21 Nov 2025 09:59:18 +0800 Subject: [PATCH 1/2] fix: sol compute and trtllm=1.2 collect Signed-off-by: Kimi Zhao --- collector/collect.py | 19 +- collector/collect_all_reduce.py | 7 +- collector/helper.py | 2 +- .../slurm_comm_collector/collect_allreduce.py | 6 +- collector/trtllm/collect_gemm_trt.py | 6 +- collector/trtllm/collect_mla.py | 27 +-- collector/trtllm/collect_mla_1_1rc2.py | 104 +++++++++ collector/trtllm/collect_moe.py | 218 ++++++++++-------- src/aiconfigurator/sdk/perf_database.py | 35 ++- .../1.2.0rc2/context_attention_perf.txt | 3 + .../trtllm/1.2.0rc2/context_mla_perf.txt | 3 + .../trtllm/1.2.0rc2/custom_allreduce_perf.txt | 3 + .../h200_sxm/trtllm/1.2.0rc2/gemm_perf.txt | 3 + .../1.2.0rc2/generation_attention_perf.txt | 3 + .../trtllm/1.2.0rc2/generation_mla_perf.txt | 3 + .../h200_sxm/trtllm/1.2.0rc2/mla_bmm_perf.txt | 3 + .../h200_sxm/trtllm/1.2.0rc2/moe_perf.txt | 3 + tools/sanity_check/validate_database.ipynb | 154 +++++++------ 18 files changed, 399 insertions(+), 203 deletions(-) create mode 100644 src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/context_attention_perf.txt create mode 100644 src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/context_mla_perf.txt create mode 100644 src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/custom_allreduce_perf.txt create mode 100644 src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/gemm_perf.txt create mode 100644 src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/generation_attention_perf.txt create mode 100644 src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/generation_mla_perf.txt create mode 100644 src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/mla_bmm_perf.txt create mode 100644 src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/moe_perf.txt diff --git a/collector/collect.py b/collector/collect.py index 60bdfe11..dba089cb 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -222,8 +222,10 @@ def create_process_exit_error(device_id, exit_code): # Stall detection unchanged... if progress_value.value == last_progress: stall_count += 1 - if stall_count > 30: + if stall_count > 30 and "moe" not in func.__name__: logger.warning(f"Progress stalled at {progress_value.value}/{len(tasks)}") + if stall_count > 900 and "moe" in func.__name__: + logger.warning(f"Moe Progress stalled at {progress_value.value}/{len(tasks)}") else: stall_count = 0 last_progress = progress_value.value @@ -264,7 +266,10 @@ def create_process_exit_error(device_id, exit_code): # Wait for processes for p in processes: - p.join(timeout=10) + if "moe" in func.__name__: + p.join(timeout=500) # tune + 30 tokens cases + else: + p.join(timeout=10) if p.is_alive(): logger.warning(f"Process {p.pid} did not terminate, forcing...") p.terminate() @@ -511,7 +516,9 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None): "module": "collector.trtllm.collect_mla", "get_func": "get_context_mla_test_cases", "run_func": "run_mla", - "version_handler": lambda v: "trtllm.collect_mla_1_1rc2" if v.startswith("1.1") else "trtllm.collect_mla", + "version_handler": lambda v: "trtllm.collect_mla_1_1rc2" + if v.startswith(("1.1.0", "1.2.0")) + else "trtllm.collect_mla", }, { "name": "trtllm", @@ -519,7 +526,9 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None): "module": "collector.trtllm.collect_mla", "get_func": "get_generation_mla_test_cases", "run_func": "run_mla", - "version_handler": lambda v: "trtllm.collect_mla_1_1rc2" if v.startswith("1.1") else "trtllm.collect_mla", + "version_handler": lambda v: "trtllm.collect_mla_1_1rc2" + if v.startswith(("1.1.0", "1.2.0")) + else "trtllm.collect_mla", }, # Attention collections - separate entries for context and generation { @@ -563,7 +572,7 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None): else "collector.trtllm.collect_moe_pre_1_0" if v.startswith(("0.21.0", "1.0.0")) else "collector.trtllm.collect_moe" - if v.startswith("1.1.0") + if v.startswith(("1.1.0", "1.2.0")) else None, }, ] diff --git a/collector/collect_all_reduce.py b/collector/collect_all_reduce.py index fa84bd62..fbf6c448 100644 --- a/collector/collect_all_reduce.py +++ b/collector/collect_all_reduce.py @@ -30,10 +30,9 @@ from argparse import ArgumentParser from typing import Optional -# isort: off import torch -# isort: on +from helper import log_perf def get_input_shape_and_comm_size(size, token_dim=4096): @@ -112,8 +111,6 @@ def benchmark_trtllm_allreduce( num_warmups = 3 num_runs = 20 - from helper import log_perf - size = min_size while size < max_size: input_shape = get_input_shape_and_comm_size(size) @@ -261,8 +258,6 @@ def benchmark_vllm_allreduce( num_warmups = 3 num_runs = 20 - from helper import log_perf - # Warmup communication warmup_tensor = torch.ones(1, dtype=torch_dtype, device="cuda") _ = vllm_mods["tensor_model_parallel_all_reduce"](warmup_tensor) diff --git a/collector/helper.py b/collector/helper.py index cac11183..3aa5bfde 100644 --- a/collector/helper.py +++ b/collector/helper.py @@ -13,7 +13,7 @@ try: from cuda import cuda except: - pass + from cuda.bindings import driver as cuda from datetime import datetime from pathlib import Path diff --git a/collector/slurm_comm_collector/collect_allreduce.py b/collector/slurm_comm_collector/collect_allreduce.py index 4555eeb0..3d3d7761 100755 --- a/collector/slurm_comm_collector/collect_allreduce.py +++ b/collector/slurm_comm_collector/collect_allreduce.py @@ -5,7 +5,11 @@ import os import torch -from cuda import cudart + +try: + from cuda import cudart +except: + from cuda.bindings import runtime as cudart from tensorrt_llm import Mapping from tensorrt_llm._torch.distributed import ( AllReduce, diff --git a/collector/trtllm/collect_gemm_trt.py b/collector/trtllm/collect_gemm_trt.py index 45e5bc38..1a6b3845 100644 --- a/collector/trtllm/collect_gemm_trt.py +++ b/collector/trtllm/collect_gemm_trt.py @@ -4,7 +4,11 @@ import tensorrt as trt import tensorrt_llm import torch -from cuda import cudart + +try: + from cuda import cudart +except: + from cuda.bindings import runtime as cudart from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner from tensorrt_llm import Tensor from tensorrt_llm._utils import str_dtype_to_torch diff --git a/collector/trtllm/collect_mla.py b/collector/trtllm/collect_mla.py index 88ec8299..4e68486c 100644 --- a/collector/trtllm/collect_mla.py +++ b/collector/trtllm/collect_mla.py @@ -21,9 +21,9 @@ def get_context_mla_test_cases(): - dtype_list = [tensorrt_llm.bindings.DataType.FP8, tensorrt_llm.bindings.DataType.BF16] + dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1 test_cases = [] - n_list = [64, 128] + n_list = [128] b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256] s_list = [ 16, @@ -47,7 +47,7 @@ def get_context_mla_test_cases(): for b in b_list: for s in s_list: # [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072]: for dtype in dtype_list: - for tp_size in [1, 2, 4, 8, 16, 32, 64]: + for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]: if b * s > 32768: continue # (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size, @@ -72,9 +72,9 @@ def get_context_mla_test_cases(): def get_generation_mla_test_cases(): - dtype_list = [tensorrt_llm.bindings.DataType.FP8, tensorrt_llm.bindings.DataType.BF16] + dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1 test_cases = [] - n_list = [64, 128] + n_list = [128] for n in n_list: for b in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: for s in [ @@ -97,7 +97,7 @@ def get_generation_mla_test_cases(): 131072, ]: # [target token s] is equivalent to [in: s-1, step=1] for dtype in dtype_list: - for tp_size in [1, 2, 4, 8, 16, 32, 64]: + for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]: if b * s > 1024 * 4096 * 2 * 2: continue # (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size, @@ -140,13 +140,11 @@ def run_mla( torch.cuda.set_device(device) backend_name = "TRTLLM" layer_idx = 0 - num_key_value_heads = num_heads - - assert num_key_value_heads % tp_size == 0, "num_key_value_heads != N * tp_size" - num_key_value_heads = int(num_key_value_heads / tp_size) + assert kv_cache_dtype == tensorrt_llm.bindings.DataType.BF16, "only support bfloat16 for trtllm" assert num_heads % tp_size == 0, "num_heads != N * tp_size" - num_heads = int(num_heads / tp_size) + num_heads = num_heads // tp_size + num_key_value_heads = num_heads pos_embd_params = PositionalEmbeddingParams( type=PositionEmbeddingType.yarn, @@ -170,7 +168,6 @@ def run_mla( ) quant_config = QuantConfig( - quant_algo="FP8_BLOCK_SCALES", kv_cache_quant_algo=None, group_size=None, smoothquant_val=0.5, @@ -391,15 +388,11 @@ def run_mla( isl = 1 step = input_len - dtype_str = "float16" - if kv_cache_dtype == tensorrt_llm.bindings.DataType.FP8: - dtype_str = "fp8" - log_perf( item_list=[ { "mla_dtype": "float16", - "kv_cache_dtype": dtype_str, + "kv_cache_dtype": "float16", "num_heads": num_heads, "batch_size": batch_size, "isl": isl, diff --git a/collector/trtllm/collect_mla_1_1rc2.py b/collector/trtllm/collect_mla_1_1rc2.py index 7ea53236..6cc3481e 100644 --- a/collector/trtllm/collect_mla_1_1rc2.py +++ b/collector/trtllm/collect_mla_1_1rc2.py @@ -26,6 +26,107 @@ from helper import log_perf +def get_context_mla_test_cases(): + dtype_list = [tensorrt_llm.bindings.DataType.BF16, tensorrt_llm.bindings.DataType.FP8] + test_cases = [] + n_list = [128] + b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256] + s_list = [ + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + 6144, + 8192, + 10240, + 12288, + 16384, + ] + for n in n_list: + for b in b_list: + for s in s_list: # [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072]: + for dtype in dtype_list: + for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]: + if b * s > 32768: + continue + # (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size, + # tp_size, tokens_per_block, warming_up, test_ite, is_context_phase) + test_cases.append( + [ + s, + b, + 1, + dtype, + n, + tp_size, + tp_size, + 64, + 10, + 6, + True, + "context_mla_perf.txt", + ] + ) + return test_cases + + +def get_generation_mla_test_cases(): + dtype_list = [tensorrt_llm.bindings.DataType.BF16, tensorrt_llm.bindings.DataType.FP8] + test_cases = [] + n_list = [128] + for n in n_list: + for b in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: + for s in [ + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, + 65536, + 131072, + ]: # [target token s] is equivalent to [in: s-1, step=1] + for dtype in dtype_list: + for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]: + if b * s > 1024 * 4096 * 2 * 2: + continue + # (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size, + # tp_size, tokens_per_block, warming_up, test_ite, is_context_phase) + test_cases.append( + [ + s - 1, + b, + 1, + dtype, + n, + tp_size, + tp_size, + 64, + 10, + 6, + False, + "generation_mla_perf.txt", + ] + ) + return test_cases + + # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -122,6 +223,9 @@ def run_mla( kv_cache_tokens_per_block = tokens_per_block # device = torch.device('cuda') dtype = scenario.dtype + + assert num_heads % tp_size == 0, "num_heads != N * tp_size" + num_heads = num_heads // tp_size num_kv_heads = num_heads context_sequence_lengths = [input_len for _ in range(batch_size)] diff --git a/collector/trtllm/collect_moe.py b/collector/trtllm/collect_moe.py index b4a2fea1..c45777ed 100755 --- a/collector/trtllm/collect_moe.py +++ b/collector/trtllm/collect_moe.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import glob +import json import math import os @@ -13,15 +15,46 @@ from tensorrt_llm._torch.modules.fused_moe import RenormalizeMoeRoutingMethod, create_moe from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig -from torch.nn.parameter import Parameter from helper import get_sm_version, log_perf aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112 -def balanced_logits(num_tokens, num_experts, topk): - h_selected_experts = -torch.ones([num_tokens, topk]) +def cleanup_empty_json_files(directory="/tmp/moe_tune_path"): + if not os.path.exists(directory): + return + + json_files = glob.glob(os.path.join(directory, "*.json")) + deleted_count = 0 + + for file_path in json_files: + try: + if os.path.getsize(file_path) == 0: + os.remove(file_path) + deleted_count += 1 + print(f"Deleted empty file: {file_path}") + else: + with open(file_path) as f: + data = json.load(f) + if not data: + os.remove(file_path) + deleted_count += 1 + print(f"Deleted empty JSON content: {file_path}") + except (OSError, json.JSONDecodeError) as e: + try: + os.remove(file_path) + deleted_count += 1 + print(f"Deleted invalid file: {file_path} (Error: {e})") + except OSError: + pass + + if deleted_count > 0: + print(f"Total deleted {deleted_count} invalid JSON files from {directory}") + + +def balanced_logits(num_tokens, num_experts, topk, device): + h_selected_experts = -torch.ones([num_tokens, topk]).to(torch.device(device)) stride = math.ceil(num_experts / topk) for token_i in range(num_tokens): @@ -42,11 +75,11 @@ def sample_power_law(size, alpha, xmin, xmax): return inv_cdf -def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha): +def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, device): if num_tokens * topk > num_experts: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8) + num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8).to(torch.device(device)) else: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2) + num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2).to(torch.device(device)) target_sum = num_tokens * topk @@ -85,8 +118,8 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha): stride=num_experts // ep, padding=0, bias=False, - ) - conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]) + ).to(torch.device(device)) + conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]).to(torch.device(device)) conv1d.weight.copy_(conv1d_weights) res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) @@ -110,7 +143,7 @@ def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha): for expert_id in num_tokens_per_expert_sorted_index_lists: expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id]) - expert_assignments = torch.tensor(expert_assignments, dtype=torch.long) + expert_assignments = torch.tensor(expert_assignments, dtype=torch.long).to(torch.device(device)) h_selected_experts = expert_assignments.reshape(topk, num_tokens).T expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) @@ -154,7 +187,8 @@ def get_moe_test_cases(): tp_list = [1, 2, 4, 8, 16, 32] ep_list = [1, 2, 4, 8, 16, 32, 64, 128, 256] num_gpu_list = [1, 2, 4, 8, 16, 32, 64, 128, 256] - alpha_list = [1.01, 1.2] + alpha_list = [0.0, 1.01, 1.2] # 0.0 for balanced distribution + # alpha_list = [1.01, 1.2] # hidden_size,inter_s,topk,num_expert, gated act # [15360,30720,2,16],# GPT-MOE-1.8T # [15360,3840,16,128],# GPT-MOE-1.8T-FineGrained @@ -182,7 +216,7 @@ def get_moe_test_cases(): # though trtllm gen kernel source supports fp8_block, it only provides min-latency # data. not practical. moe_list += [ - "w4afp8", + # "w4afp8", FIXME: trtllm 1.2 has bugs for w4afp8 "fp8_block", "w4a16_mxfp4", ] @@ -246,23 +280,6 @@ def get_moe_test_cases(): power_law_alpha, ] ) - # test_cases.append( - # [ - # moe_type, - # num_tokens, - # hs, - # inter_s, - # topk, - # num_experts, - # tp, - # ep, - # True, - # model_name, - # "moe_perf.txt", - # "balanced", - # 0, - # ] - # ) for power_law_alpha in alpha_list: test_cases.append( @@ -282,23 +299,6 @@ def get_moe_test_cases(): power_law_alpha, ] ) - # test_cases.append( - # [ - # moe_type, - # num_tokens, - # hs, - # inter_s, - # topk, - # num_experts, - # tp, - # ep, - # False, - # model_name, - # "moe_perf.txt", - # "balanced", - # 0, - # ] - # ) return test_cases @@ -318,9 +318,6 @@ def run_moe_torch( power_law_alpha=0.0, device="cuda:0", ): - torch.cuda.set_device(device) - torch.set_default_device(device) - # moe type support float16, fp8_qdq, fp8_block, w4a8, nvfp4(not implemented yet) dtype = torch.bfloat16 quant_group_size = 128 @@ -341,6 +338,9 @@ def run_moe_torch( quant_algo = QuantAlgo.W4A16_MXFP4 quant_group_size = 32 + if power_law_alpha - 0.0 < 1e-6: + distributed = "balanced" + quant_config = QuantConfig( quant_algo=quant_algo, kv_cache_quant_algo=None, @@ -369,9 +369,11 @@ def run_moe_torch( if model_name in ["GPT_OSS_120B", "GPT_OSS_20B"]: # use triton backend for best performance on Hopper model_config.moe_backend = "triton" - swiglu_alpha = torch.tensor([1.702] * (num_experts // moe_ep_size), dtype=torch.float32).cuda() - swiglu_beta = torch.tensor([1.0] * (num_experts // moe_ep_size), dtype=torch.float32).cuda() - swiglu_limit = torch.tensor([7.0] * (num_experts // moe_ep_size), dtype=torch.float32).cuda() + swiglu_alpha = torch.tensor([1.702] * (num_experts // moe_ep_size), dtype=torch.float32).to( + torch.device(device) + ) + swiglu_beta = torch.tensor([1.0] * (num_experts // moe_ep_size), dtype=torch.float32).to(torch.device(device)) + swiglu_limit = torch.tensor([7.0] * (num_experts // moe_ep_size), dtype=torch.float32).to(torch.device(device)) if 86 < get_sm_version() < 100: model_config.moe_backend = "triton" else: @@ -416,66 +418,87 @@ def run_moe_torch( swiglu_beta=swiglu_beta, swiglu_limit=swiglu_limit, ) - - ffn1_weights = Parameter( - torch.randn(moe.w3_w1_weight.shape, dtype=torch.bfloat16, device=torch.device(device)).to( - dtype=moe.w3_w1_weight.dtype - ), - requires_grad=False, - ) - ffn2_weights = Parameter( - torch.randn(moe.w2_weight.shape, dtype=torch.bfloat16, device=torch.device(device)).to( - dtype=moe.w2_weight.dtype - ), - requires_grad=False, + moe.to(torch.device(device)) + + if moe_type == "w4a16_mxfp4": + w1_weight = torch.randn((num_experts, inter_size, hidden_size), dtype=dtype).cuda() + w2_weight = torch.randn((num_experts, hidden_size, inter_size), dtype=dtype).cuda() + w3_weight = torch.randn((num_experts, inter_size, hidden_size), dtype=dtype).cuda() + w1_bias = torch.randn((num_experts, inter_size), dtype=dtype).cuda() + w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype).cuda() + w3_bias = torch.randn((num_experts, inter_size), dtype=dtype).cuda() + + from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch + + def fp32_to_mxfp4(tensor): + tensor = tensor.transpose(1, 2).contiguous() + tensor_fp4, tensor_scales = downcast_to_mxfp_torch(tensor, torch.uint8, axis=1) + tensor_fp4 = tensor_fp4.transpose(1, 2).contiguous() + tensor_scales = tensor_scales.transpose(1, 2).contiguous() + return tensor_fp4, tensor_scales + + w1_weight_fp4, w1_weight_scale = fp32_to_mxfp4(w1_weight) + w2_weight_fp4, w2_weight_scale = fp32_to_mxfp4(w2_weight) + w3_weight_fp4, w3_weight_scale = fp32_to_mxfp4(w3_weight) + + weights = {} + for expert_id in range(num_experts): + weights[f"{expert_id}.w1.weight"] = w1_weight_fp4[expert_id] + weights[f"{expert_id}.w2.weight"] = w2_weight_fp4[expert_id] + weights[f"{expert_id}.w3.weight"] = w3_weight_fp4[expert_id] + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale[expert_id] + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale[expert_id] + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale[expert_id] + weights[f"{expert_id}.w1.bias"] = w1_bias[expert_id] + weights[f"{expert_id}.w2.bias"] = w2_bias[expert_id] + weights[f"{expert_id}.w3.bias"] = w3_bias[expert_id] + moe.load_weights([weights]) + + hidden_states_max_tokens = torch.randn([num_tokens_lists[-1], hidden_size]).bfloat16().to(torch.device(device)) + + logits_max_tokens = balanced_logits(num_tokens_lists[-1], num_experts, topk, torch.device(device)).to( + router_logits_dtype ) - moe.w3_w1_weight = ffn1_weights - moe.w2_weight = ffn2_weights + # dty run + torch.cuda.synchronize() + moe.forward(hidden_states_max_tokens, logits_max_tokens, do_finalize=not min_latency_mode) + torch.cuda.synchronize() - max_index = -1 - while True: - try: - hidden_states_max_tokens = ( - torch.randn([num_tokens_lists[max_index], hidden_size]).bfloat16().to(torch.device(device)) - ) - logits_max_tokens = ( - torch.randn([num_tokens_lists[max_index], num_experts]).to(router_logits_dtype).to(torch.device(device)) - ) + if moe_type != "w4a16_mxfp4": + cleanup_empty_json_files("/tmp/moe_tune_path") + cache_path = ( + f"/tmp/moe_tune_path/{moe_type}_{hidden_size}_{inter_size // moe_tp_size}_{num_experts // moe_ep_size}" + ) + existing_files = glob.glob(f"{cache_path}*") + cache_loaded = False + if existing_files: + json_path = existing_files[0] + try: + AutoTuner.get().profiling_cache.load_cache(json_path) + cache_loaded = True + print(f"Loaded profiling cache from {json_path}") + except (OSError, json.JSONDecodeError): + pass + + if not cache_loaded: torch.cuda.synchronize() - AutoTuner.get().clear_cache() - with torch.inference_mode(), autotune(): + with torch.inference_mode(), autotune(cache_path=cache_path, rank=torch.device(device).index): moe.forward(hidden_states_max_tokens, logits_max_tokens, do_finalize=not min_latency_mode) torch.cuda.synchronize() - if aic_debug == 1: - print(f"tune success for tokens size {num_tokens_lists[max_index]}") - break - except Exception as e: - if aic_debug == 1: - print( - f"tune failed for tokens size {num_tokens_lists[max_index]}, fallback to " - f"tokens size {num_tokens_lists[max_index - 1]}" - ) - max_index -= 1 - if max_index == -len(num_tokens_lists): - raise ValueError("tune failed for all tokens sizes") from e - continue for num_tokens in num_tokens_lists: hidden_states = torch.randn([num_tokens, hidden_size]).bfloat16().to(torch.device(device)) - num_iter = 5 if distributed == "power_law" else 1 if distributed == "power_law": actual_logits_list = [ - power_law_logits_v3(num_tokens, num_experts, topk, moe_ep_size, power_law_alpha) - .to(router_logits_dtype) - .to(torch.device(device)) + power_law_logits_v3( + num_tokens, num_experts, topk, moe_ep_size, power_law_alpha, torch.device(device) + ).to(router_logits_dtype) for _ in range(num_iter) ] elif distributed == "balanced": - actual_logits = ( - balanced_logits(num_tokens, num_experts, topk).to(router_logits_dtype).to(torch.device(device)) - ) + actual_logits = balanced_logits(num_tokens, num_experts, topk, torch.device(device)).to(router_logits_dtype) else: raise ValueError(f"Unsupported distributed mode: {distributed}") @@ -533,3 +556,6 @@ def run_moe_torch( kernel_source=source, perf_filename=perf_filename, ) + + del moe, hidden_states, actual_logits + torch.cuda.empty_cache() diff --git a/src/aiconfigurator/sdk/perf_database.py b/src/aiconfigurator/sdk/perf_database.py index 4145c6ea..67ce2b63 100755 --- a/src/aiconfigurator/sdk/perf_database.py +++ b/src/aiconfigurator/sdk/perf_database.py @@ -2015,10 +2015,10 @@ def get_sol( * b * ( n * (full_s - prefix) * h # Q read, assuming 16 bits - + 2 * n_kv * full_s * h # K,V read - + n * (full_s - prefix) * h + + n * (full_s - prefix) * h # Output write, assuming 16 bits ) - ) # Output write, assuming 16 bits + + kvcache_quant_mode.value.memory * b * (2 * n_kv * full_s * h) # K,V read + ) # TODO fp8 io sol_math = ops / self.system_spec["gpu"]["float16_tc_flops"] * 1000 / fmha_quant_mode.value.compute sol_mem = mem_bytes / self.system_spec["gpu"]["mem_bw"] * 1000 sol_time = max(sol_math, sol_mem) @@ -2070,6 +2070,10 @@ def get_sol( """ Get the sol time, sol math and sol mem """ + if kvcache_quant_mode == common.KVCacheQuantMode.fp8: + quant_mode_gen = common.FMHAQuantMode.fp8 + else: + quant_mode_gen = common.FMHAQuantMode.float16 if w > 0: kv_len = min(s - 1, w) else: @@ -2084,7 +2088,7 @@ def get_sol( + n * h * 2 # Output write, assuming 16bits ) - sol_math = ops / self.system_spec["gpu"]["float16_tc_flops"] * 1000 + sol_math = ops / self.system_spec["gpu"]["float16_tc_flops"] * 1000 / quant_mode_gen.value.compute sol_mem = mem_bytes / self.system_spec["gpu"]["mem_bw"] * 1000 sol_time = max(sol_math, sol_mem) return sol_time, sol_math, sol_mem @@ -2139,7 +2143,9 @@ def get_sol( b * num_heads * 2 / 2 * (192 + 128) * (full_s * full_s - prefix * prefix) ) # 2 for fma, 2 for causality. num_heads, for local heads # s * 192 for q read, full_s * 192 for k read, full_s * 128 for v read, s * 192 for write. - mem_bytes = b * num_heads * 2 * (full_s * (192 + 128) + s * (192 + 128)) # 2 for fp16, TODO + mem_bytes = ( + b * num_heads * (kvcache_quant_mode.value.memory * full_s * (192 + 128) + 2 * s * (192 + 128)) + ) # 2 for qk, TODO sol_math = ops / self.system_spec["gpu"]["float16_tc_flops"] * 1000 / fmha_quant_mode.value.compute sol_mem = mem_bytes / self.system_spec["gpu"]["mem_bw"] * 1000 sol_time = max(sol_math, sol_mem) @@ -2176,13 +2182,17 @@ def get_sol( """ Get the sol time, sol math and sol mem """ + if kvcache_quant_mode == common.KVCacheQuantMode.fp8: + quant_mode_gen = common.FMHAQuantMode.fp8 + else: + quant_mode_gen = common.FMHAQuantMode.float16 # only consider fp16 mmha ops = 2 * b * num_heads * 1088 * s # 2 for fma # kvcache load bytes will depend on kvcache quant. # while input q and output might be in fp16. - mem_bytes = b * (num_heads * 1088 * 2 + (s - 1) * 1088 * kvcache_quant_mode.value.memory) - - sol_math = ops / self.system_spec["gpu"]["float16_tc_flops"] * 1000 # only fp16 + mem_bytes = b * (num_heads * 1088 * 2 + (s - 1) * 576 * kvcache_quant_mode.value.memory) + # fp16 io + fp16/fp8 kv cache, TODO fp8 io + sol_math = ops / self.system_spec["gpu"]["float16_tc_flops"] * 1000 / quant_mode_gen.value.compute sol_mem = mem_bytes / self.system_spec["gpu"]["mem_bw"] * 1000 sol_time = max(sol_math, sol_mem) return sol_time, sol_math, sol_mem @@ -2559,12 +2569,17 @@ def get_sol( total_tokens = num_tokens * topk ops = total_tokens * hidden_size * inter_size * 3 * 2 // moe_ep_size // moe_tp_size # ffn1, ffn2, gate mem_bytes = quant_mode.value.memory * ( - total_tokens * hidden_size * 3 # input+output + total_tokens // moe_ep_size * hidden_size * 2 # input+output + total_tokens + // moe_ep_size * inter_size * 3 // moe_tp_size # intermediate, assume ffn1/gate all need to write results. - + hidden_size * inter_size * 3 // moe_tp_size * min(num_experts // moe_ep_size, total_tokens) + + hidden_size + * inter_size + * 3 + // moe_tp_size + * min(num_experts // moe_ep_size, total_tokens // moe_ep_size) ) sol_math = ops / (self.system_spec["gpu"]["float16_tc_flops"] * quant_mode.value.compute) * 1000 sol_mem = mem_bytes / self.system_spec["gpu"]["mem_bw"] * 1000 diff --git a/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/context_attention_perf.txt b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/context_attention_perf.txt new file mode 100644 index 00000000..236ea91d --- /dev/null +++ b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/context_attention_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29e317d7aaafc8f292726658a21e39c2c276c940dbeb6c625814ad7c4c6e4b83 +size 4862872 diff --git a/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/context_mla_perf.txt b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/context_mla_perf.txt new file mode 100644 index 00000000..7722dcee --- /dev/null +++ b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/context_mla_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc3905c78252f7260b6bcd3e429544bedd83fa2c67e204da2b3b892a81d16bde +size 133390 diff --git a/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/custom_allreduce_perf.txt b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/custom_allreduce_perf.txt new file mode 100644 index 00000000..dee1ec99 --- /dev/null +++ b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/custom_allreduce_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67f815d61ecde5369802254ac52364054ed43e979dd327a9b0da7662c2d82052 +size 5857 diff --git a/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/gemm_perf.txt b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/gemm_perf.txt new file mode 100644 index 00000000..1c230fca --- /dev/null +++ b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/gemm_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97619374abc1ccc5925756e5c317baea32728f19539be56224e384f51a88c24d +size 2211782 diff --git a/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/generation_attention_perf.txt b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/generation_attention_perf.txt new file mode 100644 index 00000000..064036d1 --- /dev/null +++ b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/generation_attention_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4222d821afd6fb176c07b595315812a9e3e9035df22d0b9824a71cc91cf6fa52 +size 3162873 diff --git a/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/generation_mla_perf.txt b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/generation_mla_perf.txt new file mode 100644 index 00000000..d4101559 --- /dev/null +++ b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/generation_mla_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea24a6b66274d43a13ebbde2d5ba606e28147b24c36caa8d738bf09424ab6f5b +size 288188 diff --git a/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/mla_bmm_perf.txt b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/mla_bmm_perf.txt new file mode 100644 index 00000000..db00f5c7 --- /dev/null +++ b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/mla_bmm_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fdcd134be13ce1351e99d22fae35fb0d48c8b37bae02530ad01dc7906a3aebdc +size 69359 diff --git a/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/moe_perf.txt b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/moe_perf.txt new file mode 100644 index 00000000..b04d5738 --- /dev/null +++ b/src/aiconfigurator/systems/data/h200_sxm/trtllm/1.2.0rc2/moe_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:748013336738d8c7173df214c13b26867b26855aec3459d73b3a937c31903267 +size 7071275 diff --git a/tools/sanity_check/validate_database.ipynb b/tools/sanity_check/validate_database.ipynb index 36e1f8e0..13fb8363 100644 --- a/tools/sanity_check/validate_database.ipynb +++ b/tools/sanity_check/validate_database.ipynb @@ -14,10 +14,10 @@ "from aiconfigurator.sdk import common\n", "from aiconfigurator.sdk.perf_database import get_database\n", "\n", - "system = \"h200_sxm\"\n", - "database = get_database(system=system, backend=\"trtllm\", version=\"1.0.0rc3\")\n", + "from aiconfigurator.sdk.common import SOLMode\n", "\n", - "from aiconfigurator.sdk.common import SOLMode" + "system = \"h200_sxm\"\n", + "database = get_database(system=system, backend=\"trtllm\", version=\"1.2.0rc2\")\n" ] }, { @@ -78,6 +78,7 @@ " ax[1, i].set_ylabel(\"mem sol %\")\n", " ax[1, i].set_ylim(0, 1)\n", " ax[1, i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", @@ -250,6 +251,7 @@ " ax[1, i].set_ylabel(\"mem sol %\")\n", " ax[1, i].set_ylim(0, 1)\n", " ax[1, i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", @@ -265,7 +267,7 @@ "def visualize_generation_attention(database):\n", " b = 64\n", " n = 32\n", - " n_kv_list = [1, 2, 4, 8, n]\n", + " n_kv_list = [1, 2, 4, 8] # mha uses sol in current version\n", " s_list = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]\n", "\n", " color_list = [\n", @@ -328,6 +330,7 @@ " ax[1, i].set_ylabel(\"mem sol %\")\n", " # ax[1,i].set_ylim(0,1)\n", " ax[1, i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", @@ -402,6 +405,7 @@ " ax[1, i].set_ylabel(\"mem sol %\")\n", " # ax[1,i].set_ylim(0,1)\n", " ax[1, i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", @@ -417,7 +421,7 @@ "def visualize_context_mla_with_prefix(database):\n", " b = 8\n", " n_list = [2, 4, 8, 16, 32]\n", - " s_list = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 524288]\n", + " s_list = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]\n", " prefix_scale = 0.3\n", "\n", " color_list = [\n", @@ -486,10 +490,11 @@ " ax[1, i].set_ylabel(\"mem sol %\")\n", " ax[1, i].set_ylim(0, 1)\n", " ax[1, i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", - "visualize_context_mla(database)" + "visualize_context_mla_with_prefix(database)" ] }, { @@ -561,6 +566,7 @@ " ax[1, i].set_ylabel(\"mem sol %\")\n", " # ax[1,i].set_ylim(0,1)\n", " ax[1, i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", @@ -632,6 +638,7 @@ " ax[1, i].set_ylabel(\"mem sol %\")\n", " # ax[1,i].set_ylim(0,1)\n", " ax[1, i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", @@ -645,19 +652,15 @@ "outputs": [], "source": [ "def visualize_moe(database):\n", - " topk = 8\n", - " num_experts = 256\n", - " hidden_size = 7168\n", - " inter_size = 2048\n", - " workload_distribution = \"power_law_1.01\"\n", - " tp_list = [1, 2, 4, 8]\n", - " ep_list = [1, 2, 4, 8]\n", + " workload_distributions = [\"balanced\", \"power_law_1.01\", \"power_law_1.2\"]\n", + " tp_list = [1, 2, 4]\n", + " ep_list = [1, 2, 4, 8, 16, 32]\n", " tp_ep_list = []\n", " for tp in tp_list:\n", " for ep in ep_list:\n", - " if tp * ep >= 4 and tp * ep <= 16:\n", + " if tp * ep >= 4 and tp * ep <= 32:\n", " tp_ep_list.append([tp, ep])\n", - " m_list = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 4]\n", + " m_list = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]\n", " color_list = [\n", " \"red\",\n", " \"blue\",\n", @@ -670,59 +673,76 @@ " \"olive\",\n", " \"cyan\",\n", " ]\n", - " fig, ax = plt.subplots(2, len(tp_ep_list), figsize=(5 * len(tp_ep_list), 5 * 2))\n", - " for i, (tp, ep) in enumerate(tp_ep_list):\n", - " for color_id, quant_mode in enumerate(database._moe_data.keys()):\n", - " sol_math_list = []\n", - " sol_mem_list = []\n", - " for m in m_list:\n", - " sol_time, sol_math, sol_mem = database.query_moe(\n", - " num_tokens=m,\n", - " hidden_size=hidden_size,\n", - " inter_size=inter_size,\n", - " topk=topk,\n", - " num_experts=num_experts,\n", - " moe_tp_size=tp,\n", - " moe_ep_size=ep,\n", - " quant_mode=quant_mode,\n", - " workload_distribution=workload_distribution,\n", - " sol_mode=SOLMode.SOL_FULL,\n", + " fig, ax = plt.subplots(2*len(workload_distributions), len(tp_ep_list), figsize=(5 * len(tp_ep_list), 5 * 2*len(workload_distributions)))\n", + " for workload_distribution_id, workload_distribution in enumerate(workload_distributions):\n", + " for i, (tp, ep) in enumerate(tp_ep_list):\n", + " for color_id, quant_mode in enumerate(database._moe_data.keys()):\n", + " if quant_mode == common.MoEQuantMode.w4a16_mxfp4:\n", + " topk = 4\n", + " num_experts = 128\n", + " hidden_size = 2880\n", + " inter_size = 2880\n", + " else:\n", + " topk = 8\n", + " num_experts = 256\n", + " hidden_size = 7168\n", + " inter_size = 2048\n", + " sol_math_list = []\n", + " sol_mem_list = []\n", + " for m in m_list:\n", + " sol_time, sol_math, sol_mem = database.query_moe(\n", + " num_tokens=m,\n", + " hidden_size=hidden_size,\n", + " inter_size=inter_size,\n", + " topk=topk,\n", + " num_experts=num_experts,\n", + " moe_tp_size=tp,\n", + " moe_ep_size=ep,\n", + " quant_mode=quant_mode,\n", + " workload_distribution=workload_distribution,\n", + " sol_mode=SOLMode.SOL_FULL,\n", + " )\n", + " db_time = database.query_moe(\n", + " num_tokens=m,\n", + " hidden_size=hidden_size,\n", + " inter_size=inter_size,\n", + " topk=topk,\n", + " num_experts=num_experts,\n", + " moe_tp_size=tp,\n", + " moe_ep_size=ep,\n", + " quant_mode=quant_mode,\n", + " workload_distribution=workload_distribution,\n", + " sol_mode=SOLMode.NON_SOL,\n", + " )\n", + " percentage_of_math = sol_math / db_time\n", + " percentage_of_mem = sol_mem / db_time\n", + " sol_math_list.append(percentage_of_math)\n", + " sol_mem_list.append(percentage_of_mem)\n", + " ax[workload_distribution_id*2, i].plot(\n", + " m_list, sol_math_list, color=color_list[color_id], label=f\"{quant_mode} math\"\n", " )\n", - " db_time = database.query_moe(\n", - " num_tokens=m,\n", - " hidden_size=hidden_size,\n", - " inter_size=inter_size,\n", - " topk=topk,\n", - " num_experts=num_experts,\n", - " moe_tp_size=tp,\n", - " moe_ep_size=ep,\n", - " quant_mode=quant_mode,\n", - " workload_distribution=workload_distribution,\n", - " sol_mode=SOLMode.NON_SOL,\n", + " ax[workload_distribution_id*2+1, i].plot(\n", + " m_list,\n", + " sol_mem_list,\n", + " color=color_list[color_id],\n", + " linestyle=\"--\",\n", + " label=f\"{quant_mode} mem\",\n", " )\n", - " percentage_of_math = sol_math / db_time\n", - " percentage_of_mem = sol_mem / db_time\n", - " sol_math_list.append(percentage_of_math)\n", - " sol_mem_list.append(percentage_of_mem)\n", - " ax[0, i].plot(\n", - " m_list, sol_math_list, color=color_list[color_id], label=f\"{quant_mode} math\"\n", - " )\n", - " ax[1, i].plot(\n", - " m_list,\n", - " sol_mem_list,\n", - " color=color_list[color_id],\n", - " linestyle=\"--\",\n", - " label=f\"{quant_mode} mem\",\n", - " )\n", - " ax[0, i].set_title(f\"topk={topk} e={num_experts} tp={tp} ep={ep}\")\n", - " ax[0, i].set_xlabel(\"s\")\n", - " ax[0, i].set_ylabel(\"math sol %\")\n", - " # ax[0,i].set_ylim(0,1)\n", - " ax[0, i].legend()\n", - " ax[1, i].set_xlabel(\"s\")\n", - " ax[1, i].set_ylabel(\"mem sol %\")\n", - " # ax[1,i].set_ylim(0,1)\n", - " ax[1, i].legend()\n", + " if workload_distribution != \"balanced\":\n", + " workload_distribution_title = workload_distribution + \" vs balanced sol\"\n", + " else:\n", + " workload_distribution_title = workload_distribution\n", + "\n", + " ax[workload_distribution_id*2, i].set_title(f\"{workload_distribution_title} \\ntopk={topk} e={num_experts} tp={tp} ep={ep}\")\n", + " ax[workload_distribution_id*2, i].set_xlabel(\"s\")\n", + " ax[workload_distribution_id*2, i].set_ylabel(\"math sol %\")\n", + " # ax[0,i].set_ylim(0,1)\n", + " ax[workload_distribution_id*2, i].legend()\n", + " ax[workload_distribution_id*2+1, i].set_xlabel(\"s\")\n", + " ax[workload_distribution_id*2+1, i].set_ylabel(\"mem sol %\")\n", + " # ax[1,i].set_ylim(0,1)\n", + " ax[workload_distribution_id*2+1, i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", @@ -790,6 +810,7 @@ " ax[i].set_ylabel(\"sol %\")\n", " # ax[i].set_ylim(0,1)\n", " ax[i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", @@ -856,6 +877,7 @@ " ax[i].set_ylabel(\"sol %\")\n", " # ax[i].set_ylim(0,1)\n", " ax[i].legend()\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", From 8286c7c18d5227309dca925a3facfb248aada45d Mon Sep 17 00:00:00 2001 From: Kimi Zhao Date: Fri, 21 Nov 2025 12:26:22 +0800 Subject: [PATCH 2/2] fix: main branch mla data read Signed-off-by: Kimi Zhao --- collector/trtllm/collect_moe.py | 8 +++++--- src/aiconfigurator/sdk/perf_database.py | 27 ++++++++++++------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/collector/trtllm/collect_moe.py b/collector/trtllm/collect_moe.py index c45777ed..b337066e 100755 --- a/collector/trtllm/collect_moe.py +++ b/collector/trtllm/collect_moe.py @@ -20,8 +20,10 @@ aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112 +moe_tune_path = os.path.dirname(os.path.abspath(__file__)) -def cleanup_empty_json_files(directory="/tmp/moe_tune_path"): + +def cleanup_empty_json_files(directory): if not os.path.exists(directory): return @@ -466,9 +468,9 @@ def fp32_to_mxfp4(tensor): torch.cuda.synchronize() if moe_type != "w4a16_mxfp4": - cleanup_empty_json_files("/tmp/moe_tune_path") + cleanup_empty_json_files(moe_tune_path) cache_path = ( - f"/tmp/moe_tune_path/{moe_type}_{hidden_size}_{inter_size // moe_tp_size}_{num_experts // moe_ep_size}" + f"{moe_tune_path}/{moe_type}_{hidden_size}_{inter_size // moe_tp_size}_{num_experts // moe_ep_size}" ) existing_files = glob.glob(f"{cache_path}*") cache_loaded = False diff --git a/src/aiconfigurator/sdk/perf_database.py b/src/aiconfigurator/sdk/perf_database.py index 67ce2b63..b98f3ccf 100755 --- a/src/aiconfigurator/sdk/perf_database.py +++ b/src/aiconfigurator/sdk/perf_database.py @@ -1490,14 +1490,13 @@ def __init__(self, system: str, backend: str, version: str, systems_dir: str = " ) # s target_z_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 1024, 2048] # b - self._extrapolate_data_grid( - data_dict=data_dict, # tpsize,sb - target_x_list=target_x_list, - target_y_list=target_y_list, - target_z_list=target_z_list, - sqrt_y_value=True, - ) - + self._extrapolate_data_grid( + data_dict=data_dict, # tpsize,sb + target_x_list=target_x_list, + target_y_list=target_y_list, + target_z_list=target_z_list, + sqrt_y_value=True, + ) # wideep generation mla if getattr(self, "_wideep_generation_mla_data", None) is not None: for kernel_source in self._wideep_generation_mla_data: @@ -1581,12 +1580,12 @@ def __init__(self, system: str, backend: str, version: str, systems_dir: str = " 2097152 * 8, ] # s - self._extrapolate_data_grid( - data_dict=data_dict, # tpsize, bs - target_x_list=target_x_list, - target_y_list=target_y_list, - target_z_list=target_z_list, - ) + self._extrapolate_data_grid( + data_dict=data_dict, # tpsize, bs + target_x_list=target_x_list, + target_y_list=target_y_list, + target_z_list=target_z_list, + ) # post-correction self._correct_data()