Skip to content

Commit 19fa3cf

Browse files
committed
update format3
1 parent 1f54ef4 commit 19fa3cf

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

collector/collect.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ def create_process_exit_error(device_id, exit_code):
267267
# Wait for processes
268268
for p in processes:
269269
if "moe" in func.__name__:
270-
p.join(timeout = 2000)
270+
p.join(timeout=2000)
271271
else:
272-
p.join(timeout = 10)
272+
p.join(timeout=10)
273273
if p.is_alive():
274274
logger.warning(f"Process {p.pid} did not terminate, forcing...")
275275
p.terminate()
@@ -376,7 +376,8 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None):
376376
"get_func": "get_context_mla_test_cases",
377377
"run_func": "run_mla",
378378
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2"
379-
if v.startswith(("1.1.0", "1.2.0")) else "trtllm.collect_mla",
379+
if v.startswith(("1.1.0", "1.2.0"))
380+
else "trtllm.collect_mla",
380381
},
381382
{
382383
"name": "trtllm",
@@ -385,7 +386,8 @@ def collect_trtllm(num_processes: int, ops: list[str] | None = None):
385386
"get_func": "get_generation_mla_test_cases",
386387
"run_func": "run_mla",
387388
"version_handler": lambda v: "trtllm.collect_mla_1_1rc2"
388-
if v.startswith(("1.1.0", "1.2.0")) else "trtllm.collect_mla",
389+
if v.startswith(("1.1.0", "1.2.0"))
390+
else "trtllm.collect_mla",
389391
},
390392
# Attention collections - separate entries for context and generation
391393
{

collector/trtllm/collect_mla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def get_context_mla_test_cases():
24-
dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1
24+
dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1
2525
test_cases = []
2626
n_list = [128]
2727
b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
@@ -72,7 +72,7 @@ def get_context_mla_test_cases():
7272

7373

7474
def get_generation_mla_test_cases():
75-
dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1
75+
dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1
7676
test_cases = []
7777
n_list = [128]
7878
for n in n_list:

collector/trtllm/collect_mla_1_1rc2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def get_generation_mla_test_cases():
126126
)
127127
return test_cases
128128

129+
129130
# Copied from transformers.models.llama.modeling_llama.rotate_half
130131
def rotate_half(x):
131132
"""Rotates half the hidden dims of the input."""

src/aiconfigurator/sdk/perf_database.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,9 +1916,10 @@ def get_sol(
19161916
* b
19171917
* (
19181918
n * s * h # Q read, assuming 16 bits
1919-
+ n * s * h # Output write, assuming 16 bits
1920-
) + kvcache_quant_mode.value.memory * b * (2 * n_kv * s * h) # K,V read
1921-
) #TODO fp8 io
1919+
+ n * s * h # Output write, assuming 16 bits
1920+
)
1921+
+ kvcache_quant_mode.value.memory * b * (2 * n_kv * s * h) # K,V read
1922+
) # TODO fp8 io
19221923
sol_math = ops / self.system_spec["gpu"]["float16_tc_flops"] * 1000 / fmha_quant_mode.value.compute
19231924
sol_mem = mem_bytes / self.system_spec["gpu"]["mem_bw"] * 1000
19241925
sol_time = max(sol_math, sol_mem)
@@ -2035,9 +2036,9 @@ def get_sol(
20352036
ops = (
20362037
b * num_heads * 2 / 2 * (s * s * 192 + s * s * 128)
20372038
) # 2 for fma, 2 for causality. num_heads, for local heads
2038-
mem_bytes = b * num_heads * (
2039-
kvcache_quant_mode.value.memory * (s * 192 + s * 128) + 2 * (s * 192 + s * 128)
2040-
) # fp16 io + fp16/fp8 kv cache, TODO fp8 io
2039+
mem_bytes = (
2040+
b * num_heads * (kvcache_quant_mode.value.memory * (s * 192 + s * 128) + 2 * (s * 192 + s * 128))
2041+
) # fp16 io + fp16/fp8 kv cache, TODO fp8 io
20412042
sol_math = ops / self.system_spec["gpu"]["float16_tc_flops"] * 1000 / fmha_quant_mode.value.compute
20422043
sol_mem = mem_bytes / self.system_spec["gpu"]["mem_bw"] * 1000
20432044
sol_time = max(sol_math, sol_mem)

0 commit comments

Comments
 (0)