3030from dataclasses import dataclass
3131from decimal import Decimal
3232from io import SEEK_END , BytesIO
33+ from itertools import chain
3334from pathlib import Path
3435from types import TracebackType
3536from typing import (
174175from ai .backend .common .types import (
175176 MODEL_SERVICE_RUNTIME_PROFILES ,
176177 AbuseReportValue ,
177- AcceleratorMetadata ,
178178 AgentId ,
179179 AutoPullBehavior ,
180180 BinarySize ,
233233from .observer .heartbeat import HeartbeatObserver
234234from .observer .host_port import HostPortObserver
235235from .resources import (
236- AbstractComputeDevice ,
237236 AbstractComputePlugin ,
238237 ComputerContext ,
239238 KernelResourceSpec ,
240239 Mount ,
240+ ResourcePartitioner ,
241241 align_memory ,
242242 allocate ,
243243 known_slot_types ,
@@ -765,7 +765,10 @@ class AbstractAgent(
765765 etcd : AsyncEtcd
766766 local_instance_id : str
767767 kernel_registry : MutableMapping [KernelId , AbstractKernel ]
768+ resource_partitioner : ResourcePartitioner
768769 computers : MutableMapping [DeviceName , ComputerContext ]
770+ total_slots : Mapping [SlotName , Decimal ]
771+ reserved_slots : Mapping [SlotName , Decimal ]
769772 images : Mapping [ImageCanonical , ScannedImage ]
770773 port_pool : set [int ]
771774
@@ -836,6 +839,7 @@ def __init__(
836839 error_monitor : ErrorPluginContext ,
837840 skip_initial_scan : bool = False ,
838841 agent_public_key : Optional [PublicKey ],
842+ resource_partitioner : ResourcePartitioner ,
839843 ) -> None :
840844 self ._skip_initial_scan = skip_initial_scan
841845 self .loop = current_loop ()
@@ -845,7 +849,10 @@ def __init__(
845849 self .local_instance_id = generate_local_instance_id (__file__ )
846850 self .agent_public_key = agent_public_key
847851 self .kernel_registry = {}
852+ self .resource_partitioner = resource_partitioner
848853 self .computers = {}
854+ self .total_slots = {}
855+ self .reserved_slots = {}
849856 self .images = {}
850857 self .restarting_kernels = {}
851858 self .stat_ctx = StatContext (
@@ -934,28 +941,34 @@ async def __ainit__(self) -> None:
934941 alloc_map_mod .log_alloc_map = self .local_config .debug .log_alloc_map
935942 computers = await self .load_resources ()
936943
937- all_devices : list [AbstractComputeDevice ] = []
938- metadatas : list [AcceleratorMetadata ] = []
939944 for name , computer in computers .items ():
940945 devices = await computer .list_devices ()
941- all_devices .extend (devices )
942946 alloc_map = await computer .create_alloc_map ()
943947 self .computers [name ] = ComputerContext (computer , devices , alloc_map )
944- metadatas .append (computer .get_metadata ())
945948
949+ self .total_slots = self .resource_partitioner .calculate_total_slots (
950+ self .computers , self .local_config .resource_common
951+ )
952+ self .reserved_slots = self .resource_partitioner .restrict_computer_resources (
953+ self .computers , self .total_slots
954+ )
946955 self .slots = await self .update_slots ()
947956 log .info ("Resource slots: {!r}" , self .slots )
948957 log .info ("Slot types: {!r}" , known_slot_types )
949958 self .timer_tasks .append (aiotools .create_timer (self .update_slots_periodically , 30.0 ))
950959
951960 # Use ValkeyStatClient batch operations for better performance
952961 field_value_map = {}
953- for metadata in metadatas :
962+ for computer_ctx in self .computers .values ():
963+ metadata = computer_ctx .instance .get_metadata ()
954964 field_value_map [metadata ["slot_name" ]] = dump_json_str (metadata ).encode ()
955965
956966 if field_value_map :
957967 await self .valkey_stat_client .store_computer_metadata (field_value_map )
958968
969+ all_devices = list (
970+ chain .from_iterable (computer .devices for computer in self .computers .values ())
971+ )
959972 self .affinity_map = AffinityMap .build (all_devices )
960973
961974 if not self ._skip_initial_scan :
@@ -1949,6 +1962,7 @@ async def load_resources(
19491962 """
19501963 Detect available resources attached on the system and load corresponding device plugin.
19511964 """
1965+ raise NotImplementedError
19521966
19531967 @abstractmethod
19541968 async def scan_available_resources (
@@ -1957,6 +1971,7 @@ async def scan_available_resources(
19571971 """
19581972 Scan and define the amount of available resource slots in this node.
19591973 """
1974+ raise NotImplementedError
19601975
19611976 async def update_slots (
19621977 self ,
@@ -1967,14 +1982,9 @@ async def update_slots(
19671982 """
19681983 scanned_slots = await self .scan_available_resources ()
19691984 usable_slots : dict [SlotName , Decimal ] = {}
1970- reserved_slots = {
1971- SlotName ("cpu" ): Decimal (self .local_config .resource .reserved_cpu ),
1972- SlotName ("mem" ): Decimal (self .local_config .resource .reserved_mem ),
1973- SlotName ("disk" ): Decimal (self .local_config .resource .reserved_disk ),
1974- }
19751985 for slot_name , slot_capacity in scanned_slots .items ():
19761986 if slot_name == SlotName ("mem" ):
1977- mem_reserved = int (reserved_slots .get (slot_name , 0 ))
1987+ mem_reserved = int (self . reserved_slots .get (slot_name , 0 ))
19781988 mem_align = int (self .local_config .resource .memory_align_size )
19791989 mem_usable , mem_reserved = align_memory (
19801990 int (slot_capacity ), mem_reserved , align = mem_align
@@ -1988,7 +1998,7 @@ async def update_slots(
19881998 )
19891999 else :
19902000 usable_capacity = max (
1991- Decimal (0 ), slot_capacity - reserved_slots .get (slot_name , Decimal (0 ))
2001+ Decimal (0 ), slot_capacity - self . reserved_slots .get (slot_name , Decimal (0 ))
19922002 )
19932003 usable_slots [slot_name ] = usable_capacity
19942004 return usable_slots
@@ -2100,6 +2110,7 @@ async def scan_images(self) -> ScanImagesResult:
21002110 This is called periodically to keep the image list up-to-date and allow
21012111 manual image addition and deletions by admins.
21022112 """
2113+ raise NotImplementedError
21032114
21042115 async def _scan_images_wrapper (self , interval : float ) -> None :
21052116 result = await self .scan_images ()
@@ -2120,6 +2131,7 @@ async def push_image(
21202131 """
21212132 Push the given image to the given registry.
21222133 """
2134+ raise NotImplementedError
21232135
21242136 @abstractmethod
21252137 async def pull_image (
@@ -2132,12 +2144,14 @@ async def pull_image(
21322144 """
21332145 Pull the given image from the given registry.
21342146 """
2147+ raise NotImplementedError
21352148
21362149 @abstractmethod
21372150 async def purge_images (self , request : PurgeImagesReq ) -> PurgeImagesResp :
21382151 """
21392152 Purge the given images from the agent.
21402153 """
2154+ raise NotImplementedError
21412155
21422156 async def check_and_pull (
21432157 self ,
@@ -2269,7 +2283,7 @@ async def check_image(
22692283 Check the availability of the image and return a boolean flag that indicates whether
22702284 the agent should try pulling the image from a registry.
22712285 """
2272- return False
2286+ raise NotImplementedError
22732287
22742288 async def scan_running_kernels (self ) -> None :
22752289 """
@@ -3491,6 +3505,7 @@ async def destroy_kernel(
34913505 * Send SIGTERM to the kernel's main process.
34923506 * Send SIGKILL if it's not terminated within a few seconds.
34933507 """
3508+ raise NotImplementedError
34943509
34953510 @abstractmethod
34963511 async def clean_kernel (
@@ -3514,6 +3529,7 @@ async def clean_kernel(
35143529 The ``container_id`` may be ``None`` if the container has already gone away.
35153530 In such cases, skip container-specific cleanups.
35163531 """
3532+ raise NotImplementedError
35173533
35183534 @abstractmethod
35193535 async def create_local_network (self , network_name : str ) -> None :
@@ -3525,6 +3541,7 @@ async def create_local_network(self, network_name: str) -> None:
35253541 It may raise :exc:`NotImplementedError` and then the manager
35263542 will cancel creation of the session.
35273543 """
3544+ raise NotImplementedError
35283545
35293546 @abstractmethod
35303547 async def destroy_local_network (self , network_name : str ) -> None :
@@ -3533,6 +3550,7 @@ async def destroy_local_network(self, network_name: str) -> None:
35333550
35343551 This is called by the manager after kernel destruction.
35353552 """
3553+ raise NotImplementedError
35363554
35373555 @abstractmethod
35383556 async def restart_kernel__load_config (
@@ -3543,7 +3561,7 @@ async def restart_kernel__load_config(
35433561 """
35443562 Restore the cluster config from a previous launch of the kernel.
35453563 """
3546- pass
3564+ raise NotImplementedError
35473565
35483566 @abstractmethod
35493567 async def restart_kernel__store_config (
@@ -3556,7 +3574,7 @@ async def restart_kernel__store_config(
35563574 Store the cluster config to a kernel-related storage (e.g., scratch space),
35573575 so that restarts of this kernel can reuse the configuration.
35583576 """
3559- pass
3577+ raise NotImplementedError
35603578
35613579 async def restart_kernel (
35623580 self ,
0 commit comments