@@ -296,13 +296,13 @@ def get_dtype_size(dtype: str) -> float:
296296
297297def _get_system_file_for_device (device_name : str ) -> str :
298298 """Map GPU device name to system YAML filename.
299-
299+
300300 Args:
301301 device_name: GPU device name
302-
302+
303303 Returns:
304304 System YAML filename
305-
305+
306306 Raises:
307307 ValueError: If GPU is not supported
308308 """
@@ -314,11 +314,11 @@ def _get_system_file_for_device(device_name: str) -> str:
314314 "GB200" : "gb200_sxm.yaml" , # Check GB200 before B200
315315 "B200" : "b200_sxm.yaml" ,
316316 }
317-
317+
318318 for prefix , filename in gpu_mappings .items ():
319319 if prefix in device_upper :
320320 return filename
321-
321+
322322 raise ValueError (f"Unsupported GPU: { device_name } " )
323323
324324
@@ -400,10 +400,10 @@ def measure_kernel_power(
400400
401401def get_system_spec_from_device (device_name : str ) -> dict :
402402 """Load full system spec from device name.
403-
403+
404404 Args:
405405 device_name: GPU device name
406-
406+
407407 Returns:
408408 Full system_spec dict with 'gpu' key
409409 """
@@ -420,24 +420,24 @@ def get_system_spec_from_device(device_name: str) -> dict:
420420def _get_gemm_quant_mode (dtype_str : str ):
421421 """Map dtype string to GEMMQuantMode enum."""
422422 from aiconfigurator .sdk import common
423-
423+
424424 dtype_map = {
425425 "float16" : common .GEMMQuantMode .float16 ,
426426 "fp8" : common .GEMMQuantMode .fp8 ,
427427 "fp8_block" : common .GEMMQuantMode .fp8_block ,
428428 "nvfp4" : common .GEMMQuantMode .nvfp4 ,
429429 }
430-
430+
431431 if dtype_str not in dtype_map :
432432 raise ValueError (f"Unsupported dtype: { dtype_str } " )
433-
433+
434434 return dtype_map [dtype_str ]
435435
436436
437437def _get_kvcache_quant_mode (dtype_str : str , use_fp8_kv_cache : bool ):
438438 """Map dtype and fp8 flag to KVCacheQuantMode enum."""
439439 from aiconfigurator .sdk import common
440-
440+
441441 if use_fp8_kv_cache or "fp8" in dtype_str .lower ():
442442 return common .KVCacheQuantMode .fp8
443443 else :
@@ -447,7 +447,7 @@ def _get_kvcache_quant_mode(dtype_str: str, use_fp8_kv_cache: bool):
447447def _get_fmha_quant_mode (dtype_str : str , use_fp8_context_fmha : bool ):
448448 """Map dtype and fp8 flag to FMHAQuantMode enum."""
449449 from aiconfigurator .sdk import common
450-
450+
451451 if use_fp8_context_fmha or "fp8" in dtype_str .lower ():
452452 return common .FMHAQuantMode .fp8
453453 else :
@@ -458,25 +458,25 @@ def is_gemm_compute_bound_collector(m: int, n: int, k: int, dtype: str, device_n
458458 """
459459 Determine if a GEMM operation is compute-bound.
460460 Wrapper for use in collectors.
461-
461+
462462 Args:
463463 m, n, k: GEMM dimensions (C = A @ B, A is mxk, B is kxn)
464464 dtype: Data type (e.g., 'float16', 'fp8')
465465 device_name: GPU device name
466-
466+
467467 Returns:
468468 True if compute-bound, False if memory-bound
469469 """
470470 from aiconfigurator .sdk import common
471471 from aiconfigurator .sdk .perf_database import PerfDatabase
472-
472+
473473 system_spec = get_system_spec_from_device (device_name )
474474 quant_mode = _get_gemm_quant_mode (dtype )
475-
475+
476476 # Create minimal PerfDatabase instance just to call query_gemm with SOL_FULL
477477 db = PerfDatabase .__new__ (PerfDatabase )
478478 db .system_spec = system_spec
479-
479+
480480 sol_time , sol_math , sol_mem = db .query_gemm (m , n , k , quant_mode , sol_mode = common .SOLMode .SOL_FULL )
481481 return sol_math > sol_mem
482482
@@ -497,7 +497,7 @@ def is_context_attention_compute_bound_collector(
497497 """
498498 Determine if context (prefill) attention is compute-bound.
499499 Wrapper for use in collectors.
500-
500+
501501 Args:
502502 b: Batch size
503503 s: Sequence length (input)
@@ -510,21 +510,21 @@ def is_context_attention_compute_bound_collector(
510510 use_fp8_context_fmha: Whether using FP8 for context FMHA
511511 device_name: GPU device name
512512 attention_window_size: Attention window size
513-
513+
514514 Returns:
515515 True if compute-bound, False if memory-bound
516516 """
517517 from aiconfigurator .sdk import common
518518 from aiconfigurator .sdk .perf_database import PerfDatabase
519-
519+
520520 system_spec = get_system_spec_from_device (device_name )
521521 kvcache_quant_mode = _get_kvcache_quant_mode (kv_cache_dtype , use_fp8_kv_cache )
522522 fmha_quant_mode = _get_fmha_quant_mode (dtype , use_fp8_context_fmha )
523-
523+
524524 # Create minimal PerfDatabase instance just to call query_context_attention with SOL_FULL
525525 db = PerfDatabase .__new__ (PerfDatabase )
526526 db .system_spec = system_spec
527-
527+
528528 sol_time , sol_math , sol_mem = db .query_context_attention (
529529 b , s , num_heads , num_key_value_heads ,
530530 kvcache_quant_mode , fmha_quant_mode ,
@@ -539,7 +539,7 @@ def is_generation_attention_compute_bound_collector() -> bool:
539539 """
540540 Determine if generation (decode) attention is compute-bound.
541541 Generation attention is ALWAYS memory-bound.
542-
542+
543543 Returns:
544544 False (always memory-bound)
545545 """
0 commit comments