|
11 | 11 | import os |
12 | 12 | import sys |
13 | 13 | import textwrap |
| 14 | +from decimal import Decimal |
14 | 15 | from pathlib import Path |
15 | 16 | from typing import ( |
16 | 17 | Any, |
|
51 | 52 | BinarySizeField, |
52 | 53 | ResourceGroupType, |
53 | 54 | ServiceDiscoveryType, |
| 55 | + SlotName, |
54 | 56 | ) |
55 | 57 | from ai.backend.logging import BraceStyleAdapter |
56 | 58 | from ai.backend.logging.config import LoggingConfig |
@@ -81,6 +83,12 @@ class ScratchType(enum.StrEnum): |
81 | 83 | K8S_NFS = "k8s-nfs" |
82 | 84 |
|
83 | 85 |
|
| 86 | +class ResourceAllocationMode(enum.StrEnum): |
| 87 | + SHARED = "shared" |
| 88 | + AUTO_SPLIT = "auto-split" |
| 89 | + MANUAL = "manual" |
| 90 | + |
| 91 | + |
84 | 92 | class AgentConfigValidationContext(BaseConfigValidationContext): |
85 | 93 | is_invoked_subcommand: bool |
86 | 94 |
|
@@ -863,7 +871,7 @@ class ContainerConfig(CommonContainerConfig, OverridableContainerConfig): |
863 | 871 | pass |
864 | 872 |
|
865 | 873 |
|
866 | | -class ResourceConfig(BaseConfigSchema): |
| 874 | +class CommonResourceConfig(BaseConfigSchema): |
867 | 875 | reserved_cpu: int = Field( |
868 | 876 | default=1, |
869 | 877 | description="The number of CPU cores reserved for the operating system and the agent service.", |
@@ -895,6 +903,18 @@ class ResourceConfig(BaseConfigSchema): |
895 | 903 | validation_alias=AliasChoices("reserved-disk", "reserved_disk"), |
896 | 904 | serialization_alias="reserved-disk", |
897 | 905 | ) |
| 906 | + allocation_mode: ResourceAllocationMode = Field( |
| 907 | + default=ResourceAllocationMode.SHARED, |
| 908 | + description=textwrap.dedent(""" |
| 909 | + Resource allocation mode for multi-agent scenarios. |
| 910 | + - `shared`: All agents share the full resource pool (default, backward compatible). |
| 911 | + - `auto-split`: Automatically divide resources equally (1/N) among all agents. |
| 912 | + - `manual`: Manually specify per-agent resource allocations via config. |
| 913 | + """), |
| 914 | + examples=[item.value for item in ResourceAllocationMode], |
| 915 | + validation_alias=AliasChoices("allocation-mode", "allocation_mode"), |
| 916 | + serialization_alias="allocation-mode", |
| 917 | + ) |
898 | 918 | memory_align_size: BinarySizeField = Field( |
899 | 919 | default=BinarySize.finite_from_str("16M"), |
900 | 920 | description=( |
@@ -937,6 +957,64 @@ def _parse_affinity_policy(cls, v: Any) -> AffinityPolicy: |
937 | 957 | return v |
938 | 958 |
|
939 | 959 |
|
| 960 | +class OverridableResourceConfig(BaseConfigSchema): |
| 961 | + allocated_cpu: Optional[int] = Field( |
| 962 | + default=None, |
| 963 | + description=textwrap.dedent(""" |
| 964 | + Hard CPU allocation for this agent (e.g., 8 cores). |
| 965 | + Only used in MANUAL allocation mode. |
| 966 | + All agents must specify this value when allocation-mode is MANUAL. |
| 967 | + """), |
| 968 | + examples=[8, 16], |
| 969 | + validation_alias=AliasChoices("allocated-cpu", "allocated_cpu"), |
| 970 | + serialization_alias="allocated-cpu", |
| 971 | + ) |
| 972 | + allocated_mem: Optional[BinarySizeField] = Field( |
| 973 | + default=None, |
| 974 | + description=textwrap.dedent(""" |
| 975 | + Hard memory allocation for this agent (e.g., "32G"). |
| 976 | + Only used in MANUAL allocation mode. |
| 977 | + All agents must specify this value when allocation-mode is MANUAL. |
| 978 | + """), |
| 979 | + examples=["32G", "64G"], |
| 980 | + validation_alias=AliasChoices("allocated-mem", "allocated_mem"), |
| 981 | + serialization_alias="allocated-mem", |
| 982 | + ) |
| 983 | + allocated_devices: Mapping[SlotName, Decimal] = Field( |
| 984 | + default_factory=dict, |
| 985 | + description=textwrap.dedent(""" |
| 986 | + Device-specific per-slot resource allocations. |
| 987 | + Only used in MANUAL allocation mode. |
| 988 | + """), |
| 989 | + examples=[{"cuda.mem": "0.3", "cuda.shares": "0.5"}], |
| 990 | + validation_alias=AliasChoices("allocated-devices", "allocated_devices"), |
| 991 | + serialization_alias="allocated-devices", |
| 992 | + ) |
| 993 | + |
| 994 | + model_config = ConfigDict( |
| 995 | + extra="allow", |
| 996 | + arbitrary_types_allowed=True, |
| 997 | + ) |
| 998 | + |
| 999 | + @model_validator(mode="after") |
| 1000 | + def validate_values_are_positive(self) -> Self: |
| 1001 | + if self.allocated_cpu is not None and self.allocated_cpu < 0: |
| 1002 | + raise ValueError( |
| 1003 | + f"Allocated cpu must not be a negative value, but given {self.allocated_cpu}" |
| 1004 | + ) |
| 1005 | + if self.allocated_mem is not None and self.allocated_mem < 0: |
| 1006 | + raise ValueError( |
| 1007 | + f"Allocated mem must not be a negative value, but given {self.allocated_mem}" |
| 1008 | + ) |
| 1009 | + if any(value < 0 for value in self.allocated_devices.values()): |
| 1010 | + raise ValueError("All allocated device resource values must not be a negative value") |
| 1011 | + return self |
| 1012 | + |
| 1013 | + |
| 1014 | +class ResourceConfig(CommonResourceConfig, OverridableResourceConfig): |
| 1015 | + pass |
| 1016 | + |
| 1017 | + |
940 | 1018 | class EtcdConfig(BaseConfigSchema): |
941 | 1019 | namespace: str = Field( |
942 | 1020 | description="Etcd namespace", |
@@ -1166,7 +1244,7 @@ class AgentOverrideConfig(BaseConfigSchema): |
1166 | 1244 | default=None, |
1167 | 1245 | description="Container config overrides for the individual agent", |
1168 | 1246 | ) |
1169 | | - resource: ResourceConfig | None = Field( |
| 1247 | + resource: OverridableResourceConfig | None = Field( |
1170 | 1248 | default=None, |
1171 | 1249 | description="Resource config overrides for the individual agent", |
1172 | 1250 | ) |
@@ -1229,6 +1307,10 @@ def agent_configs(self) -> Sequence[AgentUnifiedConfig]: |
1229 | 1307 | def agent_ids(self) -> Sequence[AgentId]: |
1230 | 1308 | return [AgentId(agent_config.agent.id) for agent_config in self.agent_configs] |
1231 | 1309 |
|
| 1310 | + @property |
| 1311 | + def resource_common(self) -> CommonResourceConfig: |
| 1312 | + return self.resource |
| 1313 | + |
1232 | 1314 | def with_updates( |
1233 | 1315 | self, |
1234 | 1316 | *, |
@@ -1312,6 +1394,43 @@ def validate(config: AgentSpecificConfig) -> None: |
1312 | 1394 | self._for_each_agent(validate) |
1313 | 1395 | return self |
1314 | 1396 |
|
| 1397 | + @model_validator(mode="after") |
| 1398 | + def _validate_resource_allocation_mode(self) -> Self: |
| 1399 | + def validate_manual_resource_not_specified(config: AgentSpecificConfig) -> None: |
| 1400 | + resource = config.resource |
| 1401 | + if any([ |
| 1402 | + resource.allocated_cpu is not None, |
| 1403 | + resource.allocated_mem is not None, |
| 1404 | + resource.allocated_devices, |
| 1405 | + ]): |
| 1406 | + raise ValueError( |
| 1407 | + "On non-MANUAL mode, config must not specify manual resource allocations" |
| 1408 | + ) |
| 1409 | + |
| 1410 | + def validate_mandatory_manual_resource_specified(config: AgentSpecificConfig) -> None: |
| 1411 | + resource = config.resource |
| 1412 | + if any([ |
| 1413 | + resource.allocated_cpu is None, |
| 1414 | + resource.allocated_mem is None, |
| 1415 | + ]): |
| 1416 | + raise ValueError( |
| 1417 | + "On MANUAL mode, config must specify cpu and mem resource allocations" |
| 1418 | + ) |
| 1419 | + |
| 1420 | + match self.resource.allocation_mode: |
| 1421 | + case ResourceAllocationMode.SHARED | ResourceAllocationMode.AUTO_SPLIT: |
| 1422 | + self._for_each_agent(validate_manual_resource_not_specified) |
| 1423 | + case ResourceAllocationMode.MANUAL: |
| 1424 | + self._for_each_agent(validate_mandatory_manual_resource_specified) |
| 1425 | + |
| 1426 | + slot_names = self._for_each_agent( |
| 1427 | + lambda config: set(config.resource.allocated_devices.keys()) |
| 1428 | + ) |
| 1429 | + if not all(slot_name == slot_names[0] for slot_name in slot_names): |
| 1430 | + raise ValueError("All agents must have the same slots defined in the devices!") |
| 1431 | + |
| 1432 | + return self |
| 1433 | + |
1315 | 1434 | def _for_each_agent(self, func: Callable[[AgentUnifiedConfig], R]) -> list[R]: |
1316 | 1435 | agents = [agent.construct_unified_config(default=self) for agent in self.agents] |
1317 | 1436 | if not agents: |
|
0 commit comments