Skip to content

Commit d63c47d

Browse files
committed
fix Linting complaints
Signed-off-by: Kai Ma <[email protected]>
1 parent 47829b1 commit d63c47d

File tree

6 files changed

+362
-39
lines changed

6 files changed

+362
-39
lines changed

collector/collect_all_reduce.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,13 @@ def import_trtllm():
7474

7575

7676
def benchmark_trtllm_allreduce(
77-
dtype: str, test_range: str, world_size: int, rank: int, use_slurm: bool, perf_filename: str, measure_power: bool = False
77+
dtype: str,
78+
test_range: str,
79+
world_size: int,
80+
rank: int,
81+
use_slurm: bool,
82+
perf_filename: str,
83+
measure_power: bool = False,
7884
):
7985
"""Benchmark TensorRT-LLM AllReduce implementation"""
8086
trtllm_mods = import_trtllm()
@@ -100,7 +106,7 @@ def benchmark_trtllm_allreduce(
100106

101107
power_monitor = NVMLPowerMonitor(gpu_indices=[local_rank])
102108
if rank == 0:
103-
print(f"NVML power monitoring enabled on all ranks")
109+
print("NVML power monitoring enabled on all ranks")
104110
except Exception as e:
105111
if rank == 0:
106112
print(f"Warning: Failed to initialize NVML power monitor: {e}")
@@ -189,7 +195,8 @@ def benchmark_trtllm_allreduce(
189195
avg_power = None
190196

191197
if rank == 0 and local_rank == 0:
192-
print(f"[TensorRT-LLM] Size: {size}, Latency: {latency:.4f} ms" + (f", Power: {avg_power:.2f} W" if avg_power is not None else ""))
198+
power_str = f", Power: {avg_power:.2f} W" if avg_power is not None else ""
199+
print(f"[TensorRT-LLM] Size: {size}, Latency: {latency:.4f} ms{power_str}")
193200

194201
# Get TensorRT-LLM version
195202
trtllm_version = tllm.__version__ if hasattr(tllm, "__version__") else "unknown"
@@ -202,7 +209,7 @@ def benchmark_trtllm_allreduce(
202209
"latency": latency,
203210
"implementation": "trtllm",
204211
}
205-
212+
206213
if avg_power is not None:
207214
item["power"] = avg_power
208215
item["compute_bound"] = 0 # Communication is always memory/bandwidth-bound
@@ -289,7 +296,13 @@ def setup_vllm_distributed(world_size, rank, use_slurm):
289296

290297

291298
def benchmark_vllm_allreduce(
292-
dtype: str, test_range: str, world_size: int, rank: int, use_slurm: bool, perf_filename: str, measure_power: bool = False
299+
dtype: str,
300+
test_range: str,
301+
world_size: int,
302+
rank: int,
303+
use_slurm: bool,
304+
perf_filename: str,
305+
measure_power: bool = False,
293306
):
294307
"""Benchmark vLLM custom AllReduce backend"""
295308
vllm_mods, local_rank = setup_vllm_distributed(world_size, rank, use_slurm)
@@ -302,7 +315,7 @@ def benchmark_vllm_allreduce(
302315

303316
power_monitor = NVMLPowerMonitor(gpu_indices=[local_rank])
304317
if rank == 0:
305-
print(f"NVML power monitoring enabled on all ranks")
318+
print("NVML power monitoring enabled on all ranks")
306319
except Exception as e:
307320
if rank == 0:
308321
print(f"Warning: Failed to initialize NVML power monitor: {e}")
@@ -431,7 +444,8 @@ def benchmark_vllm_allreduce(
431444
avg_power = None
432445

433446
if rank == 0:
434-
print(f"[vLLM-{mode_str}] Size: {size}, Latency: {latency:.4f} ms" + (f", Power: {avg_power:.2f} W" if avg_power is not None else ""))
447+
power_str = f", Power: {avg_power:.2f} W" if avg_power is not None else ""
448+
print(f"[vLLM-{mode_str}] Size: {size}, Latency: {latency:.4f} ms{power_str}")
435449

436450
# Get vLLM version
437451
try:

collector/helper.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,13 @@ def get_dtype_size(dtype: str) -> float:
296296

297297
def _get_system_file_for_device(device_name: str) -> str:
298298
"""Map GPU device name to system YAML filename.
299-
299+
300300
Args:
301301
device_name: GPU device name
302-
302+
303303
Returns:
304304
System YAML filename
305-
305+
306306
Raises:
307307
ValueError: If GPU is not supported
308308
"""
@@ -314,11 +314,11 @@ def _get_system_file_for_device(device_name: str) -> str:
314314
"GB200": "gb200_sxm.yaml", # Check GB200 before B200
315315
"B200": "b200_sxm.yaml",
316316
}
317-
317+
318318
for prefix, filename in gpu_mappings.items():
319319
if prefix in device_upper:
320320
return filename
321-
321+
322322
raise ValueError(f"Unsupported GPU: {device_name}")
323323

324324

@@ -400,10 +400,10 @@ def measure_kernel_power(
400400

401401
def get_system_spec_from_device(device_name: str) -> dict:
402402
"""Load full system spec from device name.
403-
403+
404404
Args:
405405
device_name: GPU device name
406-
406+
407407
Returns:
408408
Full system_spec dict with 'gpu' key
409409
"""
@@ -420,24 +420,24 @@ def get_system_spec_from_device(device_name: str) -> dict:
420420
def _get_gemm_quant_mode(dtype_str: str):
421421
"""Map dtype string to GEMMQuantMode enum."""
422422
from aiconfigurator.sdk import common
423-
423+
424424
dtype_map = {
425425
"float16": common.GEMMQuantMode.float16,
426426
"fp8": common.GEMMQuantMode.fp8,
427427
"fp8_block": common.GEMMQuantMode.fp8_block,
428428
"nvfp4": common.GEMMQuantMode.nvfp4,
429429
}
430-
430+
431431
if dtype_str not in dtype_map:
432432
raise ValueError(f"Unsupported dtype: {dtype_str}")
433-
433+
434434
return dtype_map[dtype_str]
435435

436436

437437
def _get_kvcache_quant_mode(dtype_str: str, use_fp8_kv_cache: bool):
438438
"""Map dtype and fp8 flag to KVCacheQuantMode enum."""
439439
from aiconfigurator.sdk import common
440-
440+
441441
if use_fp8_kv_cache or "fp8" in dtype_str.lower():
442442
return common.KVCacheQuantMode.fp8
443443
else:
@@ -447,7 +447,7 @@ def _get_kvcache_quant_mode(dtype_str: str, use_fp8_kv_cache: bool):
447447
def _get_fmha_quant_mode(dtype_str: str, use_fp8_context_fmha: bool):
448448
"""Map dtype and fp8 flag to FMHAQuantMode enum."""
449449
from aiconfigurator.sdk import common
450-
450+
451451
if use_fp8_context_fmha or "fp8" in dtype_str.lower():
452452
return common.FMHAQuantMode.fp8
453453
else:
@@ -458,25 +458,25 @@ def is_gemm_compute_bound_collector(m: int, n: int, k: int, dtype: str, device_n
458458
"""
459459
Determine if a GEMM operation is compute-bound.
460460
Wrapper for use in collectors.
461-
461+
462462
Args:
463463
m, n, k: GEMM dimensions (C = A @ B, A is mxk, B is kxn)
464464
dtype: Data type (e.g., 'float16', 'fp8')
465465
device_name: GPU device name
466-
466+
467467
Returns:
468468
True if compute-bound, False if memory-bound
469469
"""
470470
from aiconfigurator.sdk import common
471471
from aiconfigurator.sdk.perf_database import PerfDatabase
472-
472+
473473
system_spec = get_system_spec_from_device(device_name)
474474
quant_mode = _get_gemm_quant_mode(dtype)
475-
475+
476476
# Create minimal PerfDatabase instance just to call query_gemm with SOL_FULL
477477
db = PerfDatabase.__new__(PerfDatabase)
478478
db.system_spec = system_spec
479-
479+
480480
sol_time, sol_math, sol_mem = db.query_gemm(m, n, k, quant_mode, sol_mode=common.SOLMode.SOL_FULL)
481481
return sol_math > sol_mem
482482

@@ -497,7 +497,7 @@ def is_context_attention_compute_bound_collector(
497497
"""
498498
Determine if context (prefill) attention is compute-bound.
499499
Wrapper for use in collectors.
500-
500+
501501
Args:
502502
b: Batch size
503503
s: Sequence length (input)
@@ -510,21 +510,21 @@ def is_context_attention_compute_bound_collector(
510510
use_fp8_context_fmha: Whether using FP8 for context FMHA
511511
device_name: GPU device name
512512
attention_window_size: Attention window size
513-
513+
514514
Returns:
515515
True if compute-bound, False if memory-bound
516516
"""
517517
from aiconfigurator.sdk import common
518518
from aiconfigurator.sdk.perf_database import PerfDatabase
519-
519+
520520
system_spec = get_system_spec_from_device(device_name)
521521
kvcache_quant_mode = _get_kvcache_quant_mode(kv_cache_dtype, use_fp8_kv_cache)
522522
fmha_quant_mode = _get_fmha_quant_mode(dtype, use_fp8_context_fmha)
523-
523+
524524
# Create minimal PerfDatabase instance just to call query_context_attention with SOL_FULL
525525
db = PerfDatabase.__new__(PerfDatabase)
526526
db.system_spec = system_spec
527-
527+
528528
sol_time, sol_math, sol_mem = db.query_context_attention(
529529
b, s, num_heads, num_key_value_heads,
530530
kvcache_quant_mode, fmha_quant_mode,
@@ -539,7 +539,7 @@ def is_generation_attention_compute_bound_collector() -> bool:
539539
"""
540540
Determine if generation (decode) attention is compute-bound.
541541
Generation attention is ALWAYS memory-bound.
542-
542+
543543
Returns:
544544
False (always memory-bound)
545545
"""

0 commit comments

Comments
 (0)