Skip to content

Commit 91eff54

Browse files
committed
fix: sol compute and trtllm=1.2 collect
Signed-off-by: Kimi Zhao <[email protected]>
1 parent 26eeb46 commit 91eff54

File tree

18 files changed

+398
-202
lines changed

18 files changed

+398
-202
lines changed

collector/collect.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,10 @@ def create_process_exit_error(device_id, exit_code):
222222
# Stall detection unchanged...
223223
if progress_value.value == last_progress:
224224
stall_count += 1
225-
if stall_count > 30:
225+
if stall_count > 30 and "moe" not in func.__name__:
226226
logger.warning(f"Progress stalled at {progress_value.value}/{len(tasks)}")
227+
if stall_count > 900 and "moe" in func.__name__:
228+
logger.warning(f"Moe Progress stalled at {progress_value.value}/{len(tasks)}")
227229
else:
228230
stall_count = 0
229231
last_progress = progress_value.value
@@ -264,7 +266,10 @@ def create_process_exit_error(device_id, exit_code):
264266

265267
# Wait for processes
266268
for p in processes:
267-
p.join(timeout=10)
269+
if "moe" in func.__name__:
270+
p.join(timeout=500) # tune + 30 tokens cases
271+
else:
272+
p.join(timeout=10)
268273
if p.is_alive():
269274
logger.warning(f"Process {p.pid} did not terminate, forcing...")
270275
p.terminate()
@@ -511,15 +516,19 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None):
511516
"module": "collector.trtllm.collect_mla",
512517
"get_func": "get_context_mla_test_cases",
513518
"run_func": "run_mla",
514-
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2" if v.startswith("1.1") else "trtllm.collect_mla",
519+
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2"
520+
if v.startswith(("1.1.0", "1.2.0"))
521+
else "trtllm.collect_mla",
515522
},
516523
{
517524
"name": "trtllm",
518525
"type": "mla_generation",
519526
"module": "collector.trtllm.collect_mla",
520527
"get_func": "get_generation_mla_test_cases",
521528
"run_func": "run_mla",
522-
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2" if v.startswith("1.1") else "trtllm.collect_mla",
529+
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2"
530+
if v.startswith(("1.1.0", "1.2.0"))
531+
else "trtllm.collect_mla",
523532
},
524533
# Attention collections - separate entries for context and generation
525534
{
@@ -563,7 +572,7 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None):
563572
else "collector.trtllm.collect_moe_pre_1_0"
564573
if v.startswith(("0.21.0", "1.0.0"))
565574
else "collector.trtllm.collect_moe"
566-
if v.startswith("1.1.0")
575+
if v.startswith(("1.1.0", "1.2.0"))
567576
else None,
568577
},
569578
]

collector/collect_all_reduce.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,9 @@
3030
from argparse import ArgumentParser
3131
from typing import Optional
3232

33-
# isort: off
3433
import torch
3534

36-
# isort: on
35+
from helper import log_perf
3736

3837

3938
def get_input_shape_and_comm_size(size, token_dim=4096):
@@ -112,8 +111,6 @@ def benchmark_trtllm_allreduce(
112111
num_warmups = 3
113112
num_runs = 20
114113

115-
from helper import log_perf
116-
117114
size = min_size
118115
while size < max_size:
119116
input_shape = get_input_shape_and_comm_size(size)
@@ -261,8 +258,6 @@ def benchmark_vllm_allreduce(
261258
num_warmups = 3
262259
num_runs = 20
263260

264-
from helper import log_perf
265-
266261
# Warmup communication
267262
warmup_tensor = torch.ones(1, dtype=torch_dtype, device="cuda")
268263
_ = vllm_mods["tensor_model_parallel_all_reduce"](warmup_tensor)

collector/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
try:
1414
from cuda import cuda
1515
except:
16-
pass
16+
from cuda.bindings import driver as cuda
1717
from datetime import datetime
1818
from pathlib import Path
1919

collector/slurm_comm_collector/collect_allreduce.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import os
66

77
import torch
8-
from cuda import cudart
8+
9+
try:
10+
from cuda import cudart
11+
except:
12+
from cuda.bindings import runtime as cudart
913
from tensorrt_llm import Mapping
1014
from tensorrt_llm._torch.distributed import (
1115
AllReduce,

collector/trtllm/collect_gemm_trt.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import tensorrt as trt
55
import tensorrt_llm
66
import torch
7-
from cuda import cudart
7+
8+
try:
9+
from cuda import cudart
10+
except:
11+
from cuda.bindings import runtime as cudart
812
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner
913
from tensorrt_llm import Tensor
1014
from tensorrt_llm._utils import str_dtype_to_torch

collector/trtllm/collect_mla.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222

2323
def get_context_mla_test_cases():
24-
dtype_list = [tensorrt_llm.bindings.DataType.FP8, tensorrt_llm.bindings.DataType.BF16]
24+
dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1
2525
test_cases = []
26-
n_list = [64, 128]
26+
n_list = [128]
2727
b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
2828
s_list = [
2929
16,
@@ -47,7 +47,7 @@ def get_context_mla_test_cases():
4747
for b in b_list:
4848
for s in s_list: # [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072]:
4949
for dtype in dtype_list:
50-
for tp_size in [1, 2, 4, 8, 16, 32, 64]:
50+
for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]:
5151
if b * s > 32768:
5252
continue
5353
# (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
@@ -72,9 +72,9 @@ def get_context_mla_test_cases():
7272

7373

7474
def get_generation_mla_test_cases():
75-
dtype_list = [tensorrt_llm.bindings.DataType.FP8, tensorrt_llm.bindings.DataType.BF16]
75+
dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1
7676
test_cases = []
77-
n_list = [64, 128]
77+
n_list = [128]
7878
for n in n_list:
7979
for b in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
8080
for s in [
@@ -97,7 +97,7 @@ def get_generation_mla_test_cases():
9797
131072,
9898
]: # [target token s] is equivalent to [in: s-1, step=1]
9999
for dtype in dtype_list:
100-
for tp_size in [1, 2, 4, 8, 16, 32, 64]:
100+
for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]:
101101
if b * s > 1024 * 4096 * 2 * 2:
102102
continue
103103
# (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
@@ -140,13 +140,11 @@ def run_mla(
140140
torch.cuda.set_device(device)
141141
backend_name = "TRTLLM"
142142
layer_idx = 0
143-
num_key_value_heads = num_heads
144-
145-
assert num_key_value_heads % tp_size == 0, "num_key_value_heads != N * tp_size"
146-
num_key_value_heads = int(num_key_value_heads / tp_size)
147143

144+
assert kv_cache_dtype == tensorrt_llm.bindings.DataType.BF16, "only support bfloat16 for trtllm"
148145
assert num_heads % tp_size == 0, "num_heads != N * tp_size"
149-
num_heads = int(num_heads / tp_size)
146+
num_heads = num_heads // tp_size
147+
num_key_value_heads = num_heads
150148

151149
pos_embd_params = PositionalEmbeddingParams(
152150
type=PositionEmbeddingType.yarn,
@@ -170,7 +168,6 @@ def run_mla(
170168
)
171169

172170
quant_config = QuantConfig(
173-
quant_algo="FP8_BLOCK_SCALES",
174171
kv_cache_quant_algo=None,
175172
group_size=None,
176173
smoothquant_val=0.5,
@@ -391,15 +388,11 @@ def run_mla(
391388
isl = 1
392389
step = input_len
393390

394-
dtype_str = "float16"
395-
if kv_cache_dtype == tensorrt_llm.bindings.DataType.FP8:
396-
dtype_str = "fp8"
397-
398391
log_perf(
399392
item_list=[
400393
{
401394
"mla_dtype": "float16",
402-
"kv_cache_dtype": dtype_str,
395+
"kv_cache_dtype": "float16",
403396
"num_heads": num_heads,
404397
"batch_size": batch_size,
405398
"isl": isl,

collector/trtllm/collect_mla_1_1rc2.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,107 @@
2626
from helper import log_perf
2727

2828

29+
def get_context_mla_test_cases():
30+
dtype_list = [tensorrt_llm.bindings.DataType.BF16, tensorrt_llm.bindings.DataType.FP8]
31+
test_cases = []
32+
n_list = [128]
33+
b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
34+
s_list = [
35+
16,
36+
32,
37+
64,
38+
128,
39+
256,
40+
512,
41+
1024,
42+
1536,
43+
2048,
44+
3072,
45+
4096,
46+
6144,
47+
8192,
48+
10240,
49+
12288,
50+
16384,
51+
]
52+
for n in n_list:
53+
for b in b_list:
54+
for s in s_list: # [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072]:
55+
for dtype in dtype_list:
56+
for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]:
57+
if b * s > 32768:
58+
continue
59+
# (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
60+
# tp_size, tokens_per_block, warming_up, test_ite, is_context_phase)
61+
test_cases.append(
62+
[
63+
s,
64+
b,
65+
1,
66+
dtype,
67+
n,
68+
tp_size,
69+
tp_size,
70+
64,
71+
10,
72+
6,
73+
True,
74+
"context_mla_perf.txt",
75+
]
76+
)
77+
return test_cases
78+
79+
80+
def get_generation_mla_test_cases():
81+
dtype_list = [tensorrt_llm.bindings.DataType.BF16, tensorrt_llm.bindings.DataType.FP8]
82+
test_cases = []
83+
n_list = [128]
84+
for n in n_list:
85+
for b in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
86+
for s in [
87+
2,
88+
4,
89+
8,
90+
16,
91+
32,
92+
64,
93+
128,
94+
256,
95+
512,
96+
1024,
97+
2048,
98+
4096,
99+
8192,
100+
16384,
101+
32768,
102+
65536,
103+
131072,
104+
]: # [target token s] is equivalent to [in: s-1, step=1]
105+
for dtype in dtype_list:
106+
for tp_size in [1, 2, 4, 8, 16, 32, 64, 128]:
107+
if b * s > 1024 * 4096 * 2 * 2:
108+
continue
109+
# (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
110+
# tp_size, tokens_per_block, warming_up, test_ite, is_context_phase)
111+
test_cases.append(
112+
[
113+
s - 1,
114+
b,
115+
1,
116+
dtype,
117+
n,
118+
tp_size,
119+
tp_size,
120+
64,
121+
10,
122+
6,
123+
False,
124+
"generation_mla_perf.txt",
125+
]
126+
)
127+
return test_cases
128+
129+
29130
# Copied from transformers.models.llama.modeling_llama.rotate_half
30131
def rotate_half(x):
31132
"""Rotates half the hidden dims of the input."""
@@ -122,6 +223,9 @@ def run_mla(
122223
kv_cache_tokens_per_block = tokens_per_block
123224
# device = torch.device('cuda')
124225
dtype = scenario.dtype
226+
227+
assert num_heads % tp_size == 0, "num_heads != N * tp_size"
228+
num_heads = num_heads // tp_size
125229
num_kv_heads = num_heads
126230

127231
context_sequence_lengths = [input_len for _ in range(batch_size)]

0 commit comments

Comments
 (0)