From cabab1171b3459cb3410e3bfd04e2442453f10fe Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Wed, 15 Oct 2025 17:59:57 -0700 Subject: [PATCH 01/17] feat(model): support Qwen3 Next 80B A3B --- collector/collect.py | 20 +- collector/sglang/collect_moe.py | 1 + .../trtllm/collect_chunk_gated_delta_rule.py | 395 ++++++++++++++++++ collector/trtllm/collect_conv_1d.py | 143 +++++++ src/aiconfigurator/sdk/common.py | 24 +- src/aiconfigurator/sdk/models.py | 255 ++++++++++- src/aiconfigurator/sdk/operations.py | 45 +- src/aiconfigurator/sdk/perf_database.py | 74 +++- 8 files changed, 943 insertions(+), 14 deletions(-) create mode 100644 collector/trtllm/collect_chunk_gated_delta_rule.py create mode 100644 collector/trtllm/collect_conv_1d.py diff --git a/collector/collect.py b/collector/collect.py index f31740bd..66f5210c 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -433,7 +433,25 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): else 'trtllm.collect_moe_pre_1_0' if v.startswith(('0.21.0', '1.0.0')) else 'trtllm.collect_moe' if v.startswith(('1.1.0')) else None - } + }, + + # CONV 1D collections + { + 'name': 'trtllm', + 'type': 'conv_1d', + 'module': 'trtllm.collect_conv_1d', + 'get_func': 'get_conv_1d_test_cases', + 'run_func': 'run_conv_1d' + }, + + # Chunk Gated Delta Rule collections + { + 'name': 'trtllm', + 'type': 'chunk_gated_delta_rule', + 'module': 'trtllm.collect_chunk_gated_delta_rule', + 'get_func': 'get_chunk_gated_delta_rule_test_cases', + 'run_func': 'run_chunk_gated_delta_rule' + }, ] for collection in collections: diff --git a/collector/sglang/collect_moe.py b/collector/sglang/collect_moe.py index 33dfaa4c..c471d637 100644 --- a/collector/sglang/collect_moe.py +++ b/collector/sglang/collect_moe.py @@ -47,6 +47,7 @@ def get_moe_test_cases(): [4096,1536,8,128, 'QWEN3_235B'], # qwen3-moe, 235b-a22b [6144,2560,8,160, 'QWEN3_480B'], # qwen3-moe, 480b-a35b [7168,2048,8,384, 'KIMI_K2'], # kimi k2 + [2048,5120,50,512, 'QWEN3_NEXT_80B'], # qwen3-next, 80b-a3b ] moe_list=['float16', 'fp8_block'] diff --git a/collector/trtllm/collect_chunk_gated_delta_rule.py b/collector/trtllm/collect_chunk_gated_delta_rule.py new file mode 100644 index 00000000..0f51e8c3 --- /dev/null +++ b/collector/trtllm/collect_chunk_gated_delta_rule.py @@ -0,0 +1,395 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tensorrt_llm +import torch +from cuda import cuda +from tensorrt_llm._torch.attention_backend.utils import create_attention +from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams, AttentionRuntimeFeatures +from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.models.modeling_utils import QuantConfig, QuantAlgo +from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata +from tensorrt_llm.mapping import Mapping +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm.llmapi import KvCacheConfig +import os +from helper import getSMVersion, log_perf + +def run_attention_torch(batch_size, + input_len, + num_heads, + num_key_value_heads, # keep same as num_heads for MHA + head_dim, + attention_window_size, + use_fp8_kv_cache, + use_fp8_context_fmha, + is_context_phase, + perf_filename, + device='cuda:0'): + torch.cuda.set_device(device) + + # if XQA JIT is enabled, the context phase will also trigger XQA prepare which causes the error with specifc q/kv head and seq setting. + if is_context_phase: + os.environ['TRTLLM_ENABLE_XQA_JIT']= '0' + else: + os.environ['TRTLLM_ENABLE_XQA_JIT']= '1' + + backend_name = "TRTLLM" + layer_idx = 0 + world_size=1 + tp_size=1 + tokens_per_block=64 + warming_up=10 + test_ite=6 + output_len=1 + if use_fp8_context_fmha: + assert use_fp8_kv_cache==True + quant_algo = QuantAlgo.FP8 + out_scale = torch.tensor( + [1.0], + dtype=torch.float32, + device=device, + )# fp8 fmha + else: + quant_algo = None + out_scale = None + + if use_fp8_kv_cache: + kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8 + else: + kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 + + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gpt_neox, + rope=RopeParams(dim=128)) + + + quant_config=QuantConfig(quant_algo=quant_algo, # fp8 fmha + kv_cache_quant_algo=QuantAlgo.FP8 if use_fp8_kv_cache else None, # fp8 kv, + group_size=128, + smoothquant_val=0.5, + clamp_val=None, + use_meta_recipe=False, + has_zero_point=False, + pre_quant_scale=False, + exclude_modules=None) + + attn = create_attention(backend_name = backend_name, + layer_idx = layer_idx, + num_heads = num_heads, + head_dim = head_dim, + num_kv_heads = num_key_value_heads, + pos_embd_params=pos_embd_params, + quant_config=quant_config, + is_mla_enable=False) + + total_num_tokens = (input_len + output_len) * batch_size + + mapping = Mapping(world_size=world_size, rank=0, tp_size=tp_size) + + num_hidden_layers = 1 + + kv_cache_config = KvCacheConfig( + max_tokens=int((input_len + output_len - 1)/tokens_per_block + 1) * tokens_per_block * batch_size * 2, #num_bloacks * block_size + enable_block_reuse=False) + + kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY, + num_layers=num_hidden_layers, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=input_len + output_len + 1, # +1 for the magic fixme mentioned in trtllm xqa JIT path impl. + max_batch_size=batch_size, + mapping=mapping, + dtype=kv_cache_dtype) + + input_seq_lens = [input_len for _ in range(batch_size)] + total_seq_lens = [input_len + output_len for _ in range(batch_size)] + request_ids = [i for i in range(batch_size)] + kv_cache_manager.add_dummy_requests(request_ids, total_seq_lens) + + if is_context_phase: + num_cached_tokens_per_seq = [0 for _ in range(batch_size)] + attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, + max_num_tokens=total_num_tokens, + kv_cache_manager=kv_cache_manager, + mapping=mapping, + enable_flash_mla=False, + seq_lens=torch.tensor(input_seq_lens, dtype=torch.int32), + num_contexts=batch_size, + position_ids=None, + kv_cache_params=KVCacheParams(use_cache=True, num_cached_tokens_per_seq=num_cached_tokens_per_seq, block_ids_per_seq=None, host_max_attention_window_sizes=None, host_sink_token_length=None), + cross=None, + request_ids=request_ids, + prompt_lens=input_seq_lens, + runtime_features=AttentionRuntimeFeatures(chunked_prefill=False, cache_reuse=False, has_speculative_draft_tokens=False), + all_rank_num_tokens=None, + workspace=torch.tensor([], device=device, dtype=torch.int8)) + else: + gen_seq_lens = [1 for _ in range(batch_size)] + attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, + max_num_tokens=total_num_tokens, + kv_cache_manager=kv_cache_manager, + mapping=mapping, + enable_flash_mla=False, + seq_lens=torch.tensor(gen_seq_lens, dtype=torch.int32), + position_ids=None, + num_contexts=0, + kv_cache_params=KVCacheParams(use_cache=True, num_cached_tokens_per_seq=input_seq_lens, block_ids_per_seq=None, host_max_attention_window_sizes=None, host_sink_token_length=None), + cross=None, + request_ids=request_ids, + prompt_lens=input_seq_lens, + runtime_features=AttentionRuntimeFeatures(chunked_prefill=False, cache_reuse=False), + all_rank_num_tokens=None, + workspace=torch.tensor([], device=device, dtype=torch.int8)) + + attn_metadata.prepare() + + if is_context_phase: + num_tokens = input_len * batch_size + else: + num_tokens = batch_size + + sinks = torch.randn(num_heads, dtype=torch.float32) if head_dim == 64 else None + q = torch.randn([num_tokens, num_heads*128]).bfloat16().to(torch.device(device)) + kv = torch.randn([num_tokens, 2*num_key_value_heads*128]).bfloat16().to(torch.device(device)) + input_qkv = torch.concat([q, kv], dim=-1) + attn.forward( + input_qkv, + None, + None, + attn_metadata, + attention_window_size=attention_window_size if attention_window_size>0 else None, + attention_sinks=sinks, + out_scale=out_scale + ) + out_dtype = None if not use_fp8_context_fmha else torch.float8_e4m3fn + + # capture + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + attn.forward( + input_qkv, + None, + None, + attn_metadata, + attention_window_size=attention_window_size if attention_window_size>0 else None, + attention_sinks=sinks, + out_scale=out_scale + ) + # warmup + for i in range(warming_up): + g.replay() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(test_ite): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/test_ite + + # write result + if is_context_phase: + isl = input_len + step = 0 + op_name = 'context_attention' + else: + isl = 1 + step = input_len + op_name = 'generation_attention' + kv_cache_dtype_str = 'float16' + if use_fp8_kv_cache: + kv_cache_dtype_str = 'fp8' + if use_fp8_context_fmha: + dtype_str = 'fp8' + else: + dtype_str = 'float16' + + log_perf(item_list=[{ + 'batch_size': batch_size, + 'isl': isl, + 'num_heads': num_heads, + 'num_key_value_heads': num_key_value_heads, + 'head_dim': head_dim, + 'window_size': attention_window_size, + 'beam_width': 1, + 'attn_dtype': dtype_str, + 'kv_cache_dtype': kv_cache_dtype_str, + 'step': step, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name=op_name, + kernel_source='torch_flow', + perf_filename=perf_filename) + kv_cache_manager.shutdown() + + +def get_context_attention_test_cases(): + has_fp8 = getSMVersion() > 86 + test_cases = [] + b_list = [1,2,4,8,16,32,64,128,256] + s_list = [16,32,64,128,256,512,1024,1536,2048,3072,4096,6144,8192,10240,12288,16384,262144] + n_list = [4,8,12,16,24,32,40,48,64,96] + n_kv_list = [0,1,2,4,8] + head_dim = [64,128] + + for h in head_dim: + for n in sorted(n_list, reverse=True): + for s in sorted(s_list, reverse=True): + for b in sorted(b_list, reverse=True): + for n_kv in n_kv_list: + if n_kv != 0: + if n_kv >= n or n%n_kv != 0: + continue + num_kv_heads = n_kv if n_kv !=0 else n + + if num_kv_heads == n: + if b*s > 65536 or b >128: + continue + else: + if b*s > 131072: + continue + if b*s*num_kv_heads*128*2 >= 2147483647: + continue + if getSMVersion() >= 100: + # TLLM_CHECK_WITH_INFO((params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta || params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0), + if n >= 32 and n % 32 != 0: + continue + + #print(f'collecting heads: {n} kv_heads: {num_kv_heads} seq: {s} batchsize: {b}') + # use fp8 kv cache, fp8 context fmha, is_context_phase. in torch flow, int8 kvcache is not supported yet. + # fp16 kv cache, fp16 context fmha, is_context_phase + if head_dim == 64: + test_cases.append([b, s, n, num_kv_heads, h, 128, False, False, True, 'context_attention_perf.txt']) + if has_fp8: + test_cases.append([b, s, n, num_kv_heads, h, 128, True, False, True, 'context_attention_perf.txt']) + test_cases.append([b, s, n, num_kv_heads, h, 128, True, True, True, 'context_attention_perf.txt']) + else: + test_cases.append([b, s, n, num_kv_heads, h, 0, False, False, True, 'context_attention_perf.txt']) + if has_fp8: + test_cases.append([b, s, n, num_kv_heads, h, 0, True, False, True, 'context_attention_perf.txt']) + test_cases.append([b, s, n, num_kv_heads, h, 0, True, True, True, 'context_attention_perf.txt']) + + return test_cases + +def get_generation_attention_test_cases(): + has_fp8 = getSMVersion() > 86 + test_cases = [] + + # generation + isl = 1 + b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048,] + b_list_xqa = [1,2,4,8,16,32,64,128,256,512,1024,2048] + # the i-th token to record. 1 for context phase. mapping to osl definition + s_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] + n_list = [4,8,12,16,24,32,40,48,64] + n_list_xqa = [4,8,16,32,64,96,128] + n_kv_list = [1,2,4,8] + head_dim = [64,128] + + # MHA + max_bsn = 8192*1024 #2*1024*1024*1024/128/2 INT32MAX/128/2 + max_bsn_largeb = max_bsn//2 + for n in sorted(n_list, reverse=True): + b_s_dict = {} + s_b_dict = {} + for s in s_list: + max_b = max_bsn // s // n # b*s*n*byte <= max_bsn + for b in b_list: + if b > max_b: + break + if s not in s_b_dict.keys(): + s_b_dict[s] = {b} + else: + s_b_dict[s].add(b) + for s, b_set in s_b_dict.items(): + if len(b_set) < 4: + continue + for b in b_set: + if b not in b_s_dict.keys(): + b_s_dict[b] = {s-1} + b_s_dict[b].add(s-1) + for h in head_dim: + for b, s_list_limited in b_s_dict.items(): + target_s_list = sorted(s_list_limited) + if b >= 256: + target_s_list = target_s_list[:-1] + #print(f'collecting MHA heads: {n} batchsize: {b} steps: {s_list_limited}') + # fp8 kv cache, fp8 context fmha, is_context_phase + for s in target_s_list: + if getSMVersion() >= 100: + # TLLM_CHECK_WITH_INFO((params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta || params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0), + if n >= 32 and n % 32 != 0: + continue + + test_cases.append([b, s, n, n, h, 0, False, False, False, 'generation_attention_perf.txt']) + + if has_fp8: + test_cases.append([b, s, n, n, h, 0, True, False, False, 'generation_attention_perf.txt']) + # currently, fp8 is not for generation compute + #test_cases.append([b, s, n, n, 128, True, True, False, 'generation_attention_perf.txt']) + + # XQA + max_bsn = 8192*1024*2 #2*1024*1024*1024/128/2 + for n in sorted(n_list_xqa, reverse=True): + b_s_dict = {} + s_b_dict = {} + for s in s_list: + max_b = max_bsn // s // n + for b in b_list: + if b > max_b: + break + if s not in s_b_dict.keys(): + s_b_dict[s] = {b} + else: + s_b_dict[s].add(b) + for s, b_set in s_b_dict.items(): + if len(b_set) < 4: + continue + for b in b_set: + if b not in b_s_dict.keys(): + b_s_dict[b] = {s-1} + b_s_dict[b].add(s-1) + for h in head_dim: + for b, s_list_limited in b_s_dict.items(): + target_s_list = sorted(s_list_limited) + if b >= 256: + target_s_list = target_s_list[:-1] + for n_kv in n_kv_list: + if n_kv >= n: + continue + + # fp8 kv cache, fp8 context fmha, is_context_phase + for s in target_s_list: + if getSMVersion() >= 100: + # TLLM_CHECK_WITH_INFO((params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta || params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0), + if n >= 32 and n % 32 != 0: + continue + if head_dim == 64: + test_cases.append([b, s, n, n_kv, h, 128, False, False, False, 'generation_attention_perf.txt']) + if has_fp8: + test_cases.append([b, s, n, n_kv, h, 128, True, False, False, 'generation_attention_perf.txt']) + # currently, fp8 is not for generation compute + #test_cases.append([b, s, n, n_kv, 128, True, True, False, 'generation_attention_perf.txt']) + else: + test_cases.append([b, s, n, n_kv, h, 0, False, False, False, 'generation_attention_perf.txt']) + if has_fp8: + test_cases.append([b, s, n, n_kv, h, 0, True, False, False, 'generation_attention_perf.txt']) + return test_cases + +if __name__ == '__main__': + test_cases = get_context_attention_test_cases() + for test_case in test_cases: + run_attention_torch(*test_case) + + test_cases = get_generation_attention_test_cases() + for test_case in test_cases: + run_attention_torch(*test_case) \ No newline at end of file diff --git a/collector/trtllm/collect_conv_1d.py b/collector/trtllm/collect_conv_1d.py new file mode 100644 index 00000000..b9f61da7 --- /dev/null +++ b/collector/trtllm/collect_conv_1d.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +from cuda import cuda +import torch +import torch.nn as nn +import tensorrt_llm +import math +from helper import getSMVersion, log_perf + +def get_conv1d_test_cases(): + """ + Generate test cases for Conv1D operations. + + Test parameters: + - batch_size: batch size (analogous to 'm' in GEMM) + - in_channels: number of input channels + - out_channels: number of output channels + - kernel_size: size of the convolution kernel + - seq_length: sequence length + """ + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + channel_sizes = [64, 128, 256, 512, 768, 1024, 1536, 2048, 3072, 4096] + kernel_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + seq_lengths = [32, 64, 128, 256, 512, 1024, 2048, 4096] + + conv1d_types = ['float16'] + if getSMVersion() > 86: + conv1d_types += ['fp8'] + if getSMVersion() < 100: + conv1d_types += ['fp8_block'] + if getSMVersion() >= 100: + conv1d_types += ['nvfp4'] + + test_cases = [] + for conv_type in conv1d_types: + # Generate test cases with various combinations + for batch_size in batch_sizes: + for in_ch in channel_sizes: + for out_ch in channel_sizes: + for kernel_size in kernel_sizes: + for seq_len in seq_lengths: + # Skip extremely large cases + if batch_size * in_ch * seq_len > 16777216: + continue + if conv_type == 'nvfp4' or conv_type == 'fp8_block': + if in_ch < 128 or out_ch < 128: + continue + test_cases.append([conv_type, batch_size, in_ch, out_ch, kernel_size, seq_len, 'conv1d_perf.txt']) + + return test_cases + + +def run_conv1d(conv_type, batch_size, in_channels, out_channels, kernel_size, seq_length, perf_filename, device='cuda:0'): + """ + Run Conv1D performance benchmarking. + + Args: + conv_type: Type of convolution ('float16', 'fp8', 'fp8_block', 'nvfp4') + batch_size: Batch size + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + seq_length: Sequence length + perf_filename: Output file for performance results + device: CUDA device to use + """ + device = torch.device(device) + torch.cuda.set_device(device) + torch.set_default_device(device) + + # For now, we focus on float16/bfloat16 benchmarking + # Quantized Conv1D support can be added when available in TensorRT-LLM + if conv_type != 'float16': + # Skip non-float16 for now as Conv1D quantization in TensorRT-LLM needs verification + return + + dtype = torch.bfloat16 + # Conv1D expects input shape: (batch_size, in_channels, seq_length) + x = torch.randn((batch_size, in_channels, seq_length), dtype=dtype).to(torch.device(device)) + + repeat_n = 5 # to reduce impact of L2 cache hit + op_list = [] + + for i in range(repeat_n): + # Use PyTorch's native Conv1d + conv1d = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=kernel_size//2, # 'same' padding + bias=False, + dtype=dtype, + ) + + # Initialize weights randomly + conv1d.weight.data = torch.randn((out_channels, in_channels, kernel_size), dtype=dtype, device=device) + + conv1d.to(torch.device(device)) + conv1d(x) # dry run to init + op_list.append(conv1d) + + num_warmups = 3 + num_runs = 6 + + # capture + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for op in op_list: + op.forward(x) + + # warmup + for i in range(num_warmups): + g.replay() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_runs): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/num_runs/len(op_list) + + log_perf( + item_list=[{ + 'conv_dtype': conv_type, + 'batch_size': batch_size, + 'in_channels': in_channels, + 'out_channels': out_channels, + 'kernel_size': kernel_size, + 'seq_length': seq_length, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name='conv1d', + kernel_source='torch_flow', + perf_filename=perf_filename + ) diff --git a/src/aiconfigurator/sdk/common.py b/src/aiconfigurator/sdk/common.py index 9bf4ea67..531075ce 100644 --- a/src/aiconfigurator/sdk/common.py +++ b/src/aiconfigurator/sdk/common.py @@ -23,6 +23,27 @@ class BlockConfig: ffn_no_op: bool = False num_inst: int = 0 +@dataclass(frozen=True) +class LinearAttentionConfig: + """ + Configuration for a single linear attention block in Qwen3Next. + + Attributes: + used_ratio (float): Used ratio of the linear attention block within all attention blocks + linear_conv_kernel_dim (int): Kernel dimension for the linear convolution + linear_key_head_dim (int): Head dimension for the linear key + linear_num_key_heads (int): Number of key heads for the linear attention + linear_num_value_heads (int): Number of value heads for the linear attention + linear_value_head_dim (int): Head dimension for the linear value + """ + used_ratio: float = 0.75 + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: int = 128 + linear_num_key_heads: int = 16 + linear_num_value_heads: int = 32 + linear_value_head_dim: int = 128 + + """ Supported models model name: model_family,l,n,n_kv,d,hidden_size,inter_size,vocab,context,topk,num_experts,moe_inter_size,extra_params @@ -59,6 +80,7 @@ class BlockConfig: 'QWEN3_8B':['LLAMA', 36,32,8,128,32*128,12288,151936,40960, 0, 0, 0, None], 'QWEN3_235B':['MOE', 94,64,4,128,4096,12288,151936,40960, 8, 128, 1536, None], 'QWEN3_480B':['MOE', 62,96,8,128,6144,8192,151936,262144,8,160,2560, None], + 'QWEN3_NEXT_80B':['QWEN3NEXT', 48,16,2,256,2048,5120,151936,262144,10,512,512, LinearAttentionConfig(0.75, 4, 128, 16, 32, 128)], 'Nemotron_super_v1.1':['NEMOTRONNAS', 80, 64, 0, 128, 8192, 0, 128256, 131072, 0, 0, 0, [ BlockConfig(8, False, 5.25, False, 48), @@ -78,7 +100,7 @@ class BlockConfig: """ Model family for model definition """ -ModelFamily = {'GPT', 'LLAMA', 'MOE', 'DEEPSEEK', 'NEMOTRONNAS'} +ModelFamily = {'GPT', 'LLAMA', 'MOE', 'DEEPSEEK', 'NEMOTRONNAS', 'QWEN3NEXT'} """ All reduce strategy for trtllm custom allreduce diff --git a/src/aiconfigurator/sdk/models.py b/src/aiconfigurator/sdk/models.py index 3aa3d8df..72a96ff8 100755 --- a/src/aiconfigurator/sdk/models.py +++ b/src/aiconfigurator/sdk/models.py @@ -18,7 +18,7 @@ def get_model(model_name: str, model_config: config.ModelConfig, backend_name: s """ assert(model_name in common.SupportedModels), f"unsupport model {model_name}" model_family,l,n,n_kv,d,hidden,inter,vocab,context,topk,num_experts,moe_inter_size, extra_params = common.SupportedModels[model_name] - assert(model_family in common.ModelFamily), f"model is not in ModelFamily(GPT, LLAMA, MOE, DEEPSEEK, NEMOTRONNAS)" + assert(model_family in common.ModelFamily), f"model is not in ModelFamily(GPT, LLAMA, MOE, DEEPSEEK, NEMOTRONNAS, QWEN3NEXT)" if model_config.overwrite_num_layers > 0: l = model_config.overwrite_num_layers @@ -53,6 +53,13 @@ def get_model(model_name: str, model_config: config.ModelConfig, backend_name: s model_config) model.context_ops = extra_params model.generation_ops = extra_params + elif model_family == 'QWEN3NEXT': + model = Qwen3NextModel(topk, num_experts, moe_inter_size, \ + model_name, model_family, l, n, n_kv, d, \ + hidden, inter, vocab, context, \ + model_config) + model.context_ops = extra_params + model.generation_ops = extra_params return model @@ -68,7 +75,7 @@ def check_is_moe(model_name: str) -> bool: """ Check if the model is a MoE model. """ - return get_model_family(model_name) == 'MOE' or get_model_family(model_name) == 'DEEPSEEK' + return get_model_family(model_name) == 'MOE' or get_model_family(model_name) == 'DEEPSEEK' or get_model_family(model_name) == 'QWEN3NEXT' def calc_expectation(nextn: int, nextn_accept_rates: list[float]) -> float: """ @@ -732,7 +739,249 @@ def _ffn_mult_to_intermediate_size(self, ffn_mult: float) -> int: if inter_size % 256 == 0: return inter_size return inter_size + 256 - (inter_size % 256) - + + +class Qwen3NextModel(BaseModel): + """ + Qwen3Next model uses this model impl. + Currently Qwen3Next only has a series of 80B A3B models, which is similar to MOEModel but with different attention: + 1/4 of the layers are the same to MOE model, using self attention. + 3/4 of the layers are using linear attention with convolution 1d operation. + Some rules to follow, + Due to implementation, attn layer name needs to be context_attention or generation_attention, exact match is required. Same for logits_gemm. + """ + def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> None: + super().__init__(*args) + assert self._nextn == 0, 'Qwen3Next only supports mtp=0' + + # make sure the parallel width is same + assert(self.config.tp_size * self.config.attention_dp_size == self.config.moe_tp_size * self.config.moe_ep_size), \ + f"tp_size ({self.config.tp_size}) * attention_dp_size ({self.config.attention_dp_size}) should be equal to moe_tp_size ({self.config.moe_tp_size}) * moe_ep_size ({self.config.moe_ep_size})" + + assert(num_experts >= self.config.moe_ep_size), f"ep size cannot be larger than num_experts {num_experts}" + assert(self.config.tp_size * self.config.attention_dp_size <= 256), f"moe ep size {self.config.moe_ep_size} * moe tp size {self.config.moe_tp_size} should not be larger than 256" + assert(self._num_layers % 4 == 0), f"num_layers {self._num_layers} should be divisible by 4" + + self._topk = topk + self._num_experts = num_experts + self._moe_inter_size = moe_inter_size + + self._power_law_alpha = 1.2 + + @property + def context_ops(self): + """ + Get the context(prefill) processing operations pipeline. + + Returns: + List[ops.Operation]: List of operations for processing context + sequences, including: + - embedding, + - attention blocks, + - FFN blocks, + - P2P communication, + - all reduce communication + - logits computation. + """ + return self._context_ops + + @context_ops.setter + def context_ops(self, linear_attention_config: common.LinearAttentionConfig): + """ + Set the context(prefill) processing operations pipeline based on linear attention configurations. + + Constructs a pipeline of operations for processing input context by creating operations + for each configured transformer block. The pipeline includes embedding lookup, + transformer blocks (with optional attention and FFN components), pipeline parallel + communication, and final logits computation. + + Args: + linear_attention_config (common.LinearAttentionConfig): Linear attention configuration + """ + num_v_heads = linear_attention_config.linear_num_value_heads + num_k_heads = linear_attention_config.linear_num_key_heads + head_k_dim = linear_attention_config.linear_key_head_dim + head_v_dim = linear_attention_config.linear_value_head_dim + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + moe_quant_mode = self.config.moe_quant_mode + h = self._hidden_size + tp_size = self.config.tp_size + moe_tp_size = self.config.moe_tp_size + moe_ep_size = self.config.moe_ep_size + attention_dp_size = self.config.attention_dp_size + pp_size = self.config.pp_size + num_kv_heads_per_GPU = self._num_kv_heads_per_GPU + gemm_quant_mode = self.config.gemm_quant_mode + kvcache_quant_mode = self.config.kvcache_quant_mode + fmha_quant_mode = self.config.fmha_quant_mode + workload_distribution = self.config.workload_distribution + f"_{self._power_law_alpha}" + + # 1 embedding for all layers + # 1 norm before attention per layer + self.context_ops.extend([ops.Embedding(f'context_embedding', 1, self._vocab_size, h, 0.3), + ops.ElementWise(f'context_add_norm_1', self._num_layers, 2*h, 2*h, 0.8)]) + + # self attention + # (self_attn): Qwen3NextAttention( + # (q_proj): Linear(in_features=2048, out_features=8192, bias=False) out_features = num_heads * head dimension(head_size) + # (k_proj): Linear(in_features=2048, out_features=512, bias=False) out_features = num_key_value_heads * head dimension(head_size) = 2 * 256 = 512 + # (v_proj): Linear(in_features=2048, out_features=512, bias=False) out_features = num_key_value_heads * head dimension(head_size) = 2 * 256 = 512 + # (o_proj): Linear(in_features=4096, out_features=2048, bias=False) out_features = hidden size = 2048 + # (q_norm): Qwen3NextRMSNorm((256,), eps=1e-06) ignored since the latency is ignorable + # (k_norm): Qwen3NextRMSNorm((256,), eps=1e-06) ignored since the latency is ignorable + # context_qkv_gemm: q + k + v + # context_attention: num of heads / tp_size + # context_proj_gemm: num of heads * head_size / tp_size + self.context_ops.extend([ + ops.GEMM(f'context_qkv_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, gemm_quant_mode), + ops.ContextAttention(f'context_attention', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode, fmha_quant_mode), + ops.GEMM(f'context_proj_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), h, self._num_heads*self._head_size//tp_size, gemm_quant_mode), + ]) + + # linear attention (Qwen3NextGatedDeltaNet) + self.context_ops.extend([ + ops.GEMM(f'context_qkvz_ba_gemm', self._num_layers * linear_attention_config.used_ratio, (key_dim*2 + value_dim*2 + num_v_heads*2)//tp_size, h, gemm_quant_mode), + ops.Conv1D(f'context_conv1d', self._num_layers * linear_attention_config.used_ratio, self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, kernel_size=4, stride=1, padding=3, groups=self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, bias=False, gemm_quant_mode), + ops.ChunkGatedDeltaRule(f'context_chunk_gated_delta_rule', self._num_layers * linear_attention_config.used_ratio, in_channels=self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, out_channels=h, kernel_size=4, seq_length=seq_length), + ops.GEMM(f'context_proj_gemm', self._num_layers * linear_attention_config.used_ratio, h, value_dim//tp_size, gemm_quant_mode), + ]) + + # 1 norm before MOE per layer + self.context_ops.extend([ + ops.ElementWise(f'context_add_norm_2', self._num_layers // linear_attention_config.used_ratio, 2*h, 2*h, 0.8)]) + + #router, only take it into account when num_experts >= 128 + if self._num_experts >= 128: + self.context_ops.extend([ + ops.GEMM(f'context_router_gemm', self._num_layers, self._num_experts, h, common.GEMMQuantMode.float16) + ]) + + # dispatch tokens to experts, moe calc and get tokens back + # Qwen3Next has one more shared expert. + self.context_ops.extend([ + ops.MoEDispatch(f'context_moe_pre_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, True), + ops.MoE(f'context_moe', self._num_layers, h, self._moe_inter_size, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, moe_quant_mode, workload_distribution, attention_dp_size), + ops.MoEDispatch(f'context_moe_post_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, False)]) + + self.context_ops.extend([ops.GEMM(f'context_logits_gemm', 1, self._vocab_size//tp_size, h, common.GEMMQuantMode.float16)]) + + # # # when tp_size=0, the comm part will be 0 + # self.context_ops.append(ops.AllReduce('context_ar_1', self._num_layers, h, tp_size)) + # self.context_ops.append(ops.AllReduce('context_ar_2', self._num_layers, h, tp_size)) + + # pp + pp_scale_factor = pp_size-1 + self.context_ops.append(ops.P2P('context_p2p', pp_scale_factor, h, pp_size)) + + @property + def generation_ops(self): + """ + Get the generation (decoding) operations pipeline. + + Returns: + List[ops.Operation]: List of operations for the decoding phase + including: + - embedding, + - attention blocks, + - FFN blocks, + - P2P communication, + - all reduce communication + - logits computation. + """ + return self._generation_ops + + @generation_ops.setter + def generation_ops(self, linear_attention_config: common.LinearAttentionConfig): + """ + Set the generation (decoding) operations pipeline based on linear attention configurations. + + Constructs a pipeline of operations for generating output tokens by creating operations + for each configured transformer block. The pipeline includes embedding lookup, + transformer blocks (with optional attention and FFN components), pipeline parallel + communication, and final logits computation. + + Args: + linear_attention_config (common.LinearAttentionConfig): Linear attention configuration + """ + num_v_heads = linear_attention_config.linear_num_value_heads + num_k_heads = linear_attention_config.linear_num_key_heads + head_k_dim = linear_attention_config.linear_key_head_dim + head_v_dim = linear_attention_config.linear_value_head_dim + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + moe_quant_mode = self.config.moe_quant_mode + h = self._hidden_size + tp_size = self.config.tp_size + moe_tp_size = self.config.moe_tp_size + moe_ep_size = self.config.moe_ep_size + attention_dp_size = self.config.attention_dp_size + pp_size = self.config.pp_size + num_kv_heads_per_GPU = self._num_kv_heads_per_GPU + gemm_quant_mode = self.config.gemm_quant_mode + kvcache_quant_mode = self.config.kvcache_quant_mode + fmha_quant_mode = self.config.fmha_quant_mode + workload_distribution = self.config.workload_distribution + f"_{self._power_law_alpha}" + + # 1 embedding for all layers + # 1 norm before attention per layer + self.generation_ops.extend([ops.Embedding(f'context_embedding', 1, self._vocab_size, h, 0.3), + ops.ElementWise(f'context_add_norm_1', self._num_layers, 2*h, 2*h, 0.8)]) + + # self attention + # (self_attn): Qwen3NextAttention( + # (q_proj): Linear(in_features=2048, out_features=8192, bias=False) out_features = num_heads * head dimension(head_size) + # (k_proj): Linear(in_features=2048, out_features=512, bias=False) out_features = num_key_value_heads * head dimension(head_size) = 2 * 256 = 512 + # (v_proj): Linear(in_features=2048, out_features=512, bias=False) out_features = num_key_value_heads * head dimension(head_size) = 2 * 256 = 512 + # (o_proj): Linear(in_features=4096, out_features=2048, bias=False) out_features = hidden size = 2048 + # (q_norm): Qwen3NextRMSNorm((256,), eps=1e-06) ignored since the latency is ignorable + # (k_norm): Qwen3NextRMSNorm((256,), eps=1e-06) ignored since the latency is ignorable + # context_qkv_gemm: q + k + v + # context_attention: num of heads / tp_size + # context_proj_gemm: num of heads * head_size / tp_size + self.generation_ops.extend([ + ops.GEMM(f'generation_qkv_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, gemm_quant_mode), + ops.GenerationAttention(f'generation_attention', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode), + ops.GEMM(f'generation_proj_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), h, self._num_heads*self._head_size//tp_size, gemm_quant_mode), + ]) + + # linear attention (Qwen3NextGatedDeltaNet) + self.generation_ops.extend([ + ops.GEMM(f'generation_qkvz_ba_gemm', self._num_layers * linear_attention_config.used_ratio, (key_dim*2 + value_dim*2 + num_v_heads*2)//tp_size, h, gemm_quant_mode), + ops.Conv1D(f'generation_conv1d', self._num_layers * linear_attention_config.used_ratio, self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, kernel_size=4, stride=1, padding=3, groups=self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, bias=False, gemm_quant_mode), + ops.ChunkGatedDeltaRule(f'generation_chunk_gated_delta_rule', self._num_layers * linear_attention_config.used_ratio, in_channels=self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, out_channels=h, kernel_size=4, seq_length=seq_length), + ops.GEMM(f'generation_proj_gemm', self._num_layers * linear_attention_config.used_ratio, h, value_dim//tp_size, gemm_quant_mode), + ]) + + # 1 norm before MOE per layer + self.generation_ops.extend([ + ops.ElementWise(f'context_add_norm_2', self._num_layers // linear_attention_config.used_ratio, 2*h, 2*h, 0.8)]) + + #router, only take it into account when num_experts >= 128 + if self._num_experts >= 128: + self.generation_ops.extend([ + ops.GEMM(f'generation_router_gemm', self._num_layers, self._num_experts, h, common.GEMMQuantMode.float16) + ]) + + # dispatch tokens to experts, moe calc and get tokens back + # Qwen3Next has one more shared expert. + self.generation_ops.extend([ + ops.MoEDispatch(f'generation_moe_pre_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, True), + ops.MoE(f'generation_moe', self._num_layers, h, self._moe_inter_size, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, moe_quant_mode, workload_distribution, attention_dp_size), + ops.MoEDispatch(f'generation_moe_post_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, False) + ]) + # logits gemm + self.generation_ops.extend([ops.GEMM(f'generation_logits_gemm', 1, self._vocab_size//tp_size, h, common.GEMMQuantMode.float16)]) + + # # # when tp_size=0, the comm part will be 0 + # self.generation_ops.append(ops.AllReduce('generation_ar_1', self._num_layers, h, tp_size)) + # self.generation_ops.append(ops.AllReduce('generation_ar_2', self._num_layers, h, tp_size)) + + # pp + pp_scale_factor = pp_size-1 + self.generation_ops.append(ops.P2P('generation_p2p', pp_scale_factor, h, pp_size)) + + if __name__ == '__main__': # TODO, move to unit tests model = get_model('DEEPSEEK_V3', config.ModelConfig( diff --git a/src/aiconfigurator/sdk/operations.py b/src/aiconfigurator/sdk/operations.py index e69db25d..03d9df87 100755 --- a/src/aiconfigurator/sdk/operations.py +++ b/src/aiconfigurator/sdk/operations.py @@ -90,7 +90,8 @@ def __init__(self, name: str, scale_factor: float, n: int, k: int, quant_mode: c self._n = n self._k = k self._quant_mode = quant_mode - self._weights = self._n*self._k*quant_mode.value.memory + self._weights = self._n*self._k*quant_mode.value.memory + def query(self, database:PerfDatabase, **kwargs): x = kwargs.get('x') overwrite_quant_mode = kwargs.get('quant_mode', None) @@ -99,7 +100,7 @@ def query(self, database:PerfDatabase, **kwargs): return database.query_gemm(x, self._n, self._k, quant_mode)*self._scale_factor def get_weights(self, **kwargs): - return self._weights * self._scale_factor + return self._weights * self._scale_factor class MoE(Operation): """ @@ -577,4 +578,42 @@ def query(self, database:PerfDatabase, **kwargs): return database.query_context_mla_sglang(batch_size, isl, self._tp_size, self._kvcache_quant_mode, self._fmha_quant_mode, self._attn_backend) * self._scale_factor def get_weights(self, **kwargs): - return self._weights * self._scale_factor \ No newline at end of file + return self._weights * self._scale_factor + +class Conv1D(Operation): + """ + Conv1D operation. + """ + def __init__(self, name: str, scale_factor: float, in_channels: int, out_channels: int, kernel_size: int, seq_length: int) -> None: + super().__init__(name, scale_factor) + self._in_channels = in_channels + self._out_channels = out_channels + self._kernel_size = kernel_size + self._seq_length = seq_length + self._weights = in_channels * out_channels * kernel_size * seq_length + + def query(self, database:PerfDatabase, **kwargs): + x = kwargs.get('x') + return database.query_conv_1d(x, self._in_channels, self._out_channels, self._kernel_size, self._seq_length)*self._scale_factor + + def get_weights(self, **kwargs): + return self._weights * self._scale_factor + +class ChunkGatedDeltaRule(Operation): + """ + Chunk gated delta rule operation. + """ + def __init__(self, name: str, scale_factor: float, in_channels: int, out_channels: int, kernel_size: int, seq_length: int) -> None: + super().__init__(name, scale_factor) + self._in_channels = in_channels + self._out_channels = out_channels + self._kernel_size = kernel_size + self._seq_length = seq_length + self._weights = in_channels * out_channels * kernel_size * seq_length + + def query(self, database:PerfDatabase, **kwargs): + x = kwargs.get('x') + return database.query_chunk_gated_delta_rule(x, self._in_channels, self._out_channels, self._kernel_size, self._seq_length)*self._scale_factor + + def get_weights(self, **kwargs): + return self._weights * self._scale_factor diff --git a/src/aiconfigurator/sdk/perf_database.py b/src/aiconfigurator/sdk/perf_database.py index 7e1853d1..233e702f 100755 --- a/src/aiconfigurator/sdk/perf_database.py +++ b/src/aiconfigurator/sdk/perf_database.py @@ -1939,12 +1939,17 @@ def get_sol(num_tokens: int, hidden_size: int, intermediate_size: int, quant_mod ) sol_math = ops / (self.system_spec['gpu']['float16_tc_flops'] * quant_mode.value.compute) * 1000 sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 - sol_time = max(sol_math, sol_mem) - return sol_time, sol_math, sol_mem - - if sol_mode is None: - sol_mode = self._default_sol_mode - if sol_mode == common.SOLMode.SOL: + for n in self._generation_attention_data[quant_mode][n_kv].keys(): + for b in self._generation_attention_data[quant_mode][n_kv][n].keys(): + for s in self._generation_attention_data[quant_mode][n_kv][n][b].keys(): + if n_kv == 0: + n_kv_local = n + else: + n_kv_local = n_kv + sol = self.query_generation_attention(b, s, n, n_kv_local, quant_mode, sol_mode=common.SOLMode.SOL) + if sol > self._generation_attention_data[quant_mode][n_kv][n][b][s]: + logger.debug('generation attention quant {} n{} n_kv{} b{} s{}: sol {} > perf_db {}'.format(quant_mode, n, n_kv_local, b, s, sol, self._generation_attention_data[quant_mode][n_kv][n][b][s])) + self._generation_attention_data[quant_mode][n_kv][n][b][s] = sol return get_sol(num_tokens, hidden_size, intermediate_size, quant_mode)[0] elif sol_mode == common.SOLMode.SOL_FULL: return get_sol(num_tokens, hidden_size, intermediate_size, quant_mode) @@ -2014,6 +2019,63 @@ def get_sol(num_tokens: int, num_experts: int, topk: int, hidden_size: int) -> T data = self._deepep_normal_data[node_num][hidden_size][topk][num_experts] lat = self._interp_2d_linear(sms, num_tokens, data) return lat / 1000.0 + + def query_conv_1d(self, + in_channels : int, + out_channels : int, + kernel_size : int, + seq_length : int, + sol_mode : Optional[common.SOLMode] = None) -> float: + """ + Query the conv1d data + """ + def get_sol(in_channels : int, out_channels : int, kernel_size : int, seq_length : int) -> Tuple[float, float, float]: + """ + Get the sol time, sol math and sol mem + """ + sol_math = in_channels * out_channels * kernel_size * seq_length / (self.system_spec['gpu']['float16_tc_flops']*1) * 1000 + sol_mem = in_channels * out_channels * kernel_size * seq_length / self.system_spec['gpu']['mem_bw'] * 1000 + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: + return get_sol(in_channels, out_channels, kernel_size, seq_length)[0] + elif sol_mode == common.SOLMode.SOL_FULL: + return get_sol(in_channels, out_channels, kernel_size, seq_length) + else: + result = self._interp_3d(in_channels, out_channels, kernel_size, seq_length, self._conv_1d_data, 'cubic') + return result + + def query_chunk_gated_delta_rule(self, + + num_tokens : int, + in_channels : int, + out_channels : int, + kernel_size : int, + seq_length : int, + sol_mode : Optional[common.SOLMode] = None) -> float: + """ + Query the chunk gated delta rule data + """ + def get_sol(num_tokens : int, in_channels : int, out_channels : int, kernel_size : int, seq_length : int) -> Tuple[float, float, float]: + """ + Get the sol time, sol math and sol mem + """ + sol_math = num_tokens * in_channels * out_channels * kernel_size * seq_length / (self.system_spec['gpu']['float16_tc_flops']*1) * 1000 + sol_mem = in_channels * out_channels * kernel_size * seq_length / self.system_spec['gpu']['mem_bw'] * 1000 + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: + return get_sol(num_tokens, in_channels, out_channels, kernel_size, seq_length)[0] + elif sol_mode == common.SOLMode.SOL_FULL: + return get_sol(num_tokens, in_channels, out_channels, kernel_size, seq_length) + else: + result = self._interp_3d(num_tokens, in_channels, out_channels, kernel_size, seq_length, self._chunk_gated_delta_rule_data, 'cubic') + return result + if __name__ == '__main__': database_dict = get_all_databases() \ No newline at end of file From b27380245a520b64e2a8fdc219670d92ded0efa4 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Wed, 15 Oct 2025 18:03:46 -0700 Subject: [PATCH 02/17] minor --- src/aiconfigurator/sdk/perf_database.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/aiconfigurator/sdk/perf_database.py b/src/aiconfigurator/sdk/perf_database.py index 233e702f..e3e37331 100755 --- a/src/aiconfigurator/sdk/perf_database.py +++ b/src/aiconfigurator/sdk/perf_database.py @@ -1939,17 +1939,12 @@ def get_sol(num_tokens: int, hidden_size: int, intermediate_size: int, quant_mod ) sol_math = ops / (self.system_spec['gpu']['float16_tc_flops'] * quant_mode.value.compute) * 1000 sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 - for n in self._generation_attention_data[quant_mode][n_kv].keys(): - for b in self._generation_attention_data[quant_mode][n_kv][n].keys(): - for s in self._generation_attention_data[quant_mode][n_kv][n][b].keys(): - if n_kv == 0: - n_kv_local = n - else: - n_kv_local = n_kv - sol = self.query_generation_attention(b, s, n, n_kv_local, quant_mode, sol_mode=common.SOLMode.SOL) - if sol > self._generation_attention_data[quant_mode][n_kv][n][b][s]: - logger.debug('generation attention quant {} n{} n_kv{} b{} s{}: sol {} > perf_db {}'.format(quant_mode, n, n_kv_local, b, s, sol, self._generation_attention_data[quant_mode][n_kv][n][b][s])) - self._generation_attention_data[quant_mode][n_kv][n][b][s] = sol + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: return get_sol(num_tokens, hidden_size, intermediate_size, quant_mode)[0] elif sol_mode == common.SOLMode.SOL_FULL: return get_sol(num_tokens, hidden_size, intermediate_size, quant_mode) @@ -1964,7 +1959,7 @@ def get_sol(num_tokens: int, hidden_size: int, intermediate_size: int, quant_mod num_left, num_right = self._nearest_1d_point_helper(num_tokens, list(mlp_dict.keys()), inner_only=False) lat = self._interp_1d([num_left, num_right], [mlp_dict[num_left], mlp_dict[num_right]], num_tokens) return lat - + def query_deepep_ll(self, node_num: int, num_tokens: int, From f712dfa8417b4c0f6df2680d21380b62a0ee08eb Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Fri, 17 Oct 2025 16:24:58 -0700 Subject: [PATCH 03/17] add ops --- collector/collect.py | 24 +- .../trtllm/collect_chunk_gated_delta_rule.py | 395 ------------------ collector/trtllm/collect_conv_1d.py | 218 ++++++---- collector/trtllm/collect_gated_delta_rule.py | 210 ++++++++++ .../sdk/backends/base_backend.py | 1 + .../sdk/backends/trtllm_backend.py | 3 +- src/aiconfigurator/sdk/models.py | 114 ++--- src/aiconfigurator/sdk/operations.py | 76 +++- src/aiconfigurator/sdk/perf_database.py | 223 ++++++++-- 9 files changed, 674 insertions(+), 590 deletions(-) delete mode 100644 collector/trtllm/collect_chunk_gated_delta_rule.py create mode 100644 collector/trtllm/collect_gated_delta_rule.py diff --git a/collector/collect.py b/collector/collect.py index 66f5210c..bd1d568f 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -438,20 +438,34 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): # CONV 1D collections { 'name': 'trtllm', - 'type': 'conv_1d', + 'type': 'conv_1d_fn', 'module': 'trtllm.collect_conv_1d', - 'get_func': 'get_conv_1d_test_cases', - 'run_func': 'run_conv_1d' + 'get_func': 'get_conv_1d_fn_test_cases', + 'run_func': 'run_conv_1d_fn' + }, + { + 'name': 'trtllm', + 'type': 'conv_1d_update', + 'module': 'trtllm.collect_conv_1d', + 'get_func': 'get_conv_1d_update_test_cases', + 'run_func': 'run_conv_1d_update' }, - # Chunk Gated Delta Rule collections + # Gated Delta Rule collections { 'name': 'trtllm', 'type': 'chunk_gated_delta_rule', - 'module': 'trtllm.collect_chunk_gated_delta_rule', + 'module': 'trtllm.collect_gated_delta_rule', 'get_func': 'get_chunk_gated_delta_rule_test_cases', 'run_func': 'run_chunk_gated_delta_rule' }, + { + 'name': 'trtllm', + 'type': 'gated_delta_rule_update', + 'module': 'trtllm.collect_gated_delta_rule', + 'get_func': 'get_gated_delta_rule_update_test_cases', + 'run_func': 'run_gated_delta_rule_update' + }, ] for collection in collections: diff --git a/collector/trtllm/collect_chunk_gated_delta_rule.py b/collector/trtllm/collect_chunk_gated_delta_rule.py deleted file mode 100644 index 0f51e8c3..00000000 --- a/collector/trtllm/collect_chunk_gated_delta_rule.py +++ /dev/null @@ -1,395 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import tensorrt_llm -import torch -from cuda import cuda -from tensorrt_llm._torch.attention_backend.utils import create_attention -from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams, AttentionRuntimeFeatures -from tensorrt_llm.functional import PositionEmbeddingType -from tensorrt_llm.models.modeling_utils import QuantConfig, QuantAlgo -from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata -from tensorrt_llm.mapping import Mapping -from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager -from tensorrt_llm._torch.metadata import KVCacheParams -from tensorrt_llm.llmapi import KvCacheConfig -import os -from helper import getSMVersion, log_perf - -def run_attention_torch(batch_size, - input_len, - num_heads, - num_key_value_heads, # keep same as num_heads for MHA - head_dim, - attention_window_size, - use_fp8_kv_cache, - use_fp8_context_fmha, - is_context_phase, - perf_filename, - device='cuda:0'): - torch.cuda.set_device(device) - - # if XQA JIT is enabled, the context phase will also trigger XQA prepare which causes the error with specifc q/kv head and seq setting. - if is_context_phase: - os.environ['TRTLLM_ENABLE_XQA_JIT']= '0' - else: - os.environ['TRTLLM_ENABLE_XQA_JIT']= '1' - - backend_name = "TRTLLM" - layer_idx = 0 - world_size=1 - tp_size=1 - tokens_per_block=64 - warming_up=10 - test_ite=6 - output_len=1 - if use_fp8_context_fmha: - assert use_fp8_kv_cache==True - quant_algo = QuantAlgo.FP8 - out_scale = torch.tensor( - [1.0], - dtype=torch.float32, - device=device, - )# fp8 fmha - else: - quant_algo = None - out_scale = None - - if use_fp8_kv_cache: - kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8 - else: - kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 - - pos_embd_params = PositionalEmbeddingParams( - type=PositionEmbeddingType.rope_gpt_neox, - rope=RopeParams(dim=128)) - - - quant_config=QuantConfig(quant_algo=quant_algo, # fp8 fmha - kv_cache_quant_algo=QuantAlgo.FP8 if use_fp8_kv_cache else None, # fp8 kv, - group_size=128, - smoothquant_val=0.5, - clamp_val=None, - use_meta_recipe=False, - has_zero_point=False, - pre_quant_scale=False, - exclude_modules=None) - - attn = create_attention(backend_name = backend_name, - layer_idx = layer_idx, - num_heads = num_heads, - head_dim = head_dim, - num_kv_heads = num_key_value_heads, - pos_embd_params=pos_embd_params, - quant_config=quant_config, - is_mla_enable=False) - - total_num_tokens = (input_len + output_len) * batch_size - - mapping = Mapping(world_size=world_size, rank=0, tp_size=tp_size) - - num_hidden_layers = 1 - - kv_cache_config = KvCacheConfig( - max_tokens=int((input_len + output_len - 1)/tokens_per_block + 1) * tokens_per_block * batch_size * 2, #num_bloacks * block_size - enable_block_reuse=False) - - kv_cache_manager = KVCacheManager( - kv_cache_config=kv_cache_config, - kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY, - num_layers=num_hidden_layers, - num_kv_heads=num_key_value_heads, - head_dim=head_dim, - tokens_per_block=tokens_per_block, - max_seq_len=input_len + output_len + 1, # +1 for the magic fixme mentioned in trtllm xqa JIT path impl. - max_batch_size=batch_size, - mapping=mapping, - dtype=kv_cache_dtype) - - input_seq_lens = [input_len for _ in range(batch_size)] - total_seq_lens = [input_len + output_len for _ in range(batch_size)] - request_ids = [i for i in range(batch_size)] - kv_cache_manager.add_dummy_requests(request_ids, total_seq_lens) - - if is_context_phase: - num_cached_tokens_per_seq = [0 for _ in range(batch_size)] - attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, - max_num_tokens=total_num_tokens, - kv_cache_manager=kv_cache_manager, - mapping=mapping, - enable_flash_mla=False, - seq_lens=torch.tensor(input_seq_lens, dtype=torch.int32), - num_contexts=batch_size, - position_ids=None, - kv_cache_params=KVCacheParams(use_cache=True, num_cached_tokens_per_seq=num_cached_tokens_per_seq, block_ids_per_seq=None, host_max_attention_window_sizes=None, host_sink_token_length=None), - cross=None, - request_ids=request_ids, - prompt_lens=input_seq_lens, - runtime_features=AttentionRuntimeFeatures(chunked_prefill=False, cache_reuse=False, has_speculative_draft_tokens=False), - all_rank_num_tokens=None, - workspace=torch.tensor([], device=device, dtype=torch.int8)) - else: - gen_seq_lens = [1 for _ in range(batch_size)] - attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, - max_num_tokens=total_num_tokens, - kv_cache_manager=kv_cache_manager, - mapping=mapping, - enable_flash_mla=False, - seq_lens=torch.tensor(gen_seq_lens, dtype=torch.int32), - position_ids=None, - num_contexts=0, - kv_cache_params=KVCacheParams(use_cache=True, num_cached_tokens_per_seq=input_seq_lens, block_ids_per_seq=None, host_max_attention_window_sizes=None, host_sink_token_length=None), - cross=None, - request_ids=request_ids, - prompt_lens=input_seq_lens, - runtime_features=AttentionRuntimeFeatures(chunked_prefill=False, cache_reuse=False), - all_rank_num_tokens=None, - workspace=torch.tensor([], device=device, dtype=torch.int8)) - - attn_metadata.prepare() - - if is_context_phase: - num_tokens = input_len * batch_size - else: - num_tokens = batch_size - - sinks = torch.randn(num_heads, dtype=torch.float32) if head_dim == 64 else None - q = torch.randn([num_tokens, num_heads*128]).bfloat16().to(torch.device(device)) - kv = torch.randn([num_tokens, 2*num_key_value_heads*128]).bfloat16().to(torch.device(device)) - input_qkv = torch.concat([q, kv], dim=-1) - attn.forward( - input_qkv, - None, - None, - attn_metadata, - attention_window_size=attention_window_size if attention_window_size>0 else None, - attention_sinks=sinks, - out_scale=out_scale - ) - out_dtype = None if not use_fp8_context_fmha else torch.float8_e4m3fn - - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - attn.forward( - input_qkv, - None, - None, - attn_metadata, - attention_window_size=attention_window_size if attention_window_size>0 else None, - attention_sinks=sinks, - out_scale=out_scale - ) - # warmup - for i in range(warming_up): - g.replay() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for i in range(test_ite): - g.replay() - end_event.record() - torch.cuda.synchronize() - latency = start_event.elapsed_time(end_event)/test_ite - - # write result - if is_context_phase: - isl = input_len - step = 0 - op_name = 'context_attention' - else: - isl = 1 - step = input_len - op_name = 'generation_attention' - kv_cache_dtype_str = 'float16' - if use_fp8_kv_cache: - kv_cache_dtype_str = 'fp8' - if use_fp8_context_fmha: - dtype_str = 'fp8' - else: - dtype_str = 'float16' - - log_perf(item_list=[{ - 'batch_size': batch_size, - 'isl': isl, - 'num_heads': num_heads, - 'num_key_value_heads': num_key_value_heads, - 'head_dim': head_dim, - 'window_size': attention_window_size, - 'beam_width': 1, - 'attn_dtype': dtype_str, - 'kv_cache_dtype': kv_cache_dtype_str, - 'step': step, - 'latency': latency - }], - framework='TRTLLM', - version=tensorrt_llm.__version__, - device_name=torch.cuda.get_device_name(device), - op_name=op_name, - kernel_source='torch_flow', - perf_filename=perf_filename) - kv_cache_manager.shutdown() - - -def get_context_attention_test_cases(): - has_fp8 = getSMVersion() > 86 - test_cases = [] - b_list = [1,2,4,8,16,32,64,128,256] - s_list = [16,32,64,128,256,512,1024,1536,2048,3072,4096,6144,8192,10240,12288,16384,262144] - n_list = [4,8,12,16,24,32,40,48,64,96] - n_kv_list = [0,1,2,4,8] - head_dim = [64,128] - - for h in head_dim: - for n in sorted(n_list, reverse=True): - for s in sorted(s_list, reverse=True): - for b in sorted(b_list, reverse=True): - for n_kv in n_kv_list: - if n_kv != 0: - if n_kv >= n or n%n_kv != 0: - continue - num_kv_heads = n_kv if n_kv !=0 else n - - if num_kv_heads == n: - if b*s > 65536 or b >128: - continue - else: - if b*s > 131072: - continue - if b*s*num_kv_heads*128*2 >= 2147483647: - continue - if getSMVersion() >= 100: - # TLLM_CHECK_WITH_INFO((params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta || params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0), - if n >= 32 and n % 32 != 0: - continue - - #print(f'collecting heads: {n} kv_heads: {num_kv_heads} seq: {s} batchsize: {b}') - # use fp8 kv cache, fp8 context fmha, is_context_phase. in torch flow, int8 kvcache is not supported yet. - # fp16 kv cache, fp16 context fmha, is_context_phase - if head_dim == 64: - test_cases.append([b, s, n, num_kv_heads, h, 128, False, False, True, 'context_attention_perf.txt']) - if has_fp8: - test_cases.append([b, s, n, num_kv_heads, h, 128, True, False, True, 'context_attention_perf.txt']) - test_cases.append([b, s, n, num_kv_heads, h, 128, True, True, True, 'context_attention_perf.txt']) - else: - test_cases.append([b, s, n, num_kv_heads, h, 0, False, False, True, 'context_attention_perf.txt']) - if has_fp8: - test_cases.append([b, s, n, num_kv_heads, h, 0, True, False, True, 'context_attention_perf.txt']) - test_cases.append([b, s, n, num_kv_heads, h, 0, True, True, True, 'context_attention_perf.txt']) - - return test_cases - -def get_generation_attention_test_cases(): - has_fp8 = getSMVersion() > 86 - test_cases = [] - - # generation - isl = 1 - b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048,] - b_list_xqa = [1,2,4,8,16,32,64,128,256,512,1024,2048] - # the i-th token to record. 1 for context phase. mapping to osl definition - s_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] - n_list = [4,8,12,16,24,32,40,48,64] - n_list_xqa = [4,8,16,32,64,96,128] - n_kv_list = [1,2,4,8] - head_dim = [64,128] - - # MHA - max_bsn = 8192*1024 #2*1024*1024*1024/128/2 INT32MAX/128/2 - max_bsn_largeb = max_bsn//2 - for n in sorted(n_list, reverse=True): - b_s_dict = {} - s_b_dict = {} - for s in s_list: - max_b = max_bsn // s // n # b*s*n*byte <= max_bsn - for b in b_list: - if b > max_b: - break - if s not in s_b_dict.keys(): - s_b_dict[s] = {b} - else: - s_b_dict[s].add(b) - for s, b_set in s_b_dict.items(): - if len(b_set) < 4: - continue - for b in b_set: - if b not in b_s_dict.keys(): - b_s_dict[b] = {s-1} - b_s_dict[b].add(s-1) - for h in head_dim: - for b, s_list_limited in b_s_dict.items(): - target_s_list = sorted(s_list_limited) - if b >= 256: - target_s_list = target_s_list[:-1] - #print(f'collecting MHA heads: {n} batchsize: {b} steps: {s_list_limited}') - # fp8 kv cache, fp8 context fmha, is_context_phase - for s in target_s_list: - if getSMVersion() >= 100: - # TLLM_CHECK_WITH_INFO((params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta || params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0), - if n >= 32 and n % 32 != 0: - continue - - test_cases.append([b, s, n, n, h, 0, False, False, False, 'generation_attention_perf.txt']) - - if has_fp8: - test_cases.append([b, s, n, n, h, 0, True, False, False, 'generation_attention_perf.txt']) - # currently, fp8 is not for generation compute - #test_cases.append([b, s, n, n, 128, True, True, False, 'generation_attention_perf.txt']) - - # XQA - max_bsn = 8192*1024*2 #2*1024*1024*1024/128/2 - for n in sorted(n_list_xqa, reverse=True): - b_s_dict = {} - s_b_dict = {} - for s in s_list: - max_b = max_bsn // s // n - for b in b_list: - if b > max_b: - break - if s not in s_b_dict.keys(): - s_b_dict[s] = {b} - else: - s_b_dict[s].add(b) - for s, b_set in s_b_dict.items(): - if len(b_set) < 4: - continue - for b in b_set: - if b not in b_s_dict.keys(): - b_s_dict[b] = {s-1} - b_s_dict[b].add(s-1) - for h in head_dim: - for b, s_list_limited in b_s_dict.items(): - target_s_list = sorted(s_list_limited) - if b >= 256: - target_s_list = target_s_list[:-1] - for n_kv in n_kv_list: - if n_kv >= n: - continue - - # fp8 kv cache, fp8 context fmha, is_context_phase - for s in target_s_list: - if getSMVersion() >= 100: - # TLLM_CHECK_WITH_INFO((params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta || params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0), - if n >= 32 and n % 32 != 0: - continue - if head_dim == 64: - test_cases.append([b, s, n, n_kv, h, 128, False, False, False, 'generation_attention_perf.txt']) - if has_fp8: - test_cases.append([b, s, n, n_kv, h, 128, True, False, False, 'generation_attention_perf.txt']) - # currently, fp8 is not for generation compute - #test_cases.append([b, s, n, n_kv, 128, True, True, False, 'generation_attention_perf.txt']) - else: - test_cases.append([b, s, n, n_kv, h, 0, False, False, False, 'generation_attention_perf.txt']) - if has_fp8: - test_cases.append([b, s, n, n_kv, h, 0, True, False, False, 'generation_attention_perf.txt']) - return test_cases - -if __name__ == '__main__': - test_cases = get_context_attention_test_cases() - for test_case in test_cases: - run_attention_torch(*test_case) - - test_cases = get_generation_attention_test_cases() - for test_case in test_cases: - run_attention_torch(*test_case) \ No newline at end of file diff --git a/collector/trtllm/collect_conv_1d.py b/collector/trtllm/collect_conv_1d.py index b9f61da7..bf91ff19 100644 --- a/collector/trtllm/collect_conv_1d.py +++ b/collector/trtllm/collect_conv_1d.py @@ -1,143 +1,185 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 + import os from cuda import cuda import torch -import torch.nn as nn import tensorrt_llm -import math -from helper import getSMVersion, log_perf +from tensorrt_llm.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from helper import log_perf -def get_conv1d_test_cases(): +def get_conv1d_fn_test_cases(): """ - Generate test cases for Conv1D operations. + Generate test cases for Conv1DFn operations. Test parameters: - - batch_size: batch size (analogous to 'm' in GEMM) - - in_channels: number of input channels - - out_channels: number of output channels - - kernel_size: size of the convolution kernel - - seq_length: sequence length + - batch_size: batch size + - isl: sequence length + - conv_kernel_size: size of the convolution kernel + - conv_dim: dimension of the convolution + - tp_size: attention tensor parallel size """ - batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] - channel_sizes = [64, 128, 256, 512, 768, 1024, 1536, 2048, 3072, 4096] - kernel_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] - seq_lengths = [32, 64, 128, 256, 512, 1024, 2048, 4096] - - conv1d_types = ['float16'] - if getSMVersion() > 86: - conv1d_types += ['fp8'] - if getSMVersion() < 100: - conv1d_types += ['fp8_block'] - if getSMVersion() >= 100: - conv1d_types += ['nvfp4'] + b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + s_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] + tp_sizes = [1, 2, 4, 8] + conv_dims = [64, 128, 256, 512, 768, 1024, 1536, 2048, 3072, 4096] + kernel_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] test_cases = [] - for conv_type in conv1d_types: - # Generate test cases with various combinations - for batch_size in batch_sizes: - for in_ch in channel_sizes: - for out_ch in channel_sizes: + for batch_size in b_list: + for isl in s_list: + for tp_size in tp_sizes: + for conv_dim in conv_dims: for kernel_size in kernel_sizes: - for seq_len in seq_lengths: - # Skip extremely large cases - if batch_size * in_ch * seq_len > 16777216: - continue - if conv_type == 'nvfp4' or conv_type == 'fp8_block': - if in_ch < 128 or out_ch < 128: - continue - test_cases.append([conv_type, batch_size, in_ch, out_ch, kernel_size, seq_len, 'conv1d_perf.txt']) + test_cases.append([batch_size, isl, kernel_size, conv_dim, tp_size, 'conv1d_fn_perf.txt']) return test_cases -def run_conv1d(conv_type, batch_size, in_channels, out_channels, kernel_size, seq_length, perf_filename, device='cuda:0'): +def run_conv1d_fn(batch_size, isl, conv_kernel_size, conv_dim, tp_size, perf_filename, device='cuda:0'): """ - Run Conv1D performance benchmarking. + Run Conv1DFn performance benchmarking. Args: - conv_type: Type of convolution ('float16', 'fp8', 'fp8_block', 'nvfp4') batch_size: Batch size - in_channels: Number of input channels - out_channels: Number of output channels - kernel_size: Size of the convolution kernel - seq_length: Sequence length + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Attention tensor parallel size perf_filename: Output file for performance results device: CUDA device to use """ - device = torch.device(device) - torch.cuda.set_device(device) - torch.set_default_device(device) - - # For now, we focus on float16/bfloat16 benchmarking - # Quantized Conv1D support can be added when available in TensorRT-LLM - if conv_type != 'float16': - # Skip non-float16 for now as Conv1D quantization in TensorRT-LLM needs verification - return - dtype = torch.bfloat16 - # Conv1D expects input shape: (batch_size, in_channels, seq_length) - x = torch.randn((batch_size, in_channels, seq_length), dtype=dtype).to(torch.device(device)) + mixed_qkv = torch.randn((batch_size * isl, conv_dim // tp_size), dtype=dtype).to(torch.device(device)) + conv1d_weights = torch.randn((conv_dim // tp_size, conv_kernel_size), dtype=dtype, device=device) - repeat_n = 5 # to reduce impact of L2 cache hit - op_list = [] - - for i in range(repeat_n): - # Use PyTorch's native Conv1d - conv1d = nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride=1, - padding=kernel_size//2, # 'same' padding - bias=False, - dtype=dtype, - ) - - # Initialize weights randomly - conv1d.weight.data = torch.randn((out_channels, in_channels, kernel_size), dtype=dtype, device=device) - - conv1d.to(torch.device(device)) - conv1d(x) # dry run to init - op_list.append(conv1d) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + mixed_qkv_trans = mixed_qkv.transpose(0, 1) + # TODO: measure optional arguments + causal_conv1d_fn( + mixed_qkv_trans, + conv1d_weights, + ).transpose(0, 1) num_warmups = 3 num_runs = 6 + + # warmup + for _ in range(num_warmups): + g.replay() + + # measure + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_runs): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/num_runs + + log_perf( + item_list=[{ + 'batch_size': batch_size, + 'isl': isl, + 'conv_kernel_size': conv_kernel_size, + 'conv_dim': conv_dim, + 'tp_size': tp_size, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name='conv1d_fn', + kernel_source='default', + perf_filename=perf_filename + ) + +def get_conv1d_update_test_cases(): + """ + Generate test cases for Conv1DUpdate operations. + + Test parameters: + - batch_size: batch size + - isl: sequence length + - conv_kernel_size: size of the convolution kernel + - conv_dim: dimension of the convolution + - tp_size: attention tensor parallel size + """ + b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + s_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] + tp_sizes = [1, 2, 4, 8] + conv_dims = [1,2,4,8,16,32] + kernel_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + + test_cases = [] + for batch_size in b_list: + for isl in s_list: + for tp_size in tp_sizes: + for conv_dim in conv_dims: + for kernel_size in kernel_sizes: + test_cases.append([batch_size, isl, kernel_size, conv_dim, tp_size, 'conv1d_update_perf.txt']) + + return test_cases + + +def run_conv1d_update(batch_size, isl, conv_kernel_size, conv_dim, tp_size, perf_filename, device='cuda:0'): + """ + Run Conv1DUpdate performance benchmarking. + + Args: + batch_size: Batch size + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Attention tensor parallel size + perf_filename: Output file for performance results + device: CUDA device to use + """ + dtype = torch.bfloat16 + mixed_qkv = torch.randn((batch_size * isl, conv_dim // tp_size), dtype=dtype).to(torch.device(device)) + conv1d_weights = torch.randn((conv_dim // tp_size, conv_kernel_size), dtype=dtype, device=device) - # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): - for op in op_list: - op.forward(x) + # TODO: measure optional arguments + causal_conv1d_update( + mixed_qkv, + conv1d_weights, + ) + + num_warmups = 3 + num_runs = 6 # warmup - for i in range(num_warmups): + for _ in range(num_warmups): g.replay() + # measure start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - for i in range(num_runs): + for _ in range(num_runs): g.replay() end_event.record() torch.cuda.synchronize() - latency = start_event.elapsed_time(end_event)/num_runs/len(op_list) + latency = start_event.elapsed_time(end_event)/num_runs log_perf( item_list=[{ - 'conv_dtype': conv_type, 'batch_size': batch_size, - 'in_channels': in_channels, - 'out_channels': out_channels, - 'kernel_size': kernel_size, - 'seq_length': seq_length, + 'isl': isl, + 'conv_kernel_size': conv_kernel_size, + 'conv_dim': conv_dim, + 'tp_size': tp_size, 'latency': latency }], framework='TRTLLM', version=tensorrt_llm.__version__, device_name=torch.cuda.get_device_name(device), - op_name='conv1d', - kernel_source='torch_flow', + op_name='conv1d_fn', + kernel_source='default', perf_filename=perf_filename ) diff --git a/collector/trtllm/collect_gated_delta_rule.py b/collector/trtllm/collect_gated_delta_rule.py new file mode 100644 index 00000000..51a62e5d --- /dev/null +++ b/collector/trtllm/collect_gated_delta_rule.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import os +from cuda import cuda +import torch +import tensorrt_llm +from tensorrt_llm.modules.flash_attention import chunk_gated_delta_rule, fused_sigmoid_gating_delta_rule_update +from helper import log_perf + +def get_chunk_gated_delta_rule_test_cases(): + """ + Generate test cases for chunk_gated_delta_rule() operations. + + Test parameters: + - num_heads: number of heads + - head_k_dim: dimension of the key heads + - head_v_dim: dimension of the value heads + - num_value_heads: number of value heads + - isl: sequence length + """ + num_heads_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + head_k_dim_list = [1,2,4,8,16,32,64,128] + head_v_dim_list = [1,2,4,8,16,32,64,128] + num_value_heads_list = [1,2,4,8,16,32,64,128] + isl_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] + + test_cases = [] + for num_heads in num_heads_list: + for head_k_dim in head_k_dim_list: + for head_v_dim in head_v_dim_list: + for num_value_heads in num_value_heads_list: + for isl in isl_list: + test_cases.append([num_heads, head_k_dim, head_v_dim, num_value_heads, isl, 'chunk_gated_delta_rule_perf.txt']) + + return test_cases + + +def run_chunk_gated_delta_rule(num_heads, head_k_dim, head_v_dim, num_value_heads, isl, perf_filename, device='cuda:0'): + """ + Run chunk_gated_delta_rule() performance benchmarking. + + Args: + num_heads: Number of heads + head_k_dim: Dimension of the key heads + head_v_dim: Dimension of the value heads + num_value_heads: Number of value heads + isl: Sequence length + perf_filename: Output file for performance results + device: CUDA device to use + """ + # NOTICE: ignored fused_gdn_gating operation + dtype = torch.bfloat16 + q = torch.randn((1, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) + k = torch.randn((1, isl, num_heads, head_v_dim), dtype=dtype).to(torch.device(device)) + v = torch.randn((1, isl, num_value_heads, head_v_dim), dtype=dtype).to(torch.device(device)) + gate = torch.randn((1, isl, num_heads), dtype=dtype).to(torch.device(device)) + beta = torch.randn((1, isl, num_heads), dtype=dtype).to(torch.device(device)) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + chunk_gated_delta_rule(q, k, v, gate, beta) + + num_warmups = 3 + num_runs = 6 + + # warmup + for _ in range(num_warmups): + g.replay() + + # measure + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_runs): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/num_runs + + log_perf( + item_list=[{ + 'num_heads': num_heads, + 'head_k_dim': head_k_dim, + 'head_v_dim': head_v_dim, + 'num_value_heads': num_value_heads, + 'isl': isl, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name='chunk_gated_delta_rule', + kernel_source='default', + perf_filename=perf_filename + ) + +def get_gated_delta_rule_update_test_cases(): + """ + Generate test cases for Conv1DUpdate operations. + + Test parameters: + - batch_size: batch size + - isl: sequence length + - num_heads: number of heads + - head_k_dim: dimension of the key heads + - head_v_dim: dimension of the value heads + - num_value_heads: number of value heads + - max_batch_size: maximum batch size + """ + b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + s_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] + num_heads_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + head_k_dim_list = [1,2,4,8,16,32,64,128] + head_v_dim_list = [1,2,4,8,16,32,64,128] + num_value_heads_list = [1,2,4,8,16,32,64,128] + max_batch_size_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + + test_cases = [] + for batch_size in b_list: + for isl in s_list: + for num_heads in num_heads_list: + for head_k_dim in head_k_dim_list: + for head_v_dim in head_v_dim_list: + for num_value_heads in num_value_heads_list: + for max_batch_size in max_batch_size_list: + test_cases.append([batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size, 'gated_delta_rule_update_perf.txt']) + + return test_cases + + +def run_gated_delta_rule_update(batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size, perf_filename, device='cuda:0'): + """ + Run fused_sigmoid_gating_delta_rule_update() performance benchmarking. + + Args: + batch_size: Batch size + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Attention tensor parallel size + perf_filename: Output file for performance results + device: CUDA device to use + """ + dtype = torch.bfloat16 + A_log = torch.randn((num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) + dt_bias = torch.randn((num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) + q = torch.randn((batch_size, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) + k = torch.randn((batch_size, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) + v = torch.randn((batch_size, isl, num_value_heads, head_v_dim), dtype=dtype).to(torch.device(device)) + a = torch.randn((batch_size * isl, num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) + b = torch.randn((batch_size, isl, num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) + initial_state_source = torch.randn((max_batch_size, num_heads * num_value_heads, head_k_dim, head_v_dim), dtype=dtype).to(torch.device(device)) + initial_state_indices = torch.randn((batch_size), dtype=dtype).to(torch.device(device)) + softplus_beta = 1.0 + softplus_threshold = 20.0 + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + # TODO: measure optional arguments + fused_sigmoid_gating_delta_rule_update( + A_log =A_log, + dt_bias = dt_bias, + q = q, + k = k, + v = v, + a = a, + b = b, + initial_state_source = initial_state_source, + initial_state_indices = initial_state_indices, + softplus_beta = softplus_beta, + softplus_threshold = softplus_threshold, + ) + + num_warmups = 3 + num_runs = 6 + + # warmup + for _ in range(num_warmups): + g.replay() + + # measure + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_runs): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/num_runs + + log_perf( + item_list=[{ + 'batch_size': batch_size, + 'isl': isl, + 'num_heads': num_heads, + 'head_k_dim': head_k_dim, + 'head_v_dim': head_v_dim, + 'num_value_heads': num_value_heads, + 'max_batch_size': max_batch_size, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name='gated_delta_rule_update', + kernel_source='default', + perf_filename=perf_filename + ) diff --git a/src/aiconfigurator/sdk/backends/base_backend.py b/src/aiconfigurator/sdk/backends/base_backend.py index 7bc778d3..e0558500 100644 --- a/src/aiconfigurator/sdk/backends/base_backend.py +++ b/src/aiconfigurator/sdk/backends/base_backend.py @@ -46,6 +46,7 @@ def run_static(self, step, default is 32. latency_correction_scale (float): the correction scale to adjust the latency, default is 1.0. corrected latency = latency * latency_correction_scale """ + def _run_context(batch_size: int, isl: int) -> dict[str, float]: context_latency_dict = defaultdict(float) diff --git a/src/aiconfigurator/sdk/backends/trtllm_backend.py b/src/aiconfigurator/sdk/backends/trtllm_backend.py index d5131b40..9ae3d38c 100644 --- a/src/aiconfigurator/sdk/backends/trtllm_backend.py +++ b/src/aiconfigurator/sdk/backends/trtllm_backend.py @@ -317,7 +317,8 @@ def _get_memory_usage(self, c_dict = {1:11, 2:6.5, 4:5, 8:5} activations = 2*num_tokens*h*c_dict[min(model.config.tp_size, 8)] activations = max(activations, 70*1024*1024) # minimum act - elif get_model_family(model.model_name) == 'MOE': + elif get_model_family(model.model_name) in ['MOE', 'QWEN3NEXT']: + # TODO: Qwen3Next has different activation memory calculation. c_dict = {1:22, 2:13, 4:10, 8:10} activations = 2*num_tokens*h*c_dict[min(model.config.tp_size, 8)] activations = max(activations, 70*1024*1024) # minimum act diff --git a/src/aiconfigurator/sdk/models.py b/src/aiconfigurator/sdk/models.py index 72a96ff8..db0fbce5 100755 --- a/src/aiconfigurator/sdk/models.py +++ b/src/aiconfigurator/sdk/models.py @@ -749,6 +749,8 @@ class Qwen3NextModel(BaseModel): 3/4 of the layers are using linear attention with convolution 1d operation. Some rules to follow, Due to implementation, attn layer name needs to be context_attention or generation_attention, exact match is required. Same for logits_gemm. + + Refer to tensorrt_llm/_torch/models/modeling_qwen3_next.py for more details. """ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> None: super().__init__(*args) @@ -796,14 +798,21 @@ def context_ops(self, linear_attention_config: common.LinearAttentionConfig): communication, and final logits computation. Args: - linear_attention_config (common.LinearAttentionConfig): Linear attention configuration + linear_attention_config (common.LinearAttentionConfig or list): Linear attention configuration + or empty list for initialization """ + self._context_ops = [] + if not isinstance(linear_attention_config, common.LinearAttentionConfig): + return + num_v_heads = linear_attention_config.linear_num_value_heads num_k_heads = linear_attention_config.linear_num_key_heads head_k_dim = linear_attention_config.linear_key_head_dim head_v_dim = linear_attention_config.linear_value_head_dim key_dim = head_k_dim * num_k_heads value_dim = head_v_dim * num_v_heads + conv_kernel_size = linear_attention_config.linear_conv_kernel_dim + conv_dim = key_dim * 2 + value_dim moe_quant_mode = self.config.moe_quant_mode h = self._hidden_size tp_size = self.config.tp_size @@ -819,60 +828,56 @@ def context_ops(self, linear_attention_config: common.LinearAttentionConfig): # 1 embedding for all layers # 1 norm before attention per layer - self.context_ops.extend([ops.Embedding(f'context_embedding', 1, self._vocab_size, h, 0.3), + self._context_ops.extend([ops.Embedding(f'context_embedding', 1, self._vocab_size, h, 0.3), ops.ElementWise(f'context_add_norm_1', self._num_layers, 2*h, 2*h, 0.8)]) # self attention - # (self_attn): Qwen3NextAttention( - # (q_proj): Linear(in_features=2048, out_features=8192, bias=False) out_features = num_heads * head dimension(head_size) - # (k_proj): Linear(in_features=2048, out_features=512, bias=False) out_features = num_key_value_heads * head dimension(head_size) = 2 * 256 = 512 - # (v_proj): Linear(in_features=2048, out_features=512, bias=False) out_features = num_key_value_heads * head dimension(head_size) = 2 * 256 = 512 - # (o_proj): Linear(in_features=4096, out_features=2048, bias=False) out_features = hidden size = 2048 - # (q_norm): Qwen3NextRMSNorm((256,), eps=1e-06) ignored since the latency is ignorable - # (k_norm): Qwen3NextRMSNorm((256,), eps=1e-06) ignored since the latency is ignorable - # context_qkv_gemm: q + k + v - # context_attention: num of heads / tp_size - # context_proj_gemm: num of heads * head_size / tp_size - self.context_ops.extend([ + self._context_ops.extend([ ops.GEMM(f'context_qkv_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, gemm_quant_mode), ops.ContextAttention(f'context_attention', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode, fmha_quant_mode), ops.GEMM(f'context_proj_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), h, self._num_heads*self._head_size//tp_size, gemm_quant_mode), ]) # linear attention (Qwen3NextGatedDeltaNet) - self.context_ops.extend([ - ops.GEMM(f'context_qkvz_ba_gemm', self._num_layers * linear_attention_config.used_ratio, (key_dim*2 + value_dim*2 + num_v_heads*2)//tp_size, h, gemm_quant_mode), - ops.Conv1D(f'context_conv1d', self._num_layers * linear_attention_config.used_ratio, self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, kernel_size=4, stride=1, padding=3, groups=self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, bias=False, gemm_quant_mode), - ops.ChunkGatedDeltaRule(f'context_chunk_gated_delta_rule', self._num_layers * linear_attention_config.used_ratio, in_channels=self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, out_channels=h, kernel_size=4, seq_length=seq_length), + self._context_ops.extend([ + # Input projections for qkvz and ba + ops.GEMM(f'context_qkvz_gemm', self._num_layers * linear_attention_config.used_ratio, (key_dim*2 + value_dim*2)//tp_size, h, gemm_quant_mode), + ops.GEMM(f'context_ba_gemm', self._num_layers * linear_attention_config.used_ratio, num_v_heads*2//tp_size, h, gemm_quant_mode), + # Conv1D and gated delta rule operations - weights handled internally + # Conv1DFn(name, scale_factor, conv_kernel_size, conv_dim, tp_size) - batch_size and isl from kwargs + ops.Conv1DFn(f'context_conv1d_fn', self._num_layers * linear_attention_config.used_ratio, conv_kernel_size, conv_dim, tp_size), + # ChunkGatedDeltaRule(name, scale_factor, num_heads, head_k_dim, head_v_dim, num_value_heads) - isl from kwargs + ops.ChunkGatedDeltaRule(f'context_chunk_gated_delta_rule', self._num_layers * linear_attention_config.used_ratio, num_k_heads, head_k_dim, head_v_dim, num_v_heads), + # Output projection ops.GEMM(f'context_proj_gemm', self._num_layers * linear_attention_config.used_ratio, h, value_dim//tp_size, gemm_quant_mode), ]) # 1 norm before MOE per layer - self.context_ops.extend([ + self._context_ops.extend([ ops.ElementWise(f'context_add_norm_2', self._num_layers // linear_attention_config.used_ratio, 2*h, 2*h, 0.8)]) #router, only take it into account when num_experts >= 128 if self._num_experts >= 128: - self.context_ops.extend([ + self._context_ops.extend([ ops.GEMM(f'context_router_gemm', self._num_layers, self._num_experts, h, common.GEMMQuantMode.float16) ]) # dispatch tokens to experts, moe calc and get tokens back # Qwen3Next has one more shared expert. - self.context_ops.extend([ + self._context_ops.extend([ ops.MoEDispatch(f'context_moe_pre_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, True), ops.MoE(f'context_moe', self._num_layers, h, self._moe_inter_size, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, moe_quant_mode, workload_distribution, attention_dp_size), ops.MoEDispatch(f'context_moe_post_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, False)]) - self.context_ops.extend([ops.GEMM(f'context_logits_gemm', 1, self._vocab_size//tp_size, h, common.GEMMQuantMode.float16)]) + self._context_ops.extend([ops.GEMM(f'context_logits_gemm', 1, self._vocab_size//tp_size, h, common.GEMMQuantMode.float16)]) # # # when tp_size=0, the comm part will be 0 - # self.context_ops.append(ops.AllReduce('context_ar_1', self._num_layers, h, tp_size)) - # self.context_ops.append(ops.AllReduce('context_ar_2', self._num_layers, h, tp_size)) + # self._context_ops.append(ops.AllReduce('context_ar_1', self._num_layers, h, tp_size)) + # self._context_ops.append(ops.AllReduce('context_ar_2', self._num_layers, h, tp_size)) # pp pp_scale_factor = pp_size-1 - self.context_ops.append(ops.P2P('context_p2p', pp_scale_factor, h, pp_size)) + self._context_ops.append(ops.P2P('context_p2p', pp_scale_factor, h, pp_size)) @property def generation_ops(self): @@ -903,13 +908,20 @@ def generation_ops(self, linear_attention_config: common.LinearAttentionConfig): Args: linear_attention_config (common.LinearAttentionConfig): Linear attention configuration + or empty list for initialization """ + self._generation_ops = [] + if not isinstance(linear_attention_config, common.LinearAttentionConfig): + return + num_v_heads = linear_attention_config.linear_num_value_heads num_k_heads = linear_attention_config.linear_num_key_heads head_k_dim = linear_attention_config.linear_key_head_dim head_v_dim = linear_attention_config.linear_value_head_dim key_dim = head_k_dim * num_k_heads value_dim = head_v_dim * num_v_heads + conv_kernel_size = linear_attention_config.linear_conv_kernel_dim + conv_dim = key_dim * 2 + value_dim moe_quant_mode = self.config.moe_quant_mode h = self._hidden_size tp_size = self.config.tp_size @@ -923,63 +935,61 @@ def generation_ops(self, linear_attention_config: common.LinearAttentionConfig): fmha_quant_mode = self.config.fmha_quant_mode workload_distribution = self.config.workload_distribution + f"_{self._power_law_alpha}" - # 1 embedding for all layers + # 1 embedding for all layers # 1 norm before attention per layer - self.generation_ops.extend([ops.Embedding(f'context_embedding', 1, self._vocab_size, h, 0.3), - ops.ElementWise(f'context_add_norm_1', self._num_layers, 2*h, 2*h, 0.8)]) + self._generation_ops.extend([ops.Embedding(f'generation_embedding', 1, self._vocab_size, h, 0.3), + ops.ElementWise(f'generation_add_norm_1', self._num_layers, 2*h, 2*h, 0.8)]) # self attention - # (self_attn): Qwen3NextAttention( - # (q_proj): Linear(in_features=2048, out_features=8192, bias=False) out_features = num_heads * head dimension(head_size) - # (k_proj): Linear(in_features=2048, out_features=512, bias=False) out_features = num_key_value_heads * head dimension(head_size) = 2 * 256 = 512 - # (v_proj): Linear(in_features=2048, out_features=512, bias=False) out_features = num_key_value_heads * head dimension(head_size) = 2 * 256 = 512 - # (o_proj): Linear(in_features=4096, out_features=2048, bias=False) out_features = hidden size = 2048 - # (q_norm): Qwen3NextRMSNorm((256,), eps=1e-06) ignored since the latency is ignorable - # (k_norm): Qwen3NextRMSNorm((256,), eps=1e-06) ignored since the latency is ignorable - # context_qkv_gemm: q + k + v - # context_attention: num of heads / tp_size - # context_proj_gemm: num of heads * head_size / tp_size - self.generation_ops.extend([ + self._generation_ops.extend([ ops.GEMM(f'generation_qkv_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, gemm_quant_mode), - ops.GenerationAttention(f'generation_attention', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode), + ops.GenerationAttention(f'generation_attention', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode, fmha_quant_mode), ops.GEMM(f'generation_proj_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), h, self._num_heads*self._head_size//tp_size, gemm_quant_mode), ]) - + # linear attention (Qwen3NextGatedDeltaNet) - self.generation_ops.extend([ - ops.GEMM(f'generation_qkvz_ba_gemm', self._num_layers * linear_attention_config.used_ratio, (key_dim*2 + value_dim*2 + num_v_heads*2)//tp_size, h, gemm_quant_mode), - ops.Conv1D(f'generation_conv1d', self._num_layers * linear_attention_config.used_ratio, self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, kernel_size=4, stride=1, padding=3, groups=self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, bias=False, gemm_quant_mode), - ops.ChunkGatedDeltaRule(f'generation_chunk_gated_delta_rule', self._num_layers * linear_attention_config.used_ratio, in_channels=self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, out_channels=h, kernel_size=4, seq_length=seq_length), + self._generation_ops.extend([ + # Input projections for qkvz and ba + ops.GEMM(f'generation_qkvz_gemm', self._num_layers * linear_attention_config.used_ratio, (key_dim*2 + value_dim*2)//tp_size, h, gemm_quant_mode), + ops.GEMM(f'generation_ba_gemm', self._num_layers * linear_attention_config.used_ratio, num_v_heads*2//tp_size, h, gemm_quant_mode), + # Conv1D and gated delta rule operations - weights handled internally + # TODO: for mixed steps, add ops.Conv1DFn(...) + # Conv1DUpdate(name, scale_factor, conv_kernel_size, conv_dim, tp_size) - batch_size and isl from kwargs + ops.Conv1DUpdate(f'generation_conv1d_update', self._num_layers * linear_attention_config.used_ratio, conv_kernel_size, conv_dim, tp_size), + # GatedDeltaRuleUpdate(name, scale_factor, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size) - batch_size and isl from kwargs + # max_batch_size is dynamic, need to determine appropriate value + ops.GatedDeltaRuleUpdate(f'generation_gated_delta_rule_update', self._num_layers * linear_attention_config.used_ratio, num_k_heads, head_k_dim, head_v_dim, num_v_heads, 1024), + # Output projection ops.GEMM(f'generation_proj_gemm', self._num_layers * linear_attention_config.used_ratio, h, value_dim//tp_size, gemm_quant_mode), ]) # 1 norm before MOE per layer - self.generation_ops.extend([ - ops.ElementWise(f'context_add_norm_2', self._num_layers // linear_attention_config.used_ratio, 2*h, 2*h, 0.8)]) + self._generation_ops.extend([ + ops.ElementWise(f'generation_add_norm_2', self._num_layers // linear_attention_config.used_ratio, 2*h, 2*h, 0.8)]) #router, only take it into account when num_experts >= 128 if self._num_experts >= 128: - self.generation_ops.extend([ + self._generation_ops.extend([ ops.GEMM(f'generation_router_gemm', self._num_layers, self._num_experts, h, common.GEMMQuantMode.float16) ]) # dispatch tokens to experts, moe calc and get tokens back # Qwen3Next has one more shared expert. - self.generation_ops.extend([ + self._generation_ops.extend([ ops.MoEDispatch(f'generation_moe_pre_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, True), ops.MoE(f'generation_moe', self._num_layers, h, self._moe_inter_size, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, moe_quant_mode, workload_distribution, attention_dp_size), ops.MoEDispatch(f'generation_moe_post_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, False) ]) # logits gemm - self.generation_ops.extend([ops.GEMM(f'generation_logits_gemm', 1, self._vocab_size//tp_size, h, common.GEMMQuantMode.float16)]) + self._generation_ops.extend([ops.GEMM(f'generation_logits_gemm', 1, self._vocab_size//tp_size, h, common.GEMMQuantMode.float16)]) # # # when tp_size=0, the comm part will be 0 - # self.generation_ops.append(ops.AllReduce('generation_ar_1', self._num_layers, h, tp_size)) - # self.generation_ops.append(ops.AllReduce('generation_ar_2', self._num_layers, h, tp_size)) + # self._generation_ops.append(ops.AllReduce('generation_ar_1', self._num_layers, h, tp_size)) + # self._generation_ops.append(ops.AllReduce('generation_ar_2', self._num_layers, h, tp_size)) # pp pp_scale_factor = pp_size-1 - self.generation_ops.append(ops.P2P('generation_p2p', pp_scale_factor, h, pp_size)) + self._generation_ops.append(ops.P2P('generation_p2p', pp_scale_factor, h, pp_size)) if __name__ == '__main__': diff --git a/src/aiconfigurator/sdk/operations.py b/src/aiconfigurator/sdk/operations.py index 03d9df87..9072eb78 100755 --- a/src/aiconfigurator/sdk/operations.py +++ b/src/aiconfigurator/sdk/operations.py @@ -580,40 +580,80 @@ def query(self, database:PerfDatabase, **kwargs): def get_weights(self, **kwargs): return self._weights * self._scale_factor -class Conv1D(Operation): +class Conv1DFn(Operation): """ - Conv1D operation. + Conv1DFn operation. """ - def __init__(self, name: str, scale_factor: float, in_channels: int, out_channels: int, kernel_size: int, seq_length: int) -> None: + def __init__(self, name: str, scale_factor: float, conv_kernel_size: int, conv_dim: int, tp_size: int) -> None: super().__init__(name, scale_factor) - self._in_channels = in_channels - self._out_channels = out_channels - self._kernel_size = kernel_size - self._seq_length = seq_length - self._weights = in_channels * out_channels * kernel_size * seq_length + self._conv_kernel_size = conv_kernel_size + self._conv_dim = conv_dim + self._tp_size = tp_size + self._weights = 0.0 def query(self, database:PerfDatabase, **kwargs): - x = kwargs.get('x') - return database.query_conv_1d(x, self._in_channels, self._out_channels, self._kernel_size, self._seq_length)*self._scale_factor + batch_size = kwargs.get('batch_size') + isl = kwargs.get('s') + return database.query_conv1d_fn(batch_size, isl, self._conv_kernel_size, self._conv_dim, self._tp_size)*self._scale_factor def get_weights(self, **kwargs): return self._weights * self._scale_factor +class Conv1DUpdate(Operation): + """ + Conv1DUpdate operation. + """ + def __init__(self, name: str, scale_factor: float, conv_kernel_size: int, conv_dim: int, tp_size: int) -> None: + super().__init__(name, scale_factor) + self._conv_kernel_size = conv_kernel_size + self._conv_dim = conv_dim + self._tp_size = tp_size + self._weights = 0.0 + + def query(self, database:PerfDatabase, **kwargs): + batch_size = kwargs.get('batch_size') + isl = kwargs.get('s') + return database.query_conv1d_update(batch_size, isl, self._conv_kernel_size, self._conv_dim, self._tp_size)*self._scale_factor + + def get_weights(self, **kwargs): + return self._weights * self._scale_factor + class ChunkGatedDeltaRule(Operation): """ Chunk gated delta rule operation. """ - def __init__(self, name: str, scale_factor: float, in_channels: int, out_channels: int, kernel_size: int, seq_length: int) -> None: + def __init__(self, name: str, scale_factor: float, num_heads: int, head_k_dim: int, head_v_dim: int, num_value_heads: int) -> None: super().__init__(name, scale_factor) - self._in_channels = in_channels - self._out_channels = out_channels - self._kernel_size = kernel_size - self._seq_length = seq_length - self._weights = in_channels * out_channels * kernel_size * seq_length + self._num_heads = num_heads + self._head_k_dim = head_k_dim + self._head_v_dim = head_v_dim + self._num_value_heads = num_value_heads + self._weights = 0.0 def query(self, database:PerfDatabase, **kwargs): - x = kwargs.get('x') - return database.query_chunk_gated_delta_rule(x, self._in_channels, self._out_channels, self._kernel_size, self._seq_length)*self._scale_factor + isl = kwargs.get('s') + return database.query_chunk_gated_delta_rule(self._num_heads, self._head_k_dim, self._head_v_dim, self._num_value_heads, isl)*self._scale_factor + + def get_weights(self, **kwargs): + return self._weights * self._scale_factor + +class GatedDeltaRuleUpdate(Operation): + """ + Gated delta rule update operation. + """ + def __init__(self, name: str, scale_factor: float, num_heads: int, head_k_dim: int, head_v_dim: int, num_value_heads: int, max_batch_size: int) -> None: + super().__init__(name, scale_factor) + self._num_heads = num_heads + self._head_k_dim = head_k_dim + self._head_v_dim = head_v_dim + self._num_value_heads = num_value_heads + self._max_batch_size = max_batch_size + self._weights = 0.0 + + def query(self, database:PerfDatabase, **kwargs): + batch_size = kwargs.get('batch_size') + isl = kwargs.get('s') + return database.query_gated_delta_rule_update(batch_size, isl, self._num_heads, self._head_k_dim, self._head_v_dim, self._num_value_heads, self._max_batch_size)*self._scale_factor def get_weights(self, **kwargs): return self._weights * self._scale_factor diff --git a/src/aiconfigurator/sdk/perf_database.py b/src/aiconfigurator/sdk/perf_database.py index e3e37331..9068a821 100755 --- a/src/aiconfigurator/sdk/perf_database.py +++ b/src/aiconfigurator/sdk/perf_database.py @@ -2015,62 +2015,223 @@ def get_sol(num_tokens: int, num_experts: int, topk: int, hidden_size: int) -> T lat = self._interp_2d_linear(sms, num_tokens, data) return lat / 1000.0 - def query_conv_1d(self, - in_channels : int, - out_channels : int, - kernel_size : int, - seq_length : int, - sol_mode : Optional[common.SOLMode] = None) -> float: + def query_conv1d_fn(self, + batch_size: int, + isl: int, + conv_kernel_size: int, + conv_dim: int, + tp_size: int, + sol_mode: Optional[common.SOLMode] = None) -> float: """ - Query the conv1d data + Query the Conv1D Fn operation data. + + Args: + batch_size: Batch size + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Tensor parallel size + sol_mode: SOL mode for theoretical performance calculation + + Returns: + Latency in milliseconds """ - def get_sol(in_channels : int, out_channels : int, kernel_size : int, seq_length : int) -> Tuple[float, float, float]: + def get_sol(batch_size: int, isl: int, conv_kernel_size: int, conv_dim: int, tp_size: int) -> Tuple[float, float, float]: """ - Get the sol time, sol math and sol mem + Get the sol time, sol math and sol mem for Conv1D Fn """ - sol_math = in_channels * out_channels * kernel_size * seq_length / (self.system_spec['gpu']['float16_tc_flops']*1) * 1000 - sol_mem = in_channels * out_channels * kernel_size * seq_length / self.system_spec['gpu']['mem_bw'] * 1000 + # Conv1D operations: batch_size * isl * (conv_dim // tp_size) * conv_kernel_size + ops = batch_size * isl * (conv_dim // tp_size) * conv_kernel_size * 2 # 2 for FMA + mem_bytes = 2 * ( # Assuming fp16/bf16 + batch_size * isl * (conv_dim // tp_size) + # Input + (conv_dim // tp_size) * conv_kernel_size + # Weights + batch_size * isl * (conv_dim // tp_size) # Output + ) + sol_math = ops / self.system_spec['gpu']['float16_tc_flops'] * 1000 + sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 sol_time = max(sol_math, sol_mem) return sol_time, sol_math, sol_mem if sol_mode is None: sol_mode = self._default_sol_mode if sol_mode == common.SOLMode.SOL: - return get_sol(in_channels, out_channels, kernel_size, seq_length)[0] + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size)[0] elif sol_mode == common.SOLMode.SOL_FULL: - return get_sol(in_channels, out_channels, kernel_size, seq_length) + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size) else: - result = self._interp_3d(in_channels, out_channels, kernel_size, seq_length, self._conv_1d_data, 'cubic') - return result + # TODO: Add actual data interpolation when measurement data is available + # For now, return SOL estimation + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size)[0] + + def query_conv1d_update(self, + batch_size: int, + isl: int, + conv_kernel_size: int, + conv_dim: int, + tp_size: int, + sol_mode: Optional[common.SOLMode] = None) -> float: + """ + Query the Conv1D Update operation data. + + Args: + batch_size: Batch size + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Tensor parallel size + sol_mode: SOL mode for theoretical performance calculation + + Returns: + Latency in milliseconds + """ + def get_sol(batch_size: int, isl: int, conv_kernel_size: int, conv_dim: int, tp_size: int) -> Tuple[float, float, float]: + """ + Get the sol time, sol math and sol mem for Conv1D Update + """ + # Conv1D update is typically lighter than full conv1d_fn + ops = batch_size * isl * (conv_dim // tp_size) * conv_kernel_size * 2 # 2 for FMA + mem_bytes = 2 * ( # Assuming fp16/bf16 + batch_size * isl * (conv_dim // tp_size) + # Input + (conv_dim // tp_size) * conv_kernel_size + # Weights + batch_size * isl * (conv_dim // tp_size) # Output + ) + sol_math = ops / self.system_spec['gpu']['float16_tc_flops'] * 1000 + sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size)[0] + elif sol_mode == common.SOLMode.SOL_FULL: + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size) + else: + # TODO: Add actual data interpolation when measurement data is available + # For now, return SOL estimation + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size)[0] def query_chunk_gated_delta_rule(self, - - num_tokens : int, - in_channels : int, - out_channels : int, - kernel_size : int, - seq_length : int, - sol_mode : Optional[common.SOLMode] = None) -> float: + num_heads: int, + head_k_dim: int, + head_v_dim: int, + num_value_heads: int, + isl: int, + sol_mode: Optional[common.SOLMode] = None) -> float: + """ + Query the Chunk Gated Delta Rule operation data. + + Args: + num_heads: Number of heads + head_k_dim: Dimension of the key heads + head_v_dim: Dimension of the value heads + num_value_heads: Number of value heads + isl: Sequence length + sol_mode: SOL mode for theoretical performance calculation + + Returns: + Latency in milliseconds """ - Query the chunk gated delta rule data + def get_sol(num_heads: int, head_k_dim: int, head_v_dim: int, num_value_heads: int, isl: int) -> Tuple[float, float, float]: + """ + Get the sol time, sol math and sol mem for Chunk Gated Delta Rule + """ + # Gated delta rule involves attention-like operations + # Operations: q*k^T, gating, and weighted sum with values + ops = ( + num_heads * isl * isl * head_k_dim * 2 + # q*k^T + num_heads * isl * isl * 2 + # gating operations + num_value_heads * isl * isl * head_v_dim * 2 # weighted sum with values + ) + mem_bytes = 2 * ( # Assuming fp16/bf16 + num_heads * isl * head_k_dim + # Q + num_heads * isl * head_k_dim + # K + num_value_heads * isl * head_v_dim + # V + num_heads * isl + # gate + num_heads * isl + # beta + num_value_heads * isl * head_v_dim # output + ) + sol_math = ops / self.system_spec['gpu']['float16_tc_flops'] * 1000 + sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: + return get_sol(num_heads, head_k_dim, head_v_dim, num_value_heads, isl)[0] + elif sol_mode == common.SOLMode.SOL_FULL: + return get_sol(num_heads, head_k_dim, head_v_dim, num_value_heads, isl) + else: + # TODO: Add actual data interpolation when measurement data is available + # For now, return SOL estimation + return get_sol(num_heads, head_k_dim, head_v_dim, num_value_heads, isl)[0] + + def query_gated_delta_rule_update(self, + batch_size: int, + isl: int, + num_heads: int, + head_k_dim: int, + head_v_dim: int, + num_value_heads: int, + max_batch_size: int, + sol_mode: Optional[common.SOLMode] = None) -> float: + """ + Query the Gated Delta Rule Update operation data. + + Args: + batch_size: Batch size + isl: Sequence length + num_heads: Number of heads + head_k_dim: Dimension of the key heads + head_v_dim: Dimension of the value heads + num_value_heads: Number of value heads + max_batch_size: Maximum batch size + sol_mode: SOL mode for theoretical performance calculation + + Returns: + Latency in milliseconds """ - def get_sol(num_tokens : int, in_channels : int, out_channels : int, kernel_size : int, seq_length : int) -> Tuple[float, float, float]: + def get_sol(batch_size: int, isl: int, num_heads: int, head_k_dim: int, + head_v_dim: int, num_value_heads: int, max_batch_size: int) -> Tuple[float, float, float]: """ - Get the sol time, sol math and sol mem + Get the sol time, sol math and sol mem for Gated Delta Rule Update """ - sol_math = num_tokens * in_channels * out_channels * kernel_size * seq_length / (self.system_spec['gpu']['float16_tc_flops']*1) * 1000 - sol_mem = in_channels * out_channels * kernel_size * seq_length / self.system_spec['gpu']['mem_bw'] * 1000 + # Fused sigmoid gating delta rule update involves state updates + ops = ( + batch_size * isl * num_heads * head_k_dim * 2 + # q processing + batch_size * isl * num_heads * head_k_dim * 2 + # k processing + batch_size * isl * num_value_heads * head_v_dim * 2 + # v processing + batch_size * isl * num_heads * num_value_heads * 2 + # gating operations + max_batch_size * num_heads * num_value_heads * head_k_dim * head_v_dim * 2 # state operations + ) + mem_bytes = 2 * ( # Assuming fp16/bf16 + num_heads * num_value_heads + # A_log + num_heads * num_value_heads + # dt_bias + batch_size * isl * num_heads * head_k_dim + # q + batch_size * isl * num_heads * head_k_dim + # k + batch_size * isl * num_value_heads * head_v_dim + # v + batch_size * isl * num_heads * num_value_heads + # a + batch_size * isl * num_heads * num_value_heads + # b + max_batch_size * num_heads * num_value_heads * head_k_dim * head_v_dim + # initial_state_source + batch_size # initial_state_indices + ) + sol_math = ops / self.system_spec['gpu']['float16_tc_flops'] * 1000 + sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 sol_time = max(sol_math, sol_mem) return sol_time, sol_math, sol_mem + if sol_mode is None: sol_mode = self._default_sol_mode if sol_mode == common.SOLMode.SOL: - return get_sol(num_tokens, in_channels, out_channels, kernel_size, seq_length)[0] + return get_sol(batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size)[0] elif sol_mode == common.SOLMode.SOL_FULL: - return get_sol(num_tokens, in_channels, out_channels, kernel_size, seq_length) + return get_sol(batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size) else: - result = self._interp_3d(num_tokens, in_channels, out_channels, kernel_size, seq_length, self._chunk_gated_delta_rule_data, 'cubic') - return result + # TODO: Add actual data interpolation when measurement data is available + # For now, return SOL estimation + return get_sol(batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size)[0] + if __name__ == '__main__': database_dict = get_all_databases() \ No newline at end of file From d61fac5c0d45264db4281d3596868b45a50c1e94 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 10:18:38 -0700 Subject: [PATCH 04/17] comment out other collections --- collector/collect.py | 152 +++++++++++++++++++++---------------------- 1 file changed, 76 insertions(+), 76 deletions(-) diff --git a/collector/collect.py b/collector/collect.py index bd1d568f..2d5a83f9 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -354,86 +354,86 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): # Define collection modules - each test type as separate entry collections = [ - # GEMM collections - { - 'name': 'trtllm', - 'type': 'gemm_trt', - 'module': 'trtllm.collect_gemm_trt', - 'get_func': 'get_gemm_test_cases', - 'run_func': 'run_gemm' - }, - { - 'name': 'trtllm', - 'type': 'gemm', - 'module': 'trtllm.collect_gemm', - 'get_func': 'get_gemm_test_cases', - 'run_func': 'run_gemm' - }, + # # GEMM collections + # { + # 'name': 'trtllm', + # 'type': 'gemm_trt', + # 'module': 'trtllm.collect_gemm_trt', + # 'get_func': 'get_gemm_test_cases', + # 'run_func': 'run_gemm' + # }, + # { + # 'name': 'trtllm', + # 'type': 'gemm', + # 'module': 'trtllm.collect_gemm', + # 'get_func': 'get_gemm_test_cases', + # 'run_func': 'run_gemm' + # }, - # MLA collections - { - 'name': 'trtllm', - 'type': 'mla_context', - 'module': 'trtllm.collect_mla', - 'get_func': 'get_context_mla_test_cases', - 'run_func': 'run_mla', - 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') - else 'trtllm.collect_mla' - }, - { - 'name': 'trtllm', - 'type': 'mla_generation', - 'module': 'trtllm.collect_mla', - 'get_func': 'get_generation_mla_test_cases', - 'run_func': 'run_mla', - 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') - else 'trtllm.collect_mla' - }, + # # MLA collections + # { + # 'name': 'trtllm', + # 'type': 'mla_context', + # 'module': 'trtllm.collect_mla', + # 'get_func': 'get_context_mla_test_cases', + # 'run_func': 'run_mla', + # 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') + # else 'trtllm.collect_mla' + # }, + # { + # 'name': 'trtllm', + # 'type': 'mla_generation', + # 'module': 'trtllm.collect_mla', + # 'get_func': 'get_generation_mla_test_cases', + # 'run_func': 'run_mla', + # 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') + # else 'trtllm.collect_mla' + # }, - # Attention collections - separate entries for context and generation - { - 'name': 'trtllm', - 'type': 'attention_context', - 'module': 'trtllm.collect_attn', - 'get_func': 'get_context_attention_test_cases', - 'run_func': 'run_attention_torch' - }, - { - 'name': 'trtllm', - 'type': 'attention_generation', - 'module': 'trtllm.collect_attn', - 'get_func': 'get_generation_attention_test_cases', - 'run_func': 'run_attention_torch' - }, + # # Attention collections - separate entries for context and generation + # { + # 'name': 'trtllm', + # 'type': 'attention_context', + # 'module': 'trtllm.collect_attn', + # 'get_func': 'get_context_attention_test_cases', + # 'run_func': 'run_attention_torch' + # }, + # { + # 'name': 'trtllm', + # 'type': 'attention_generation', + # 'module': 'trtllm.collect_attn', + # 'get_func': 'get_generation_attention_test_cases', + # 'run_func': 'run_attention_torch' + # }, - # MLA BMM collections - { - 'name': 'trtllm', - 'type': 'mla_bmm_gen_pre', - 'module': 'trtllm.collect_mla_bmm', - 'get_func': 'get_mla_gen_pre_test_cases', - 'run_func': 'run_mla_gen_pre' - }, - { - 'name': 'trtllm', - 'type': 'mla_bmm_gen_post', - 'module': 'trtllm.collect_mla_bmm', - 'get_func': 'get_mla_gen_post_test_cases', - 'run_func': 'run_mla_gen_post' - }, + # # MLA BMM collections + # { + # 'name': 'trtllm', + # 'type': 'mla_bmm_gen_pre', + # 'module': 'trtllm.collect_mla_bmm', + # 'get_func': 'get_mla_gen_pre_test_cases', + # 'run_func': 'run_mla_gen_pre' + # }, + # { + # 'name': 'trtllm', + # 'type': 'mla_bmm_gen_post', + # 'module': 'trtllm.collect_mla_bmm', + # 'get_func': 'get_mla_gen_post_test_cases', + # 'run_func': 'run_mla_gen_post' + # }, - # MOE collection (with version handling) - { - 'name': 'trtllm', - 'type': 'moe', - 'module': None, # Will be determined based on version - 'get_func': 'get_moe_test_cases', - 'run_func': 'run_moe_torch', - 'version_handler': lambda v: 'trtllm.collect_moe_pre_0_20' if v.startswith('0.20.0') - else 'trtllm.collect_moe_pre_1_0' if v.startswith(('0.21.0', '1.0.0')) - else 'trtllm.collect_moe' if v.startswith(('1.1.0')) - else None - }, + # # MOE collection (with version handling) + # { + # 'name': 'trtllm', + # 'type': 'moe', + # 'module': None, # Will be determined based on version + # 'get_func': 'get_moe_test_cases', + # 'run_func': 'run_moe_torch', + # 'version_handler': lambda v: 'trtllm.collect_moe_pre_0_20' if v.startswith('0.20.0') + # else 'trtllm.collect_moe_pre_1_0' if v.startswith(('0.21.0', '1.0.0')) + # else 'trtllm.collect_moe' if v.startswith(('1.1.0')) + # else None + # }, # CONV 1D collections { From d1ec7cbf302c42bfecdae42178dba828bdae2f9f Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 11:00:05 -0700 Subject: [PATCH 05/17] fix --- collector/collect.py | 44 +++++++++---------- .../{collect_conv_1d.py => collect_conv1d.py} | 2 +- collector/trtllm/collect_gated_delta_rule.py | 2 +- 3 files changed, 24 insertions(+), 24 deletions(-) rename collector/trtllm/{collect_conv_1d.py => collect_conv1d.py} (98%) diff --git a/collector/collect.py b/collector/collect.py index 2d5a83f9..031e635f 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -438,34 +438,34 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): # CONV 1D collections { 'name': 'trtllm', - 'type': 'conv_1d_fn', - 'module': 'trtllm.collect_conv_1d', - 'get_func': 'get_conv_1d_fn_test_cases', - 'run_func': 'run_conv_1d_fn' + 'type': 'conv1d_fn', + 'module': 'trtllm.collect_conv1d', + 'get_func': 'get_conv1d_fn_test_cases', + 'run_func': 'run_conv1d_fn' }, { 'name': 'trtllm', - 'type': 'conv_1d_update', - 'module': 'trtllm.collect_conv_1d', - 'get_func': 'get_conv_1d_update_test_cases', - 'run_func': 'run_conv_1d_update' + 'type': 'conv1d_update', + 'module': 'trtllm.collect_conv1d', + 'get_func': 'get_conv1d_update_test_cases', + 'run_func': 'run_conv1d_update' }, # Gated Delta Rule collections - { - 'name': 'trtllm', - 'type': 'chunk_gated_delta_rule', - 'module': 'trtllm.collect_gated_delta_rule', - 'get_func': 'get_chunk_gated_delta_rule_test_cases', - 'run_func': 'run_chunk_gated_delta_rule' - }, - { - 'name': 'trtllm', - 'type': 'gated_delta_rule_update', - 'module': 'trtllm.collect_gated_delta_rule', - 'get_func': 'get_gated_delta_rule_update_test_cases', - 'run_func': 'run_gated_delta_rule_update' - }, + # { + # 'name': 'trtllm', + # 'type': 'chunk_gated_delta_rule', + # 'module': 'trtllm.collect_gated_delta_rule', + # 'get_func': 'get_chunk_gated_delta_rule_test_cases', + # 'run_func': 'run_chunk_gated_delta_rule' + # }, + # { + # 'name': 'trtllm', + # 'type': 'gated_delta_rule_update', + # 'module': 'trtllm.collect_gated_delta_rule', + # 'get_func': 'get_gated_delta_rule_update_test_cases', + # 'run_func': 'run_gated_delta_rule_update' + # }, ] for collection in collections: diff --git a/collector/trtllm/collect_conv_1d.py b/collector/trtllm/collect_conv1d.py similarity index 98% rename from collector/trtllm/collect_conv_1d.py rename to collector/trtllm/collect_conv1d.py index bf91ff19..25e20df5 100644 --- a/collector/trtllm/collect_conv_1d.py +++ b/collector/trtllm/collect_conv1d.py @@ -6,7 +6,7 @@ from cuda import cuda import torch import tensorrt_llm -from tensorrt_llm.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from helper import log_perf def get_conv1d_fn_test_cases(): diff --git a/collector/trtllm/collect_gated_delta_rule.py b/collector/trtllm/collect_gated_delta_rule.py index 51a62e5d..a724f103 100644 --- a/collector/trtllm/collect_gated_delta_rule.py +++ b/collector/trtllm/collect_gated_delta_rule.py @@ -6,7 +6,7 @@ from cuda import cuda import torch import tensorrt_llm -from tensorrt_llm.modules.flash_attention import chunk_gated_delta_rule, fused_sigmoid_gating_delta_rule_update +from tensorrt_llm._torch.modules.flash_attention import chunk_gated_delta_rule, fused_sigmoid_gating_delta_rule_update from helper import log_perf def get_chunk_gated_delta_rule_test_cases(): From 9e813f79e69946823c1090fa2449e18d436c75b1 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 11:03:55 -0700 Subject: [PATCH 06/17] add fix --- collector/trtllm/collect_conv1d.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/collector/trtllm/collect_conv1d.py b/collector/trtllm/collect_conv1d.py index 25e20df5..5aed92dc 100644 --- a/collector/trtllm/collect_conv1d.py +++ b/collector/trtllm/collect_conv1d.py @@ -51,17 +51,17 @@ def run_conv1d_fn(batch_size, isl, conv_kernel_size, conv_dim, tp_size, perf_fil device: CUDA device to use """ dtype = torch.bfloat16 - mixed_qkv = torch.randn((batch_size * isl, conv_dim // tp_size), dtype=dtype).to(torch.device(device)) + # Create input with proper 3D shape: (batch_size, dim, seqlen) + mixed_qkv = torch.randn((batch_size, conv_dim // tp_size, isl), dtype=dtype, device=device) conv1d_weights = torch.randn((conv_dim // tp_size, conv_kernel_size), dtype=dtype, device=device) g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): - mixed_qkv_trans = mixed_qkv.transpose(0, 1) # TODO: measure optional arguments causal_conv1d_fn( - mixed_qkv_trans, + mixed_qkv, conv1d_weights, - ).transpose(0, 1) + ) num_warmups = 3 num_runs = 6 From 6b819659f028c9fc6e70f17daf4157d83f18cc15 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 17:04:12 -0700 Subject: [PATCH 07/17] conv1d_update only --- collector/collect.py | 14 +++++++------- collector/trtllm/collect_conv1d.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/collector/collect.py b/collector/collect.py index 031e635f..efac7048 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -436,13 +436,13 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): # }, # CONV 1D collections - { - 'name': 'trtllm', - 'type': 'conv1d_fn', - 'module': 'trtllm.collect_conv1d', - 'get_func': 'get_conv1d_fn_test_cases', - 'run_func': 'run_conv1d_fn' - }, + # { + # 'name': 'trtllm', + # 'type': 'conv1d_fn', + # 'module': 'trtllm.collect_conv1d', + # 'get_func': 'get_conv1d_fn_test_cases', + # 'run_func': 'run_conv1d_fn' + # }, { 'name': 'trtllm', 'type': 'conv1d_update', diff --git a/collector/trtllm/collect_conv1d.py b/collector/trtllm/collect_conv1d.py index 5aed92dc..26c702a1 100644 --- a/collector/trtllm/collect_conv1d.py +++ b/collector/trtllm/collect_conv1d.py @@ -179,7 +179,7 @@ def run_conv1d_update(batch_size, isl, conv_kernel_size, conv_dim, tp_size, perf framework='TRTLLM', version=tensorrt_llm.__version__, device_name=torch.cuda.get_device_name(device), - op_name='conv1d_fn', + op_name='conv1d_update', kernel_source='default', perf_filename=perf_filename ) From 10d56df1095fb4cd1b11fe508812b836a5cb50a2 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 17:11:06 -0700 Subject: [PATCH 08/17] fix conv1d_update collector --- collector/trtllm/collect_conv1d.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/collector/trtllm/collect_conv1d.py b/collector/trtllm/collect_conv1d.py index 26c702a1..1b55b644 100644 --- a/collector/trtllm/collect_conv1d.py +++ b/collector/trtllm/collect_conv1d.py @@ -103,43 +103,48 @@ def get_conv1d_update_test_cases(): Test parameters: - batch_size: batch size - - isl: sequence length - conv_kernel_size: size of the convolution kernel - conv_dim: dimension of the convolution - tp_size: attention tensor parallel size + + Note: isl (sequence length) is not used for conv1d_update as it processes + individual tokens in incremental/streaming inference mode. """ b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] - s_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] tp_sizes = [1, 2, 4, 8] conv_dims = [1,2,4,8,16,32] kernel_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] test_cases = [] for batch_size in b_list: - for isl in s_list: - for tp_size in tp_sizes: - for conv_dim in conv_dims: - for kernel_size in kernel_sizes: - test_cases.append([batch_size, isl, kernel_size, conv_dim, tp_size, 'conv1d_update_perf.txt']) + for tp_size in tp_sizes: + for conv_dim in conv_dims: + for kernel_size in kernel_sizes: + test_cases.append([batch_size, kernel_size, conv_dim, tp_size, 'conv1d_update_perf.txt']) return test_cases -def run_conv1d_update(batch_size, isl, conv_kernel_size, conv_dim, tp_size, perf_filename, device='cuda:0'): +def run_conv1d_update(batch_size, conv_kernel_size, conv_dim, tp_size, perf_filename, device='cuda:0'): """ Run Conv1DUpdate performance benchmarking. Args: batch_size: Batch size - isl: Sequence length conv_kernel_size: Size of the convolution kernel conv_dim: Dimension of the convolution tp_size: Attention tensor parallel size perf_filename: Output file for performance results device: CUDA device to use + + Note: isl (sequence length) is not used as conv1d_update processes individual + tokens in incremental/streaming inference mode. """ dtype = torch.bfloat16 - mixed_qkv = torch.randn((batch_size * isl, conv_dim // tp_size), dtype=dtype).to(torch.device(device)) + # Create input with shape (batch_size, dim) + mixed_qkv = torch.randn((batch_size, conv_dim // tp_size), dtype=dtype, device=device) + # Create conv_state with shape (batch_size, dim, kernel_size - 1) + conv_state = torch.randn((batch_size, conv_dim // tp_size, conv_kernel_size - 1), dtype=dtype, device=device) conv1d_weights = torch.randn((conv_dim // tp_size, conv_kernel_size), dtype=dtype, device=device) g = torch.cuda.CUDAGraph() @@ -147,6 +152,7 @@ def run_conv1d_update(batch_size, isl, conv_kernel_size, conv_dim, tp_size, perf # TODO: measure optional arguments causal_conv1d_update( mixed_qkv, + conv_state, conv1d_weights, ) @@ -170,7 +176,6 @@ def run_conv1d_update(batch_size, isl, conv_kernel_size, conv_dim, tp_size, perf log_perf( item_list=[{ 'batch_size': batch_size, - 'isl': isl, 'conv_kernel_size': conv_kernel_size, 'conv_dim': conv_dim, 'tp_size': tp_size, From 632b06065886b18578bc62211321cde218cde001 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 17:13:25 -0700 Subject: [PATCH 09/17] 2-4 --- collector/trtllm/collect_conv1d.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/collector/trtllm/collect_conv1d.py b/collector/trtllm/collect_conv1d.py index 1b55b644..32a936a5 100644 --- a/collector/trtllm/collect_conv1d.py +++ b/collector/trtllm/collect_conv1d.py @@ -103,17 +103,18 @@ def get_conv1d_update_test_cases(): Test parameters: - batch_size: batch size - - conv_kernel_size: size of the convolution kernel + - conv_kernel_size: size of the convolution kernel (must be between 2 and 4) - conv_dim: dimension of the convolution - tp_size: attention tensor parallel size Note: isl (sequence length) is not used for conv1d_update as it processes individual tokens in incremental/streaming inference mode. + Note: causal_conv1d_update only supports kernel widths between 2 and 4. """ b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] tp_sizes = [1, 2, 4, 8] conv_dims = [1,2,4,8,16,32] - kernel_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + kernel_sizes = [2, 3, 4] # causal_conv1d_update only supports widths 2-4 test_cases = [] for batch_size in b_list: From f80eb6353c3ee3d1d0845db3c9bfec7b6d16db70 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 17:19:12 -0700 Subject: [PATCH 10/17] collect chunk_gated_delta_rule --- collector/collect.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/collector/collect.py b/collector/collect.py index efac7048..b8f9f090 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -443,22 +443,22 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): # 'get_func': 'get_conv1d_fn_test_cases', # 'run_func': 'run_conv1d_fn' # }, - { - 'name': 'trtllm', - 'type': 'conv1d_update', - 'module': 'trtllm.collect_conv1d', - 'get_func': 'get_conv1d_update_test_cases', - 'run_func': 'run_conv1d_update' - }, - - # Gated Delta Rule collections # { # 'name': 'trtllm', - # 'type': 'chunk_gated_delta_rule', - # 'module': 'trtllm.collect_gated_delta_rule', - # 'get_func': 'get_chunk_gated_delta_rule_test_cases', - # 'run_func': 'run_chunk_gated_delta_rule' + # 'type': 'conv1d_update', + # 'module': 'trtllm.collect_conv1d', + # 'get_func': 'get_conv1d_update_test_cases', + # 'run_func': 'run_conv1d_update' # }, + + # Gated Delta Rule collections + { + 'name': 'trtllm', + 'type': 'chunk_gated_delta_rule', + 'module': 'trtllm.collect_gated_delta_rule', + 'get_func': 'get_chunk_gated_delta_rule_test_cases', + 'run_func': 'run_chunk_gated_delta_rule' + }, # { # 'name': 'trtllm', # 'type': 'gated_delta_rule_update', From 8a26b63bdf6d95e6e67fa5c4434e46e74031dbfa Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 17:20:53 -0700 Subject: [PATCH 11/17] fix dependencies --- collector/trtllm/collect_gated_delta_rule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/collector/trtllm/collect_gated_delta_rule.py b/collector/trtllm/collect_gated_delta_rule.py index a724f103..3cbc2501 100644 --- a/collector/trtllm/collect_gated_delta_rule.py +++ b/collector/trtllm/collect_gated_delta_rule.py @@ -6,7 +6,8 @@ from cuda import cuda import torch import tensorrt_llm -from tensorrt_llm._torch.modules.flash_attention import chunk_gated_delta_rule, fused_sigmoid_gating_delta_rule_update +from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule +from tensorrt_llm._torch.modules.fla.fused_sigmoid_gating_recurrent import fused_sigmoid_gating_delta_rule_update from helper import log_perf def get_chunk_gated_delta_rule_test_cases(): From 4c3cdfb60fcf307a0410f5343d00b3becbdecc69 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 19:42:25 -0700 Subject: [PATCH 12/17] fix k --- collector/trtllm/collect_gated_delta_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/collector/trtllm/collect_gated_delta_rule.py b/collector/trtllm/collect_gated_delta_rule.py index 3cbc2501..3574dc0c 100644 --- a/collector/trtllm/collect_gated_delta_rule.py +++ b/collector/trtllm/collect_gated_delta_rule.py @@ -54,7 +54,7 @@ def run_chunk_gated_delta_rule(num_heads, head_k_dim, head_v_dim, num_value_head # NOTICE: ignored fused_gdn_gating operation dtype = torch.bfloat16 q = torch.randn((1, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) - k = torch.randn((1, isl, num_heads, head_v_dim), dtype=dtype).to(torch.device(device)) + k = torch.randn((1, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) v = torch.randn((1, isl, num_value_heads, head_v_dim), dtype=dtype).to(torch.device(device)) gate = torch.randn((1, isl, num_heads), dtype=dtype).to(torch.device(device)) beta = torch.randn((1, isl, num_heads), dtype=dtype).to(torch.device(device)) From bfba6f7139f8ba857f3268b4546f2c6dd58df070 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 20:05:03 -0700 Subject: [PATCH 13/17] fix i --- collector/trtllm/collect_gated_delta_rule.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/collector/trtllm/collect_gated_delta_rule.py b/collector/trtllm/collect_gated_delta_rule.py index 3574dc0c..2315e909 100644 --- a/collector/trtllm/collect_gated_delta_rule.py +++ b/collector/trtllm/collect_gated_delta_rule.py @@ -32,6 +32,10 @@ def get_chunk_gated_delta_rule_test_cases(): for head_k_dim in head_k_dim_list: for head_v_dim in head_v_dim_list: for num_value_heads in num_value_heads_list: + # Skip invalid combinations: num_heads must be >= num_value_heads and divisible by it + # This constraint is typical for Grouped-Query Attention (GQA) + if num_heads < num_value_heads or num_heads % num_value_heads != 0: + continue for isl in isl_list: test_cases.append([num_heads, head_k_dim, head_v_dim, num_value_heads, isl, 'chunk_gated_delta_rule_perf.txt']) @@ -125,7 +129,14 @@ def get_gated_delta_rule_update_test_cases(): for head_k_dim in head_k_dim_list: for head_v_dim in head_v_dim_list: for num_value_heads in num_value_heads_list: + # Skip invalid combinations: num_heads must be >= num_value_heads and divisible by it + # This constraint is typical for Grouped-Query Attention (GQA) + if num_heads < num_value_heads or num_heads % num_value_heads != 0: + continue for max_batch_size in max_batch_size_list: + # max_batch_size must be >= batch_size + if max_batch_size < batch_size: + continue test_cases.append([batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size, 'gated_delta_rule_update_perf.txt']) return test_cases @@ -153,7 +164,8 @@ def run_gated_delta_rule_update(batch_size, isl, num_heads, head_k_dim, head_v_d a = torch.randn((batch_size * isl, num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) b = torch.randn((batch_size, isl, num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) initial_state_source = torch.randn((max_batch_size, num_heads * num_value_heads, head_k_dim, head_v_dim), dtype=dtype).to(torch.device(device)) - initial_state_indices = torch.randn((batch_size), dtype=dtype).to(torch.device(device)) + # initial_state_indices should be integers, not floats - they index into initial_state_source + initial_state_indices = torch.randint(0, max_batch_size, (batch_size,), dtype=torch.int32, device=device) softplus_beta = 1.0 softplus_threshold = 20.0 From f1ac24dd3956965599a707fca59f5a19d1b77237 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Sun, 19 Oct 2025 20:55:14 -0700 Subject: [PATCH 14/17] collect gated_delta_rule_update --- collector/collect.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/collector/collect.py b/collector/collect.py index b8f9f090..3215ba2c 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -452,20 +452,20 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): # }, # Gated Delta Rule collections - { - 'name': 'trtllm', - 'type': 'chunk_gated_delta_rule', - 'module': 'trtllm.collect_gated_delta_rule', - 'get_func': 'get_chunk_gated_delta_rule_test_cases', - 'run_func': 'run_chunk_gated_delta_rule' - }, # { # 'name': 'trtllm', - # 'type': 'gated_delta_rule_update', + # 'type': 'chunk_gated_delta_rule', # 'module': 'trtllm.collect_gated_delta_rule', - # 'get_func': 'get_gated_delta_rule_update_test_cases', - # 'run_func': 'run_gated_delta_rule_update' + # 'get_func': 'get_chunk_gated_delta_rule_test_cases', + # 'run_func': 'run_chunk_gated_delta_rule' # }, + { + 'name': 'trtllm', + 'type': 'gated_delta_rule_update', + 'module': 'trtllm.collect_gated_delta_rule', + 'get_func': 'get_gated_delta_rule_update_test_cases', + 'run_func': 'run_gated_delta_rule_update' + }, ] for collection in collections: From ffde28c599bb653855b0f0c490a865d3be77c6ca Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Mon, 20 Oct 2025 14:25:47 -0700 Subject: [PATCH 15/17] collect moe for qwen3 next --- collector/sglang/collect_moe.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/collector/sglang/collect_moe.py b/collector/sglang/collect_moe.py index c471d637..f445828f 100644 --- a/collector/sglang/collect_moe.py +++ b/collector/sglang/collect_moe.py @@ -41,14 +41,15 @@ def get_moe_test_cases(): #[2048,1408,4,60], #qwen1.5_moe #[2048,1408,6,64], #deepseekv1_moe #[5120,1536,6,160], #deepseekv2 - model_config_list=[[4096,14336,2,8,'MOE_Mixtral8x7B'],# mixtral_8x7b - [6144,16384,2,8,'MOE_Mixtral8x22B'],# mixtral_8x22b - [7168,2048,8,256,'DEEPSEEK_V3'], # deepseekv3, will have 1 shared expert - [4096,1536,8,128, 'QWEN3_235B'], # qwen3-moe, 235b-a22b - [6144,2560,8,160, 'QWEN3_480B'], # qwen3-moe, 480b-a35b - [7168,2048,8,384, 'KIMI_K2'], # kimi k2 - [2048,5120,50,512, 'QWEN3_NEXT_80B'], # qwen3-next, 80b-a3b - ] + model_config_list=[ + # [4096,14336,2,8,'MOE_Mixtral8x7B'],# mixtral_8x7b + # [6144,16384,2,8,'MOE_Mixtral8x22B'],# mixtral_8x22b + # [7168,2048,8,256,'DEEPSEEK_V3'], # deepseekv3, will have 1 shared expert + # [4096,1536,8,128, 'QWEN3_235B'], # qwen3-moe, 235b-a22b + # [6144,2560,8,160, 'QWEN3_480B'], # qwen3-moe, 480b-a35b + # [7168,2048,8,384, 'KIMI_K2'], # kimi k2 + [2048,5120,50,512, 'QWEN3_NEXT_80B'], # qwen3-next, 80b-a3b + ] moe_list=['float16', 'fp8_block'] test_cases=[] From 46ca6934248692fb3dddc9ceabee18add171be85 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Mon, 20 Oct 2025 14:28:40 -0700 Subject: [PATCH 16/17] proper collect --- collector/collect.py | 197 ++++++++++++++++++++++--------------------- 1 file changed, 99 insertions(+), 98 deletions(-) diff --git a/collector/collect.py b/collector/collect.py index 3215ba2c..af9d7d9a 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -354,111 +354,111 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): # Define collection modules - each test type as separate entry collections = [ - # # GEMM collections - # { - # 'name': 'trtllm', - # 'type': 'gemm_trt', - # 'module': 'trtllm.collect_gemm_trt', - # 'get_func': 'get_gemm_test_cases', - # 'run_func': 'run_gemm' - # }, - # { - # 'name': 'trtllm', - # 'type': 'gemm', - # 'module': 'trtllm.collect_gemm', - # 'get_func': 'get_gemm_test_cases', - # 'run_func': 'run_gemm' - # }, + # GEMM collections + { + 'name': 'trtllm', + 'type': 'gemm_trt', + 'module': 'trtllm.collect_gemm_trt', + 'get_func': 'get_gemm_test_cases', + 'run_func': 'run_gemm' + }, + { + 'name': 'trtllm', + 'type': 'gemm', + 'module': 'trtllm.collect_gemm', + 'get_func': 'get_gemm_test_cases', + 'run_func': 'run_gemm' + }, - # # MLA collections - # { - # 'name': 'trtllm', - # 'type': 'mla_context', - # 'module': 'trtllm.collect_mla', - # 'get_func': 'get_context_mla_test_cases', - # 'run_func': 'run_mla', - # 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') - # else 'trtllm.collect_mla' - # }, - # { - # 'name': 'trtllm', - # 'type': 'mla_generation', - # 'module': 'trtllm.collect_mla', - # 'get_func': 'get_generation_mla_test_cases', - # 'run_func': 'run_mla', - # 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') - # else 'trtllm.collect_mla' - # }, + # MLA collections + { + 'name': 'trtllm', + 'type': 'mla_context', + 'module': 'trtllm.collect_mla', + 'get_func': 'get_context_mla_test_cases', + 'run_func': 'run_mla', + 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') + else 'trtllm.collect_mla' + }, + { + 'name': 'trtllm', + 'type': 'mla_generation', + 'module': 'trtllm.collect_mla', + 'get_func': 'get_generation_mla_test_cases', + 'run_func': 'run_mla', + 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') + else 'trtllm.collect_mla' + }, - # # Attention collections - separate entries for context and generation - # { - # 'name': 'trtllm', - # 'type': 'attention_context', - # 'module': 'trtllm.collect_attn', - # 'get_func': 'get_context_attention_test_cases', - # 'run_func': 'run_attention_torch' - # }, - # { - # 'name': 'trtllm', - # 'type': 'attention_generation', - # 'module': 'trtllm.collect_attn', - # 'get_func': 'get_generation_attention_test_cases', - # 'run_func': 'run_attention_torch' - # }, + # Attention collections - separate entries for context and generation + { + 'name': 'trtllm', + 'type': 'attention_context', + 'module': 'trtllm.collect_attn', + 'get_func': 'get_context_attention_test_cases', + 'run_func': 'run_attention_torch' + }, + { + 'name': 'trtllm', + 'type': 'attention_generation', + 'module': 'trtllm.collect_attn', + 'get_func': 'get_generation_attention_test_cases', + 'run_func': 'run_attention_torch' + }, - # # MLA BMM collections - # { - # 'name': 'trtllm', - # 'type': 'mla_bmm_gen_pre', - # 'module': 'trtllm.collect_mla_bmm', - # 'get_func': 'get_mla_gen_pre_test_cases', - # 'run_func': 'run_mla_gen_pre' - # }, - # { - # 'name': 'trtllm', - # 'type': 'mla_bmm_gen_post', - # 'module': 'trtllm.collect_mla_bmm', - # 'get_func': 'get_mla_gen_post_test_cases', - # 'run_func': 'run_mla_gen_post' - # }, + # MLA BMM collections + { + 'name': 'trtllm', + 'type': 'mla_bmm_gen_pre', + 'module': 'trtllm.collect_mla_bmm', + 'get_func': 'get_mla_gen_pre_test_cases', + 'run_func': 'run_mla_gen_pre' + }, + { + 'name': 'trtllm', + 'type': 'mla_bmm_gen_post', + 'module': 'trtllm.collect_mla_bmm', + 'get_func': 'get_mla_gen_post_test_cases', + 'run_func': 'run_mla_gen_post' + }, - # # MOE collection (with version handling) - # { - # 'name': 'trtllm', - # 'type': 'moe', - # 'module': None, # Will be determined based on version - # 'get_func': 'get_moe_test_cases', - # 'run_func': 'run_moe_torch', - # 'version_handler': lambda v: 'trtllm.collect_moe_pre_0_20' if v.startswith('0.20.0') - # else 'trtllm.collect_moe_pre_1_0' if v.startswith(('0.21.0', '1.0.0')) - # else 'trtllm.collect_moe' if v.startswith(('1.1.0')) - # else None - # }, + # MOE collection (with version handling) + { + 'name': 'trtllm', + 'type': 'moe', + 'module': None, # Will be determined based on version + 'get_func': 'get_moe_test_cases', + 'run_func': 'run_moe_torch', + 'version_handler': lambda v: 'trtllm.collect_moe_pre_0_20' if v.startswith('0.20.0') + else 'trtllm.collect_moe_pre_1_0' if v.startswith(('0.21.0', '1.0.0')) + else 'trtllm.collect_moe' if v.startswith(('1.1.0')) + else None + }, # CONV 1D collections - # { - # 'name': 'trtllm', - # 'type': 'conv1d_fn', - # 'module': 'trtllm.collect_conv1d', - # 'get_func': 'get_conv1d_fn_test_cases', - # 'run_func': 'run_conv1d_fn' - # }, - # { - # 'name': 'trtllm', - # 'type': 'conv1d_update', - # 'module': 'trtllm.collect_conv1d', - # 'get_func': 'get_conv1d_update_test_cases', - # 'run_func': 'run_conv1d_update' - # }, + { + 'name': 'trtllm', + 'type': 'conv1d_fn', + 'module': 'trtllm.collect_conv1d', + 'get_func': 'get_conv1d_fn_test_cases', + 'run_func': 'run_conv1d_fn' + }, + { + 'name': 'trtllm', + 'type': 'conv1d_update', + 'module': 'trtllm.collect_conv1d', + 'get_func': 'get_conv1d_update_test_cases', + 'run_func': 'run_conv1d_update' + }, # Gated Delta Rule collections - # { - # 'name': 'trtllm', - # 'type': 'chunk_gated_delta_rule', - # 'module': 'trtllm.collect_gated_delta_rule', - # 'get_func': 'get_chunk_gated_delta_rule_test_cases', - # 'run_func': 'run_chunk_gated_delta_rule' - # }, + { + 'name': 'trtllm', + 'type': 'chunk_gated_delta_rule', + 'module': 'trtllm.collect_gated_delta_rule', + 'get_func': 'get_chunk_gated_delta_rule_test_cases', + 'run_func': 'run_chunk_gated_delta_rule' + }, { 'name': 'trtllm', 'type': 'gated_delta_rule_update', @@ -561,7 +561,8 @@ def main(): parser.add_argument('--ops', nargs='*', type=str, choices=['gemm_trt', 'gemm', 'mla_context', 'mla_generation', 'attention_context', 'attention_generation', 'mla_bmm_gen_pre', - 'mla_bmm_gen_post', 'moe'], + 'mla_bmm_gen_post', 'moe', 'conv1d_fn', 'conv1d_update', + 'chunk_gated_delta_rule', 'gated_delta_rule_update'], help='Run only specified collection items. Leave empty to run all.', default=None) args = parser.parse_args() From 36aef0d79c8f2f2c437e3a0b217fed986e82eb07 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Mon, 20 Oct 2025 14:38:27 -0700 Subject: [PATCH 17/17] update versions --- collector/collect.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/collector/collect.py b/collector/collect.py index af9d7d9a..9ac6682c 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -377,7 +377,7 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): 'module': 'trtllm.collect_mla', 'get_func': 'get_context_mla_test_cases', 'run_func': 'run_mla', - 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') + 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1', '1.2') else 'trtllm.collect_mla' }, { @@ -386,7 +386,7 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): 'module': 'trtllm.collect_mla', 'get_func': 'get_generation_mla_test_cases', 'run_func': 'run_mla', - 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') + 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1', '1.2') else 'trtllm.collect_mla' }, @@ -431,7 +431,7 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): 'run_func': 'run_moe_torch', 'version_handler': lambda v: 'trtllm.collect_moe_pre_0_20' if v.startswith('0.20.0') else 'trtllm.collect_moe_pre_1_0' if v.startswith(('0.21.0', '1.0.0')) - else 'trtllm.collect_moe' if v.startswith(('1.1.0')) + else 'trtllm.collect_moe' if v.startswith(('1.1.', '1.2.')) else None },