Skip to content

Commit 3b906c6

Browse files
committed
feat(BA-2851): Add resource isolation options for multi-agent setup
This change implements configuration for partitioning resources. SHARED mode allows all agents to see full resources (useful for stress testing). This is the same behavior as before. AUTO_SPLIT automatically divides resources equally among agents. MANUAL mode lets users specify exact per-agent allocations for all resources. Single-agent deployments remain unaffected and retain access to all available hardware resources.
1 parent 6d2e3d2 commit 3b906c6

File tree

10 files changed

+922
-19
lines changed

10 files changed

+922
-19
lines changed

changes/6498.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add resource isolation options for multi-agent setup

src/ai/backend/agent/agent.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@
238238
ComputerContext,
239239
KernelResourceSpec,
240240
Mount,
241+
ResourcePartitioner,
241242
align_memory,
242243
allocate,
243244
known_slot_types,
@@ -765,7 +766,10 @@ class AbstractAgent(
765766
etcd: AsyncEtcd
766767
local_instance_id: str
767768
kernel_registry: MutableMapping[KernelId, AbstractKernel]
769+
resource_partitioner: ResourcePartitioner
768770
computers: MutableMapping[DeviceName, ComputerContext]
771+
total_slots: Mapping[SlotName, Decimal]
772+
reserved_slots: Mapping[SlotName, Decimal]
769773
images: Mapping[ImageCanonical, ScannedImage]
770774
port_pool: set[int]
771775

@@ -836,6 +840,7 @@ def __init__(
836840
error_monitor: ErrorPluginContext,
837841
skip_initial_scan: bool = False,
838842
agent_public_key: Optional[PublicKey],
843+
resource_partitioner: ResourcePartitioner,
839844
) -> None:
840845
self._skip_initial_scan = skip_initial_scan
841846
self.loop = current_loop()
@@ -845,7 +850,10 @@ def __init__(
845850
self.local_instance_id = generate_local_instance_id(__file__)
846851
self.agent_public_key = agent_public_key
847852
self.kernel_registry = {}
853+
self.resource_partitioner = resource_partitioner
848854
self.computers = {}
855+
self.total_slots = {}
856+
self.reserved_slots = {}
849857
self.images = {}
850858
self.restarting_kernels = {}
851859
self.stat_ctx = StatContext(
@@ -941,6 +949,12 @@ async def __ainit__(self) -> None:
941949
self.computers[name] = ComputerContext(computer, devices, alloc_map)
942950
metadatas.append(computer.get_metadata())
943951

952+
self.total_slots = self.resource_partitioner.calculate_total_slots(
953+
self.computers, self.local_config.resource_common
954+
)
955+
self.reserved_slots = self.resource_partitioner.restrict_computer_resources(
956+
self.computers, self.total_slots
957+
)
944958
self.slots = await self.update_slots()
945959
log.info("Resource slots: {!r}", self.slots)
946960
log.info("Slot types: {!r}", known_slot_types)
@@ -1965,14 +1979,9 @@ async def update_slots(
19651979
"""
19661980
scanned_slots = await self.scan_available_resources()
19671981
usable_slots: dict[SlotName, Decimal] = {}
1968-
reserved_slots = {
1969-
SlotName("cpu"): Decimal(self.local_config.resource.reserved_cpu),
1970-
SlotName("mem"): Decimal(self.local_config.resource.reserved_mem),
1971-
SlotName("disk"): Decimal(self.local_config.resource.reserved_disk),
1972-
}
19731982
for slot_name, slot_capacity in scanned_slots.items():
19741983
if slot_name == SlotName("mem"):
1975-
mem_reserved = int(reserved_slots.get(slot_name, 0))
1984+
mem_reserved = int(self.reserved_slots.get(slot_name, 0))
19761985
mem_align = int(self.local_config.resource.memory_align_size)
19771986
mem_usable, mem_reserved = align_memory(
19781987
int(slot_capacity), mem_reserved, align=mem_align
@@ -1986,7 +1995,7 @@ async def update_slots(
19861995
)
19871996
else:
19881997
usable_capacity = max(
1989-
Decimal(0), slot_capacity - reserved_slots.get(slot_name, Decimal(0))
1998+
Decimal(0), slot_capacity - self.reserved_slots.get(slot_name, Decimal(0))
19901999
)
19912000
usable_slots[slot_name] = usable_capacity
19922001
return usable_slots
@@ -2267,7 +2276,6 @@ async def check_image(
22672276
Check the availability of the image and return a boolean flag that indicates whether
22682277
the agent should try pulling the image from a registry.
22692278
"""
2270-
return False
22712279

22722280
async def scan_running_kernels(self) -> None:
22732281
"""

src/ai/backend/agent/alloc_map.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,17 @@ def update_affinity_hint(
236236
hint_for_next_allocation.append(dev)
237237
affinity_hint.devices = hint_for_next_allocation
238238

239+
@final
240+
def update_device_slot_amounts(self, slot_amounts: Mapping[SlotName, Decimal]) -> None:
241+
self.device_slots = {
242+
device_id: DeviceSlotInfo(
243+
slot_type=slot_info.slot_type,
244+
slot_name=slot_info.slot_name,
245+
amount=slot_amounts[slot_info.slot_name],
246+
)
247+
for device_id, slot_info in self.device_slots.items()
248+
}
249+
239250
@abstractmethod
240251
def allocate(
241252
self,

src/ai/backend/agent/docker/agent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
ComputerContext,
117117
KernelResourceSpec,
118118
Mount,
119+
ResourcePartitioner,
119120
known_slot_types,
120121
)
121122
from ..scratch import create_loop_filesystem, destroy_loop_filesystem
@@ -1315,6 +1316,7 @@ def __init__(
13151316
skip_initial_scan: bool = False,
13161317
agent_public_key: Optional[PublicKey],
13171318
metadata_server: MetadataServer,
1319+
resource_partitioner: ResourcePartitioner,
13181320
) -> None:
13191321
super().__init__(
13201322
etcd,
@@ -1323,6 +1325,7 @@ def __init__(
13231325
error_monitor=error_monitor,
13241326
skip_initial_scan=skip_initial_scan,
13251327
agent_public_key=agent_public_key,
1328+
resource_partitioner=resource_partitioner,
13261329
)
13271330
self.checked_invalid_images = set()
13281331
self.metadata_server = metadata_server

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
ComputerContext,
7575
KernelResourceSpec,
7676
Mount,
77+
ResourcePartitioner,
7778
known_slot_types,
7879
)
7980
from ..types import Container, KernelOwnershipData, MountInfo, Port
@@ -829,6 +830,7 @@ def __init__(
829830
error_monitor: ErrorPluginContext,
830831
skip_initial_scan: bool = False,
831832
agent_public_key: Optional[PublicKey],
833+
resource_partitioner: ResourcePartitioner,
832834
) -> None:
833835
super().__init__(
834836
etcd,
@@ -837,6 +839,7 @@ def __init__(
837839
error_monitor=error_monitor,
838840
skip_initial_scan=skip_initial_scan,
839841
agent_public_key=agent_public_key,
842+
resource_partitioner=resource_partitioner,
840843
)
841844

842845
async def __ainit__(self) -> None:

src/ai/backend/agent/resources.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
import aiodocker
3030
import attrs
3131

32+
from ai.backend.agent.config.unified import (
33+
CommonResourceConfig,
34+
ResourceAllocationMode,
35+
ResourceConfig,
36+
)
3237
from ai.backend.common.json import dump_json_str, load_json
3338
from ai.backend.common.plugin import AbstractPlugin, BasePluginContext
3439
from ai.backend.common.types import (
@@ -71,6 +76,17 @@
7176
known_slot_types: Mapping[SlotName, SlotTypes] = {}
7277

7378

79+
def _combine_mappings(mappings: list[Mapping[SlotName, Decimal]]) -> dict[SlotName, Decimal]:
80+
combined: dict[SlotName, Decimal] = {}
81+
for mapping in mappings:
82+
if set(combined.keys()) & set(mapping.keys()):
83+
raise ValueError(
84+
f"Duplicate keys found in devices: {combined.keys()} and {mapping.keys()}"
85+
)
86+
combined = {**combined, **mapping}
87+
return combined
88+
89+
7490
@attrs.define(auto_attribs=True, slots=True)
7591
class ComputerContext:
7692
instance: AbstractComputePlugin
@@ -444,6 +460,160 @@ def get_additional_allowed_syscalls(self) -> list[str]:
444460
return []
445461

446462

463+
class ResourcePartitioner:
464+
def __init__(
465+
self,
466+
resource_config: ResourceConfig,
467+
num_agents: int,
468+
agent_idx: int,
469+
) -> None:
470+
self.resource_config = resource_config
471+
self.num_agents = num_agents
472+
self.agent_idx = agent_idx
473+
self.resource_scaling_factor: Mapping[SlotName, Decimal] = {}
474+
475+
@staticmethod
476+
def calculate_total_slots(
477+
computers: Mapping[DeviceName, ComputerContext],
478+
resource_config: CommonResourceConfig,
479+
deduct_reserved: bool = False,
480+
) -> dict[SlotName, Decimal]:
481+
total_slots: dict[SlotName, Decimal] = defaultdict(lambda: Decimal("0"))
482+
for device in computers.values():
483+
for slot_info in device.alloc_map.device_slots.values():
484+
total_slots[slot_info.slot_name] += slot_info.amount
485+
if deduct_reserved:
486+
return ResourcePartitioner.deduct_reserved_resources(total_slots, resource_config)
487+
else:
488+
return total_slots
489+
490+
@staticmethod
491+
def deduct_reserved_resources(
492+
total_slots: Mapping[SlotName, Decimal],
493+
resource_config: CommonResourceConfig,
494+
) -> dict[SlotName, Decimal]:
495+
reserved_resources = {
496+
SlotName("cpu"): Decimal(resource_config.reserved_cpu),
497+
SlotName("mem"): Decimal(resource_config.reserved_mem),
498+
}
499+
500+
slots: dict[SlotName, Decimal] = {}
501+
for slot_name, slot in total_slots.items():
502+
slots[slot_name] = slot - reserved_resources.get(slot_name, Decimal("0"))
503+
return slots
504+
505+
def restrict_computer_resources(
506+
self,
507+
computers: MutableMapping[DeviceName, ComputerContext],
508+
total_slots: Mapping[SlotName, Decimal],
509+
) -> dict[SlotName, Decimal]:
510+
devices_allocated_slots: list[Mapping[SlotName, Decimal]] = []
511+
devices_reserved_slots: list[Mapping[SlotName, Decimal]] = []
512+
for device in computers.values():
513+
device_allocated_slots = self._calculate_device_slots(device.alloc_map, total_slots)
514+
device.alloc_map.update_device_slot_amounts(device_allocated_slots)
515+
devices_allocated_slots.append(device_allocated_slots)
516+
517+
device_reserved_slots = self._calculate_reserved_slots(
518+
device_allocated_slots, total_slots
519+
)
520+
devices_reserved_slots.append(device_reserved_slots)
521+
522+
allocated_slots = _combine_mappings(devices_allocated_slots)
523+
self.resource_scaling_factor = self._calculate_resource_scaling_factor(
524+
allocated_slots, total_slots
525+
)
526+
527+
reserved_slots = _combine_mappings(devices_reserved_slots)
528+
return reserved_slots
529+
530+
def get_resource_scaling_factor(self, slot_name: SlotName) -> Decimal:
531+
return self.resource_scaling_factor[slot_name]
532+
533+
def _calculate_device_slots(
534+
self,
535+
alloc_map: AbstractAllocMap,
536+
total_slots: Mapping[SlotName, Decimal],
537+
) -> dict[SlotName, Decimal]:
538+
total_slots_no_reserved = ResourcePartitioner.deduct_reserved_resources(
539+
total_slots, self.resource_config
540+
)
541+
return {
542+
device_slot.slot_name: self._calculate_device_slot(
543+
device_slot.slot_name,
544+
total_slots_no_reserved[device_slot.slot_name],
545+
type(alloc_map),
546+
)
547+
for device_slot in alloc_map.device_slots.values()
548+
}
549+
550+
def _calculate_device_slot(
551+
self,
552+
slot_name: SlotName,
553+
total_slot: Decimal,
554+
alloc_map_type: Type[AbstractAllocMap],
555+
) -> Decimal:
556+
match self.resource_config.allocation_mode:
557+
case ResourceAllocationMode.SHARED:
558+
return total_slot
559+
case ResourceAllocationMode.AUTO_SPLIT:
560+
if alloc_map_type is DiscretePropertyAllocMap:
561+
slot, slot_extra = divmod(total_slot, self.num_agents)
562+
remainder_value = 1 if self.agent_idx < slot_extra else 0
563+
return slot + remainder_value
564+
elif alloc_map_type is FractionAllocMap:
565+
return total_slot / self.num_agents
566+
else:
567+
raise NotImplementedError(
568+
f"Unrecognized AbstractAllocMap type {alloc_map_type}"
569+
)
570+
case ResourceAllocationMode.MANUAL:
571+
match slot_name:
572+
case "cpu":
573+
assert self.resource_config.allocated_cpu is not None
574+
return Decimal(self.resource_config.allocated_cpu)
575+
case "mem":
576+
assert self.resource_config.allocated_mem is not None
577+
return Decimal(self.resource_config.allocated_mem)
578+
case slot_name:
579+
if slot_name not in self.resource_config.allocated_devices:
580+
raise ValueError(
581+
f"{slot_name=} not found in config {self.resource_config.allocated_devices!r}"
582+
)
583+
return self.resource_config.allocated_devices[slot_name]
584+
585+
def _calculate_reserved_slots(
586+
self,
587+
device_slots: Mapping[SlotName, Decimal],
588+
total_slots: Mapping[SlotName, Decimal],
589+
) -> dict[SlotName, Decimal]:
590+
reserved_slots: dict[SlotName, Decimal] = {}
591+
for slot_name, slot in device_slots.items():
592+
reserved_slots[slot_name] = max(total_slots[slot_name] - slot, Decimal(0))
593+
return reserved_slots
594+
595+
def _calculate_resource_scaling_factor(
596+
self,
597+
allocated_slots: Mapping[SlotName, Decimal],
598+
total_slots: Mapping[SlotName, Decimal],
599+
) -> dict[SlotName, Decimal]:
600+
match self.resource_config.allocation_mode:
601+
case ResourceAllocationMode.SHARED:
602+
return defaultdict(lambda: Decimal(1.0))
603+
case ResourceAllocationMode.AUTO_SPLIT:
604+
return defaultdict(lambda: Decimal(1.0) / Decimal(self.num_agents))
605+
case ResourceAllocationMode.MANUAL:
606+
if SlotName("cpu") not in allocated_slots or SlotName("cpu") not in total_slots:
607+
raise ValueError("CPU not in allocated or total slots seen")
608+
if SlotName("mem") not in allocated_slots or SlotName("mem") not in total_slots:
609+
raise ValueError("Memory not in allocated or total slots seen")
610+
scaling_factor = {
611+
slot_name: slot / total_slots[slot_name]
612+
for slot_name, slot in allocated_slots.items()
613+
}
614+
return scaling_factor
615+
616+
447617
class ComputePluginContext(BasePluginContext[AbstractComputePlugin]):
448618
plugin_group = "backendai_accelerator_v21"
449619

0 commit comments

Comments
 (0)