@@ -383,3 +383,164 @@ def measure_kernel_power(
383383 avg_power_watts = total_energy_j / (total_time_ms / 1000 ) # J / seconds
384384
385385 return avg_latency_ms , avg_power_watts
386+
387+
388+ def get_system_spec_from_device (device_name : str ) -> dict :
389+ """Load full system spec from device name.
390+
391+ Args:
392+ device_name: GPU device name
393+
394+ Returns:
395+ Full system_spec dict with 'gpu' key
396+ """
397+ device_upper = device_name .upper ()
398+ if "H100" in device_upper :
399+ system_file = "h100_sxm.yaml"
400+ elif "H200" in device_upper :
401+ system_file = "h200_sxm.yaml"
402+ elif "A100" in device_upper :
403+ system_file = "a100_sxm.yaml"
404+ elif "B200" in device_upper :
405+ system_file = "b200_sxm.yaml"
406+ elif "GB200" in device_upper :
407+ system_file = "gb200_sxm.yaml"
408+ else :
409+ raise ValueError (f"Unsupported GPU: { device_name } " )
410+
411+ systems_dir = pkg_resources .files ("aiconfigurator" ) / "systems"
412+ yaml_path = systems_dir / system_file
413+
414+ with open (yaml_path ) as f :
415+ system_spec = yaml .safe_load (f )
416+
417+ return system_spec
418+
419+
420+ def _get_gemm_quant_mode (dtype_str : str ):
421+ """Map dtype string to GEMMQuantMode enum."""
422+ from aiconfigurator .sdk import common
423+
424+ dtype_map = {
425+ "float16" : common .GEMMQuantMode .float16 ,
426+ "fp8" : common .GEMMQuantMode .fp8 ,
427+ "fp8_block" : common .GEMMQuantMode .fp8_block ,
428+ "nvfp4" : common .GEMMQuantMode .nvfp4 ,
429+ }
430+
431+ if dtype_str not in dtype_map :
432+ raise ValueError (f"Unsupported dtype: { dtype_str } " )
433+
434+ return dtype_map [dtype_str ]
435+
436+
437+ def _get_kvcache_quant_mode (dtype_str : str , use_fp8_kv_cache : bool ):
438+ """Map dtype and fp8 flag to KVCacheQuantMode enum."""
439+ from aiconfigurator .sdk import common
440+
441+ if use_fp8_kv_cache or "fp8" in dtype_str .lower ():
442+ return common .KVCacheQuantMode .fp8
443+ else :
444+ return common .KVCacheQuantMode .float16
445+
446+
447+ def _get_fmha_quant_mode (dtype_str : str , use_fp8_context_fmha : bool ):
448+ """Map dtype and fp8 flag to FMHAQuantMode enum."""
449+ from aiconfigurator .sdk import common
450+
451+ if use_fp8_context_fmha or "fp8" in dtype_str .lower ():
452+ return common .FMHAQuantMode .fp8
453+ else :
454+ return common .FMHAQuantMode .float16
455+
456+
457+ def is_gemm_compute_bound_collector (m : int , n : int , k : int , dtype : str , device_name : str ) -> bool :
458+ """
459+ Determine if a GEMM operation is compute-bound.
460+ Wrapper for use in collectors.
461+
462+ Args:
463+ m, n, k: GEMM dimensions (C = A @ B, A is mxk, B is kxn)
464+ dtype: Data type (e.g., 'float16', 'fp8')
465+ device_name: GPU device name
466+
467+ Returns:
468+ True if compute-bound, False if memory-bound
469+ """
470+ from aiconfigurator .sdk import common
471+ from aiconfigurator .sdk .perf_database import PerfDatabase
472+
473+ system_spec = get_system_spec_from_device (device_name )
474+ quant_mode = _get_gemm_quant_mode (dtype )
475+
476+ # Create minimal PerfDatabase instance just to call query_gemm with SOL_FULL
477+ db = PerfDatabase .__new__ (PerfDatabase )
478+ db .system_spec = system_spec
479+
480+ sol_time , sol_math , sol_mem = db .query_gemm (m , n , k , quant_mode , sol_mode = common .SOLMode .SOL_FULL )
481+ return sol_math > sol_mem
482+
483+
484+ def is_context_attention_compute_bound_collector (
485+ b : int ,
486+ s : int ,
487+ num_heads : int ,
488+ num_key_value_heads : int ,
489+ head_dim : int ,
490+ dtype : str ,
491+ kv_cache_dtype : str ,
492+ use_fp8_kv_cache : bool ,
493+ use_fp8_context_fmha : bool ,
494+ device_name : str ,
495+ attention_window_size : int = 0 ,
496+ ) -> bool :
497+ """
498+ Determine if context (prefill) attention is compute-bound.
499+ Wrapper for use in collectors.
500+
501+ Args:
502+ b: Batch size
503+ s: Sequence length (input)
504+ num_heads: Number of query heads (H_q)
505+ num_key_value_heads: Number of key/value heads (H_kv)
506+ head_dim: Head dimension
507+ dtype: Activation dtype
508+ kv_cache_dtype: KV cache dtype
509+ use_fp8_kv_cache: Whether using FP8 for KV cache
510+ use_fp8_context_fmha: Whether using FP8 for context FMHA
511+ device_name: GPU device name
512+ attention_window_size: Attention window size
513+
514+ Returns:
515+ True if compute-bound, False if memory-bound
516+ """
517+ from aiconfigurator .sdk import common
518+ from aiconfigurator .sdk .perf_database import PerfDatabase
519+
520+ system_spec = get_system_spec_from_device (device_name )
521+ kvcache_quant_mode = _get_kvcache_quant_mode (kv_cache_dtype , use_fp8_kv_cache )
522+ fmha_quant_mode = _get_fmha_quant_mode (dtype , use_fp8_context_fmha )
523+
524+ # Create minimal PerfDatabase instance just to call query_context_attention with SOL_FULL
525+ db = PerfDatabase .__new__ (PerfDatabase )
526+ db .system_spec = system_spec
527+
528+ sol_time , sol_math , sol_mem = db .query_context_attention (
529+ b , s , num_heads , num_key_value_heads ,
530+ kvcache_quant_mode , fmha_quant_mode ,
531+ sol_mode = common .SOLMode .SOL_FULL ,
532+ window_size = attention_window_size ,
533+ head_size = head_dim
534+ )
535+ return sol_math > sol_mem
536+
537+
538+ def is_generation_attention_compute_bound_collector () -> bool :
539+ """
540+ Determine if generation (decode) attention is compute-bound.
541+ Generation attention is ALWAYS memory-bound.
542+
543+ Returns:
544+ False (always memory-bound)
545+ """
546+ return False
0 commit comments