Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions collector/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,10 @@ def create_process_exit_error(device_id, exit_code):
# Stall detection unchanged...
if progress_value.value == last_progress:
stall_count += 1
if stall_count > 30:
if stall_count > 30 and "moe" not in func.__name__:
logger.warning(f"Progress stalled at {progress_value.value}/{len(tasks)}")
if stall_count > 900 and "moe" in func.__name__:
logger.warning(f"Moe Progress stalled at {progress_value.value}/{len(tasks)}")
else:
stall_count = 0
last_progress = progress_value.value
Expand Down Expand Up @@ -264,7 +266,10 @@ def create_process_exit_error(device_id, exit_code):

# Wait for processes
for p in processes:
p.join(timeout=10)
if "moe" in func.__name__:
p.join(timeout=500) # tune + 30 tokens cases
else:
p.join(timeout=10)
if p.is_alive():
logger.warning(f"Process {p.pid} did not terminate, forcing...")
p.terminate()
Expand Down Expand Up @@ -511,15 +516,19 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None):
"module": "collector.trtllm.collect_mla",
"get_func": "get_context_mla_test_cases",
"run_func": "run_mla",
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2" if v.startswith("1.1") else "trtllm.collect_mla",
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2"
if v.startswith(("1.1.0", "1.2.0"))
else "trtllm.collect_mla",
},
{
"name": "trtllm",
"type": "mla_generation",
"module": "collector.trtllm.collect_mla",
"get_func": "get_generation_mla_test_cases",
"run_func": "run_mla",
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2" if v.startswith("1.1") else "trtllm.collect_mla",
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2"
if v.startswith(("1.1.0", "1.2.0"))
else "trtllm.collect_mla",
},
# Attention collections - separate entries for context and generation
{
Expand Down Expand Up @@ -563,7 +572,7 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None):
else "collector.trtllm.collect_moe_pre_1_0"
if v.startswith(("0.21.0", "1.0.0"))
else "collector.trtllm.collect_moe"
if v.startswith("1.1.0")
if v.startswith(("1.1.0", "1.2.0"))
else None,
},
]
Expand Down
7 changes: 1 addition & 6 deletions collector/collect_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
from argparse import ArgumentParser
from typing import Optional

# isort: off
import torch

# isort: on
from helper import log_perf


def get_input_shape_and_comm_size(size, token_dim=4096):
Expand Down Expand Up @@ -112,8 +111,6 @@ def benchmark_trtllm_allreduce(
num_warmups = 3
num_runs = 20

from helper import log_perf

size = min_size
while size < max_size:
input_shape = get_input_shape_and_comm_size(size)
Expand Down Expand Up @@ -261,8 +258,6 @@ def benchmark_vllm_allreduce(
num_warmups = 3
num_runs = 20

from helper import log_perf

# Warmup communication
warmup_tensor = torch.ones(1, dtype=torch_dtype, device="cuda")
_ = vllm_mods["tensor_model_parallel_all_reduce"](warmup_tensor)
Expand Down
2 changes: 1 addition & 1 deletion collector/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
try:
from cuda import cuda
except:
pass
from cuda.bindings import driver as cuda
from datetime import datetime
from pathlib import Path

Expand Down
6 changes: 5 additions & 1 deletion collector/slurm_comm_collector/collect_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import os

import torch
from cuda import cudart

try:
from cuda import cudart
except:
from cuda.bindings import runtime as cudart
from tensorrt_llm import Mapping
from tensorrt_llm._torch.distributed import (
AllReduce,
Expand Down
6 changes: 5 additions & 1 deletion collector/trtllm/collect_gemm_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import tensorrt as trt
import tensorrt_llm
import torch
from cuda import cudart

try:
from cuda import cudart
except:
from cuda.bindings import runtime as cudart
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner
from tensorrt_llm import Tensor
from tensorrt_llm._utils import str_dtype_to_torch
Expand Down
27 changes: 10 additions & 17 deletions collector/trtllm/collect_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@


def get_context_mla_test_cases():
dtype_list = [tensorrt_llm.bindings.DataType.FP8, tensorrt_llm.bindings.DataType.BF16]
dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1
test_cases = []
n_list = [64, 128]
n_list = [128]
b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
s_list = [
16,
Expand All @@ -47,7 +47,7 @@ def get_context_mla_test_cases():
for b in b_list:
for s in s_list: # [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072]:
for dtype in dtype_list:
for tp_size in [1, 2, 4, 8, 16, 32, 64]:
for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]:
if b * s > 32768:
continue
# (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
Expand All @@ -72,9 +72,9 @@ def get_context_mla_test_cases():


def get_generation_mla_test_cases():
dtype_list = [tensorrt_llm.bindings.DataType.FP8, tensorrt_llm.bindings.DataType.BF16]
dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1
test_cases = []
n_list = [64, 128]
n_list = [128]
for n in n_list:
for b in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
for s in [
Expand All @@ -97,7 +97,7 @@ def get_generation_mla_test_cases():
131072,
]: # [target token s] is equivalent to [in: s-1, step=1]
for dtype in dtype_list:
for tp_size in [1, 2, 4, 8, 16, 32, 64]:
for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]:
if b * s > 1024 * 4096 * 2 * 2:
continue
# (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
Expand Down Expand Up @@ -140,13 +140,11 @@ def run_mla(
torch.cuda.set_device(device)
backend_name = "TRTLLM"
layer_idx = 0
num_key_value_heads = num_heads

assert num_key_value_heads % tp_size == 0, "num_key_value_heads != N * tp_size"
num_key_value_heads = int(num_key_value_heads / tp_size)

assert kv_cache_dtype == tensorrt_llm.bindings.DataType.BF16, "only support bfloat16 for trtllm"
assert num_heads % tp_size == 0, "num_heads != N * tp_size"
num_heads = int(num_heads / tp_size)
num_heads = num_heads // tp_size
num_key_value_heads = num_heads

pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.yarn,
Expand All @@ -170,7 +168,6 @@ def run_mla(
)

quant_config = QuantConfig(
quant_algo="FP8_BLOCK_SCALES",
kv_cache_quant_algo=None,
group_size=None,
smoothquant_val=0.5,
Expand Down Expand Up @@ -391,15 +388,11 @@ def run_mla(
isl = 1
step = input_len

dtype_str = "float16"
if kv_cache_dtype == tensorrt_llm.bindings.DataType.FP8:
dtype_str = "fp8"

log_perf(
item_list=[
{
"mla_dtype": "float16",
"kv_cache_dtype": dtype_str,
"kv_cache_dtype": "float16",
"num_heads": num_heads,
"batch_size": batch_size,
"isl": isl,
Expand Down
104 changes: 104 additions & 0 deletions collector/trtllm/collect_mla_1_1rc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,107 @@
from helper import log_perf


def get_context_mla_test_cases():
dtype_list = [tensorrt_llm.bindings.DataType.BF16, tensorrt_llm.bindings.DataType.FP8]
test_cases = []
n_list = [128]
b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
s_list = [
16,
32,
64,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
6144,
8192,
10240,
12288,
16384,
]
for n in n_list:
for b in b_list:
for s in s_list: # [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072]:
for dtype in dtype_list:
for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]:
if b * s > 32768:
continue
# (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
# tp_size, tokens_per_block, warming_up, test_ite, is_context_phase)
test_cases.append(
[
s,
b,
1,
dtype,
n,
tp_size,
tp_size,
64,
10,
6,
True,
"context_mla_perf.txt",
]
)
return test_cases


def get_generation_mla_test_cases():
dtype_list = [tensorrt_llm.bindings.DataType.BF16, tensorrt_llm.bindings.DataType.FP8]
test_cases = []
n_list = [128]
for n in n_list:
for b in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
for s in [
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
32768,
65536,
131072,
]: # [target token s] is equivalent to [in: s-1, step=1]
for dtype in dtype_list:
for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]:
if b * s > 1024 * 4096 * 2 * 2:
continue
# (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
# tp_size, tokens_per_block, warming_up, test_ite, is_context_phase)
test_cases.append(
[
s - 1,
b,
1,
dtype,
n,
tp_size,
tp_size,
64,
10,
6,
False,
"generation_mla_perf.txt",
]
)
return test_cases


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -122,6 +223,9 @@ def run_mla(
kv_cache_tokens_per_block = tokens_per_block
# device = torch.device('cuda')
dtype = scenario.dtype

assert num_heads % tp_size == 0, "num_heads != N * tp_size"
num_heads = num_heads // tp_size
num_kv_heads = num_heads

context_sequence_lengths = [input_len for _ in range(batch_size)]
Expand Down
Loading