Skip to content

Commit ff7a82a

Browse files
committed
re-use perf_database function to calculate SOL in collector
Signed-off-by: Kai Ma <[email protected]>
1 parent 844cc25 commit ff7a82a

File tree

5 files changed

+254
-83
lines changed

5 files changed

+254
-83
lines changed

collector/collect.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ def worker(
134134
worker_logger.exception("Failed to initialize NVML power monitor")
135135
raise # Fail if power measurement requested but NVML unavailable
136136

137+
# Get default power limit if measuring power but no limits specified
138+
default_power_limit = None
139+
if measure_power and not power_limits:
140+
try:
141+
from nvml_power_monitor import get_power_management_limit
142+
default_power_limit = get_power_management_limit(device_id)
143+
worker_logger.info(f"Auto-detected power limit: {default_power_limit}W on device {device_id}")
144+
except Exception as e:
145+
worker_logger.warning(f"Could not get power limit, power data will not be recorded: {e}")
146+
137147
# Process tasks
138148
while True:
139149
task_info = queue.get()
@@ -150,7 +160,7 @@ def worker(
150160
task_id = create_test_case_id(task, "unknown", module_name)
151161

152162
# Sweep power limits
153-
for power_limit in power_limits or [None]:
163+
for power_limit in power_limits or [default_power_limit]:
154164
with lock:
155165
progress_value.value += 1
156166

collector/helper.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,164 @@ def measure_kernel_power(
383383
avg_power_watts = total_energy_j / (total_time_ms / 1000) # J / seconds
384384

385385
return avg_latency_ms, avg_power_watts
386+
387+
388+
def get_system_spec_from_device(device_name: str) -> dict:
389+
"""Load full system spec from device name.
390+
391+
Args:
392+
device_name: GPU device name
393+
394+
Returns:
395+
Full system_spec dict with 'gpu' key
396+
"""
397+
device_upper = device_name.upper()
398+
if "H100" in device_upper:
399+
system_file = "h100_sxm.yaml"
400+
elif "H200" in device_upper:
401+
system_file = "h200_sxm.yaml"
402+
elif "A100" in device_upper:
403+
system_file = "a100_sxm.yaml"
404+
elif "B200" in device_upper:
405+
system_file = "b200_sxm.yaml"
406+
elif "GB200" in device_upper:
407+
system_file = "gb200_sxm.yaml"
408+
else:
409+
raise ValueError(f"Unsupported GPU: {device_name}")
410+
411+
systems_dir = pkg_resources.files("aiconfigurator") / "systems"
412+
yaml_path = systems_dir / system_file
413+
414+
with open(yaml_path) as f:
415+
system_spec = yaml.safe_load(f)
416+
417+
return system_spec
418+
419+
420+
def _get_gemm_quant_mode(dtype_str: str):
421+
"""Map dtype string to GEMMQuantMode enum."""
422+
from aiconfigurator.sdk import common
423+
424+
dtype_map = {
425+
"float16": common.GEMMQuantMode.float16,
426+
"fp8": common.GEMMQuantMode.fp8,
427+
"fp8_block": common.GEMMQuantMode.fp8_block,
428+
"nvfp4": common.GEMMQuantMode.nvfp4,
429+
}
430+
431+
if dtype_str not in dtype_map:
432+
raise ValueError(f"Unsupported dtype: {dtype_str}")
433+
434+
return dtype_map[dtype_str]
435+
436+
437+
def _get_kvcache_quant_mode(dtype_str: str, use_fp8_kv_cache: bool):
438+
"""Map dtype and fp8 flag to KVCacheQuantMode enum."""
439+
from aiconfigurator.sdk import common
440+
441+
if use_fp8_kv_cache or "fp8" in dtype_str.lower():
442+
return common.KVCacheQuantMode.fp8
443+
else:
444+
return common.KVCacheQuantMode.float16
445+
446+
447+
def _get_fmha_quant_mode(dtype_str: str, use_fp8_context_fmha: bool):
448+
"""Map dtype and fp8 flag to FMHAQuantMode enum."""
449+
from aiconfigurator.sdk import common
450+
451+
if use_fp8_context_fmha or "fp8" in dtype_str.lower():
452+
return common.FMHAQuantMode.fp8
453+
else:
454+
return common.FMHAQuantMode.float16
455+
456+
457+
def is_gemm_compute_bound_collector(m: int, n: int, k: int, dtype: str, device_name: str) -> bool:
458+
"""
459+
Determine if a GEMM operation is compute-bound.
460+
Wrapper for use in collectors.
461+
462+
Args:
463+
m, n, k: GEMM dimensions (C = A @ B, A is mxk, B is kxn)
464+
dtype: Data type (e.g., 'float16', 'fp8')
465+
device_name: GPU device name
466+
467+
Returns:
468+
True if compute-bound, False if memory-bound
469+
"""
470+
from aiconfigurator.sdk import common
471+
from aiconfigurator.sdk.perf_database import PerfDatabase
472+
473+
system_spec = get_system_spec_from_device(device_name)
474+
quant_mode = _get_gemm_quant_mode(dtype)
475+
476+
# Create minimal PerfDatabase instance just to call query_gemm with SOL_FULL
477+
db = PerfDatabase.__new__(PerfDatabase)
478+
db.system_spec = system_spec
479+
480+
sol_time, sol_math, sol_mem = db.query_gemm(m, n, k, quant_mode, sol_mode=common.SOLMode.SOL_FULL)
481+
return sol_math > sol_mem
482+
483+
484+
def is_context_attention_compute_bound_collector(
485+
b: int,
486+
s: int,
487+
num_heads: int,
488+
num_key_value_heads: int,
489+
head_dim: int,
490+
dtype: str,
491+
kv_cache_dtype: str,
492+
use_fp8_kv_cache: bool,
493+
use_fp8_context_fmha: bool,
494+
device_name: str,
495+
attention_window_size: int = 0,
496+
) -> bool:
497+
"""
498+
Determine if context (prefill) attention is compute-bound.
499+
Wrapper for use in collectors.
500+
501+
Args:
502+
b: Batch size
503+
s: Sequence length (input)
504+
num_heads: Number of query heads (H_q)
505+
num_key_value_heads: Number of key/value heads (H_kv)
506+
head_dim: Head dimension
507+
dtype: Activation dtype
508+
kv_cache_dtype: KV cache dtype
509+
use_fp8_kv_cache: Whether using FP8 for KV cache
510+
use_fp8_context_fmha: Whether using FP8 for context FMHA
511+
device_name: GPU device name
512+
attention_window_size: Attention window size
513+
514+
Returns:
515+
True if compute-bound, False if memory-bound
516+
"""
517+
from aiconfigurator.sdk import common
518+
from aiconfigurator.sdk.perf_database import PerfDatabase
519+
520+
system_spec = get_system_spec_from_device(device_name)
521+
kvcache_quant_mode = _get_kvcache_quant_mode(kv_cache_dtype, use_fp8_kv_cache)
522+
fmha_quant_mode = _get_fmha_quant_mode(dtype, use_fp8_context_fmha)
523+
524+
# Create minimal PerfDatabase instance just to call query_context_attention with SOL_FULL
525+
db = PerfDatabase.__new__(PerfDatabase)
526+
db.system_spec = system_spec
527+
528+
sol_time, sol_math, sol_mem = db.query_context_attention(
529+
b, s, num_heads, num_key_value_heads,
530+
kvcache_quant_mode, fmha_quant_mode,
531+
sol_mode=common.SOLMode.SOL_FULL,
532+
window_size=attention_window_size,
533+
head_size=head_dim
534+
)
535+
return sol_math > sol_mem
536+
537+
538+
def is_generation_attention_compute_bound_collector() -> bool:
539+
"""
540+
Determine if generation (decode) attention is compute-bound.
541+
Generation attention is ALWAYS memory-bound.
542+
543+
Returns:
544+
False (always memory-bound)
545+
"""
546+
return False

collector/trtllm/collect_attn.py

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,59 +23,13 @@
2323
get_dtype_size,
2424
get_gpu_specs_from_device,
2525
get_sm_version,
26+
is_context_attention_compute_bound_collector,
27+
is_generation_attention_compute_bound_collector,
2628
log_perf,
2729
measure_kernel_power,
2830
)
2931

3032

31-
def is_context_attention_compute_bound(b, s, num_heads, num_key_value_heads, d, dtype, kv_cache_dtype, device_name):
32-
"""
33-
Determine if context (prefill) attention is compute-bound with Grouped-Query Attention.
34-
35-
Args:
36-
b: Batch size
37-
s: Sequence length (input)
38-
num_heads: Number of query heads (H_q)
39-
num_key_value_heads: Number of key/value heads (H_kv)
40-
d: Head dimension
41-
dtype: Activation dtype
42-
kv_cache_dtype: KV cache dtype
43-
device_name: GPU device name
44-
45-
Returns:
46-
True if compute-bound, False if memory-bound
47-
"""
48-
gpu_specs = get_gpu_specs_from_device(device_name)
49-
dtype_size = get_dtype_size(dtype)
50-
kv_dtype_size = get_dtype_size(kv_cache_dtype)
51-
52-
# Hardware intensity
53-
if "fp8" in dtype.lower():
54-
hardware_tflops = gpu_specs["fp8_tflops"]
55-
else:
56-
hardware_tflops = gpu_specs["float16_tflops"]
57-
58-
hardware_intensity = (hardware_tflops * 1e12) / (gpu_specs["mem_bw_gbs"] * 1e9)
59-
60-
# GQA Attention FLOPs
61-
total_flops = 4 * b * num_heads * s * s * d
62-
63-
# GQA Attention Memory Movement
64-
memory_bytes = (
65-
dtype_size * b * s * num_heads * d # Q read (all query heads)
66-
+ kv_dtype_size * b * s * num_key_value_heads * d * 2 # K read and write (KV heads)
67-
+ kv_dtype_size * b * s * num_key_value_heads * d * 2 # V read and write (KV heads)
68-
+ dtype_size * b * s * num_heads * d # Output write
69-
)
70-
71-
arithmetic_intensity = total_flops / memory_bytes
72-
73-
return arithmetic_intensity > hardware_intensity
74-
75-
76-
def is_generation_attention_compute_bound():
77-
"""Generation (decode) attention is ALWAYS memory-bound"""
78-
return False
7933

8034

8135
def run_attention_torch(
@@ -303,14 +257,24 @@ def run_attention_torch(
303257

304258
# Determine if compute-bound
305259
if is_context_phase:
306-
compute_bound = is_context_attention_compute_bound(
307-
batch_size, input_len, num_heads, num_key_value_heads, head_dim, dtype_str, kv_cache_dtype_str, device_name
260+
compute_bound = is_context_attention_compute_bound_collector(
261+
batch_size,
262+
input_len,
263+
num_heads,
264+
num_key_value_heads,
265+
head_dim,
266+
dtype_str,
267+
kv_cache_dtype_str,
268+
use_fp8_kv_cache,
269+
use_fp8_context_fmha,
270+
device_name,
271+
attention_window_size,
308272
)
309273
isl = input_len
310274
step = 0
311275
op_name = "context_attention"
312276
else:
313-
compute_bound = is_generation_attention_compute_bound()
277+
compute_bound = is_generation_attention_compute_bound_collector()
314278
isl = 1
315279
step = input_len
316280
op_name = "generation_attention"

collector/trtllm/collect_gemm.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
get_dtype_size,
1313
get_gpu_specs_from_device,
1414
get_sm_version,
15+
is_gemm_compute_bound_collector,
1516
log_perf,
1617
measure_kernel_power,
1718
)
@@ -86,36 +87,6 @@ def get_gemm_test_cases():
8687
return test_cases
8788

8889

89-
def is_gemm_compute_bound(m, n, k, dtype, device_name):
90-
"""
91-
Determine if a GEMM operation is compute-bound.
92-
93-
Args:
94-
m, n, k: GEMM dimensions (C = A @ B, A is mxk, B is kxn)
95-
dtype: Data type (e.g., 'float16', 'fp8')
96-
device_name: GPU device name
97-
98-
Returns:
99-
True if compute-bound, False if memory-bound
100-
"""
101-
gpu_specs = get_gpu_specs_from_device(device_name)
102-
dtype_size = get_dtype_size(dtype)
103-
104-
# Hardware intensity (FLOPs per byte)
105-
if "fp8" in dtype.lower():
106-
hardware_tflops = gpu_specs["fp8_tflops"]
107-
else:
108-
hardware_tflops = gpu_specs["float16_tflops"]
109-
110-
hardware_intensity = (hardware_tflops * 1e12) / (gpu_specs["mem_bw_gbs"] * 1e9)
111-
112-
# GEMM arithmetic intensity
113-
total_flops = 2 * m * n * k
114-
memory_bytes = dtype_size * (m * k + k * n + m * n)
115-
arithmetic_intensity = total_flops / memory_bytes
116-
117-
# Compute-bound if arithmetic intensity > hardware intensity
118-
return arithmetic_intensity > hardware_intensity
11990

12091

12192
def run_gemm(
@@ -222,7 +193,7 @@ def run_gemm(
222193
op.forward(x)
223194

224195
# Determine if compute-bound
225-
compute_bound = is_gemm_compute_bound(m, n, k, gemm_type, device_name)
196+
compute_bound = is_gemm_compute_bound_collector(m, n, k, gemm_type, device_name)
226197

227198
# Benchmarking
228199
if measure_power and power_monitor is not None and not compute_bound:

0 commit comments

Comments
 (0)