|
29 | 29 | import aiodocker |
30 | 30 | import attrs |
31 | 31 |
|
| 32 | +from ai.backend.agent.config.unified import ( |
| 33 | + CommonResourceConfig, |
| 34 | + ResourceAllocationMode, |
| 35 | + ResourceConfig, |
| 36 | +) |
32 | 37 | from ai.backend.common.json import dump_json_str, load_json |
33 | 38 | from ai.backend.common.plugin import AbstractPlugin, BasePluginContext |
34 | 39 | from ai.backend.common.types import ( |
|
71 | 76 | known_slot_types: Mapping[SlotName, SlotTypes] = {} |
72 | 77 |
|
73 | 78 |
|
| 79 | +def _combine_mappings(mappings: list[Mapping[SlotName, Decimal]]) -> dict[SlotName, Decimal]: |
| 80 | + combined: dict[SlotName, Decimal] = {} |
| 81 | + for mapping in mappings: |
| 82 | + if set(combined.keys()) & set(mapping.keys()): |
| 83 | + raise ValueError( |
| 84 | + f"Duplicate keys found in devices: {combined.keys()} and {mapping.keys()}" |
| 85 | + ) |
| 86 | + combined = {**combined, **mapping} |
| 87 | + return combined |
| 88 | + |
| 89 | + |
74 | 90 | @attrs.define(auto_attribs=True, slots=True) |
75 | 91 | class ComputerContext: |
76 | 92 | instance: AbstractComputePlugin |
@@ -444,6 +460,160 @@ def get_additional_allowed_syscalls(self) -> list[str]: |
444 | 460 | return [] |
445 | 461 |
|
446 | 462 |
|
| 463 | +class ResourcePartitioner: |
| 464 | + def __init__( |
| 465 | + self, |
| 466 | + resource_config: ResourceConfig, |
| 467 | + num_agents: int, |
| 468 | + agent_idx: int, |
| 469 | + ) -> None: |
| 470 | + self.resource_config = resource_config |
| 471 | + self.num_agents = num_agents |
| 472 | + self.agent_idx = agent_idx |
| 473 | + self.resource_scaling_factor: Mapping[SlotName, Decimal] = {} |
| 474 | + |
| 475 | + @staticmethod |
| 476 | + def calculate_total_slots( |
| 477 | + computers: Mapping[DeviceName, ComputerContext], |
| 478 | + resource_config: CommonResourceConfig, |
| 479 | + deduct_reserved: bool = False, |
| 480 | + ) -> dict[SlotName, Decimal]: |
| 481 | + total_slots: dict[SlotName, Decimal] = defaultdict(lambda: Decimal("0")) |
| 482 | + for device in computers.values(): |
| 483 | + for slot_info in device.alloc_map.device_slots.values(): |
| 484 | + total_slots[slot_info.slot_name] += slot_info.amount |
| 485 | + if deduct_reserved: |
| 486 | + return ResourcePartitioner.deduct_reserved_resources(total_slots, resource_config) |
| 487 | + else: |
| 488 | + return total_slots |
| 489 | + |
| 490 | + @staticmethod |
| 491 | + def deduct_reserved_resources( |
| 492 | + total_slots: Mapping[SlotName, Decimal], |
| 493 | + resource_config: CommonResourceConfig, |
| 494 | + ) -> dict[SlotName, Decimal]: |
| 495 | + reserved_resources = { |
| 496 | + SlotName("cpu"): Decimal(resource_config.reserved_cpu), |
| 497 | + SlotName("mem"): Decimal(resource_config.reserved_mem), |
| 498 | + } |
| 499 | + |
| 500 | + slots: dict[SlotName, Decimal] = {} |
| 501 | + for slot_name, slot in total_slots.items(): |
| 502 | + slots[slot_name] = slot - reserved_resources.get(slot_name, Decimal("0")) |
| 503 | + return slots |
| 504 | + |
| 505 | + def restrict_computer_resources( |
| 506 | + self, |
| 507 | + computers: MutableMapping[DeviceName, ComputerContext], |
| 508 | + total_slots: Mapping[SlotName, Decimal], |
| 509 | + ) -> dict[SlotName, Decimal]: |
| 510 | + devices_allocated_slots: list[Mapping[SlotName, Decimal]] = [] |
| 511 | + devices_reserved_slots: list[Mapping[SlotName, Decimal]] = [] |
| 512 | + for device in computers.values(): |
| 513 | + device_allocated_slots = self._calculate_device_slots(device.alloc_map, total_slots) |
| 514 | + device.alloc_map.update_device_slot_amounts(device_allocated_slots) |
| 515 | + devices_allocated_slots.append(device_allocated_slots) |
| 516 | + |
| 517 | + device_reserved_slots = self._calculate_reserved_slots( |
| 518 | + device_allocated_slots, total_slots |
| 519 | + ) |
| 520 | + devices_reserved_slots.append(device_reserved_slots) |
| 521 | + |
| 522 | + allocated_slots = _combine_mappings(devices_allocated_slots) |
| 523 | + self.resource_scaling_factor = self._calculate_resource_scaling_factor( |
| 524 | + allocated_slots, total_slots |
| 525 | + ) |
| 526 | + |
| 527 | + reserved_slots = _combine_mappings(devices_reserved_slots) |
| 528 | + return reserved_slots |
| 529 | + |
| 530 | + def get_resource_scaling_factor(self, slot_name: SlotName) -> Decimal: |
| 531 | + return self.resource_scaling_factor[slot_name] |
| 532 | + |
| 533 | + def _calculate_device_slots( |
| 534 | + self, |
| 535 | + alloc_map: AbstractAllocMap, |
| 536 | + total_slots: Mapping[SlotName, Decimal], |
| 537 | + ) -> dict[SlotName, Decimal]: |
| 538 | + total_slots_no_reserved = ResourcePartitioner.deduct_reserved_resources( |
| 539 | + total_slots, self.resource_config |
| 540 | + ) |
| 541 | + return { |
| 542 | + device_slot.slot_name: self._calculate_device_slot( |
| 543 | + device_slot.slot_name, |
| 544 | + total_slots_no_reserved[device_slot.slot_name], |
| 545 | + type(alloc_map), |
| 546 | + ) |
| 547 | + for device_slot in alloc_map.device_slots.values() |
| 548 | + } |
| 549 | + |
| 550 | + def _calculate_device_slot( |
| 551 | + self, |
| 552 | + slot_name: SlotName, |
| 553 | + total_slot: Decimal, |
| 554 | + alloc_map_type: Type[AbstractAllocMap], |
| 555 | + ) -> Decimal: |
| 556 | + match self.resource_config.allocation_mode: |
| 557 | + case ResourceAllocationMode.SHARED: |
| 558 | + return total_slot |
| 559 | + case ResourceAllocationMode.AUTO_SPLIT: |
| 560 | + if alloc_map_type is DiscretePropertyAllocMap: |
| 561 | + slot, slot_extra = divmod(total_slot, self.num_agents) |
| 562 | + remainder_value = 1 if self.agent_idx < slot_extra else 0 |
| 563 | + return slot + remainder_value |
| 564 | + elif alloc_map_type is FractionAllocMap: |
| 565 | + return total_slot / self.num_agents |
| 566 | + else: |
| 567 | + raise NotImplementedError( |
| 568 | + f"Unrecognized AbstractAllocMap type {alloc_map_type}" |
| 569 | + ) |
| 570 | + case ResourceAllocationMode.MANUAL: |
| 571 | + match slot_name: |
| 572 | + case "cpu": |
| 573 | + assert self.resource_config.allocated_cpu is not None |
| 574 | + return Decimal(self.resource_config.allocated_cpu) |
| 575 | + case "mem": |
| 576 | + assert self.resource_config.allocated_mem is not None |
| 577 | + return Decimal(self.resource_config.allocated_mem) |
| 578 | + case slot_name: |
| 579 | + if slot_name not in self.resource_config.allocated_devices: |
| 580 | + raise ValueError( |
| 581 | + f"{slot_name=} not found in config {self.resource_config.allocated_devices!r}" |
| 582 | + ) |
| 583 | + return self.resource_config.allocated_devices[slot_name] |
| 584 | + |
| 585 | + def _calculate_reserved_slots( |
| 586 | + self, |
| 587 | + device_slots: Mapping[SlotName, Decimal], |
| 588 | + total_slots: Mapping[SlotName, Decimal], |
| 589 | + ) -> dict[SlotName, Decimal]: |
| 590 | + reserved_slots: dict[SlotName, Decimal] = {} |
| 591 | + for slot_name, slot in device_slots.items(): |
| 592 | + reserved_slots[slot_name] = max(total_slots[slot_name] - slot, Decimal(0)) |
| 593 | + return reserved_slots |
| 594 | + |
| 595 | + def _calculate_resource_scaling_factor( |
| 596 | + self, |
| 597 | + allocated_slots: Mapping[SlotName, Decimal], |
| 598 | + total_slots: Mapping[SlotName, Decimal], |
| 599 | + ) -> dict[SlotName, Decimal]: |
| 600 | + match self.resource_config.allocation_mode: |
| 601 | + case ResourceAllocationMode.SHARED: |
| 602 | + return defaultdict(lambda: Decimal(1.0)) |
| 603 | + case ResourceAllocationMode.AUTO_SPLIT: |
| 604 | + return defaultdict(lambda: Decimal(1.0) / Decimal(self.num_agents)) |
| 605 | + case ResourceAllocationMode.MANUAL: |
| 606 | + if SlotName("cpu") not in allocated_slots or SlotName("cpu") not in total_slots: |
| 607 | + raise ValueError("CPU not in allocated or total slots seen") |
| 608 | + if SlotName("mem") not in allocated_slots or SlotName("mem") not in total_slots: |
| 609 | + raise ValueError("Memory not in allocated or total slots seen") |
| 610 | + scaling_factor = { |
| 611 | + slot_name: slot / total_slots[slot_name] |
| 612 | + for slot_name, slot in allocated_slots.items() |
| 613 | + } |
| 614 | + return scaling_factor |
| 615 | + |
| 616 | + |
447 | 617 | class ComputePluginContext(BasePluginContext[AbstractComputePlugin]): |
448 | 618 | plugin_group = "backendai_accelerator_v21" |
449 | 619 |
|
|
0 commit comments