Skip to content

Commit c1295dc

Browse files
kaim-engjaywonchung
authored andcommitted
replace multi-line if-else with dict
1 parent 84e2d66 commit c1295dc

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

collector/helper.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -294,28 +294,41 @@ def get_dtype_size(dtype: str) -> float:
294294
return DTYPE_SIZES[dtype_lower]
295295

296296

297+
def _get_system_file_for_device(device_name: str) -> str:
298+
"""Map GPU device name to system YAML filename.
299+
300+
Args:
301+
device_name: GPU device name
302+
303+
Returns:
304+
System YAML filename
305+
306+
Raises:
307+
ValueError: If GPU is not supported
308+
"""
309+
device_upper = device_name.upper()
310+
gpu_mappings = {
311+
"H200": "h200_sxm.yaml",
312+
"H100": "h100_sxm.yaml",
313+
"A100": "a100_sxm.yaml",
314+
"GB200": "gb200_sxm.yaml", # Check GB200 before B200
315+
"B200": "b200_sxm.yaml",
316+
}
317+
318+
for prefix, filename in gpu_mappings.items():
319+
if prefix in device_upper:
320+
return filename
321+
322+
raise ValueError(f"Unsupported GPU: {device_name}")
323+
324+
297325
def get_gpu_specs_from_device(device_name: str) -> dict:
298326
"""Load GPU specifications from system YAML files.
299327
300328
Dictionary keys are float16_tflops, fp8_tflops, int8_tflops, mem_bw_gbs, power_max.
301329
Keys follow system YAML files, except for power_max (which is just 'power' in YAML).
302330
"""
303-
# Map device name to system file
304-
device_upper = device_name.upper()
305-
if "H100" in device_upper:
306-
system_file = "h100_sxm.yaml"
307-
elif "H200" in device_upper:
308-
system_file = "h200_sxm.yaml"
309-
elif "A100" in device_upper:
310-
system_file = "a100_sxm.yaml"
311-
elif "B200" in device_upper:
312-
system_file = "b200_sxm.yaml"
313-
elif "GB200" in device_upper:
314-
system_file = "gb200_sxm.yaml"
315-
else:
316-
raise ValueError(f"Unsupported GPU: {device_name}")
317-
318-
# Load system YAML
331+
system_file = _get_system_file_for_device(device_name)
319332
systems_dir = pkg_resources.files("aiconfigurator") / "systems"
320333
yaml_path = systems_dir / system_file
321334

@@ -394,20 +407,7 @@ def get_system_spec_from_device(device_name: str) -> dict:
394407
Returns:
395408
Full system_spec dict with 'gpu' key
396409
"""
397-
device_upper = device_name.upper()
398-
if "H100" in device_upper:
399-
system_file = "h100_sxm.yaml"
400-
elif "H200" in device_upper:
401-
system_file = "h200_sxm.yaml"
402-
elif "A100" in device_upper:
403-
system_file = "a100_sxm.yaml"
404-
elif "B200" in device_upper:
405-
system_file = "b200_sxm.yaml"
406-
elif "GB200" in device_upper:
407-
system_file = "gb200_sxm.yaml"
408-
else:
409-
raise ValueError(f"Unsupported GPU: {device_name}")
410-
410+
system_file = _get_system_file_for_device(device_name)
411411
systems_dir = pkg_resources.files("aiconfigurator") / "systems"
412412
yaml_path = systems_dir / system_file
413413

0 commit comments

Comments
 (0)