Skip to content

Commit 001bbe9

Browse files
committed
refactor(BA-3028): Move kernel registry ownership to AgentRuntime
This change moves where the agent's kernel registry is stored from within the agent to outside the agent at a global level in AgentRuntime. This is to make future change of pulling out Metadata server outside of DockerAgent easier, which needs to take in a global view of all kernels across all agents.
1 parent 0249c25 commit 001bbe9

File tree

8 files changed

+572
-3
lines changed

8 files changed

+572
-3
lines changed

changes/6730.enhance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Move kernel registry ownership to agent runtime

src/ai/backend/agent/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@
229229
from .kernel import (
230230
RUN_ID_FOR_BATCH_JOB,
231231
AbstractKernel,
232+
KernelRegistry,
232233
match_distro_data,
233234
)
234235
from .observer.heartbeat import HeartbeatObserver
@@ -836,6 +837,7 @@ def __init__(
836837
error_monitor: ErrorPluginContext,
837838
skip_initial_scan: bool = False,
838839
agent_public_key: Optional[PublicKey],
840+
kernel_registry: KernelRegistry,
839841
) -> None:
840842
self._skip_initial_scan = skip_initial_scan
841843
self.loop = current_loop()
@@ -844,7 +846,7 @@ def __init__(
844846
self.id = AgentId(local_config.agent.id or f"agent-{uuid4()}")
845847
self.local_instance_id = generate_local_instance_id(__file__)
846848
self.agent_public_key = agent_public_key
847-
self.kernel_registry = {}
849+
self.kernel_registry = kernel_registry.agent_view(self.id)
848850
self.computers = {}
849851
self.images = {}
850852
self.restarting_kernels = {}

src/ai/backend/agent/docker/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
from ..config.unified import AgentUnifiedConfig, ContainerSandboxType, ScratchType
110110
from ..exception import ContainerCreationError, UnsupportedResource
111111
from ..fs import create_scratch_filesystem, destroy_scratch_filesystem
112-
from ..kernel import AbstractKernel
112+
from ..kernel import AbstractKernel, KernelRegistry
113113
from ..plugin.network import ContainerNetworkCapability, ContainerNetworkInfo, NetworkPluginContext
114114
from ..proxy import DomainSocketProxy, proxy_connection
115115
from ..resources import (
@@ -1357,6 +1357,7 @@ def __init__(
13571357
error_monitor: ErrorPluginContext,
13581358
skip_initial_scan: bool = False,
13591359
agent_public_key: Optional[PublicKey],
1360+
kernel_registry: KernelRegistry,
13601361
) -> None:
13611362
super().__init__(
13621363
etcd,
@@ -1365,6 +1366,7 @@ def __init__(
13651366
error_monitor=error_monitor,
13661367
skip_initial_scan=skip_initial_scan,
13671368
agent_public_key=agent_public_key,
1369+
kernel_registry=kernel_registry,
13681370
)
13691371
self.checked_invalid_images = set()
13701372

src/ai/backend/agent/kernel.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@
2121
Any,
2222
Dict,
2323
FrozenSet,
24+
Iterator,
2425
List,
2526
Literal,
27+
MutableMapping,
2628
NotRequired,
2729
Optional,
2830
Set,
2931
Tuple,
3032
TypedDict,
3133
Union,
3234
cast,
35+
overload,
3336
)
3437

3538
import zmq
@@ -452,6 +455,102 @@ async def execute(
452455
raise
453456

454457

458+
@dataclass(frozen=True)
459+
class AgentKernelRegistryKey:
460+
agent_id: AgentId
461+
kernel_id: KernelId
462+
463+
464+
class KernelRegistry(MutableMapping[AgentKernelRegistryKey, AbstractKernel]):
465+
_registry: MutableMapping[AgentKernelRegistryKey, AbstractKernel]
466+
_global_registry: MutableMapping[KernelId, AbstractKernel]
467+
468+
def __init__(self) -> None:
469+
super().__init__()
470+
471+
self._registry = {}
472+
self._global_registry = {}
473+
474+
def agent_view(self, agent_id: AgentId) -> "KernelRegistryAgentView":
475+
return KernelRegistryAgentView(self, agent_id)
476+
477+
def global_view(self) -> "KernelRegistryGlobalView":
478+
return KernelRegistryGlobalView(self)
479+
480+
@overload
481+
def __getitem__(self, key: KernelId) -> AbstractKernel: ...
482+
483+
@overload
484+
def __getitem__(self, key: AgentKernelRegistryKey) -> AbstractKernel: ...
485+
486+
def __getitem__(self, key: KernelId | AgentKernelRegistryKey) -> AbstractKernel:
487+
if isinstance(key, AgentKernelRegistryKey):
488+
return self._registry[key]
489+
else:
490+
return self._global_registry[key]
491+
492+
def __setitem__(self, key: AgentKernelRegistryKey, value: AbstractKernel) -> None:
493+
self._registry[key] = value
494+
self._global_registry[key.kernel_id] = value
495+
496+
def __delitem__(self, key: AgentKernelRegistryKey) -> None:
497+
del self._registry[key]
498+
del self._global_registry[key.kernel_id]
499+
500+
def __iter__(self) -> Iterator[AgentKernelRegistryKey]:
501+
return iter(self._registry)
502+
503+
def __len__(self) -> int:
504+
return len(self._registry)
505+
506+
507+
class KernelRegistryAgentView(MutableMapping[KernelId, AbstractKernel]):
508+
_registry: KernelRegistry
509+
_agent_id: AgentId
510+
511+
def __init__(self, kernel_registry: KernelRegistry, agent_id: AgentId) -> None:
512+
super().__init__()
513+
514+
self._registry = kernel_registry
515+
self._agent_id = agent_id
516+
517+
def __getitem__(self, key: KernelId) -> AbstractKernel:
518+
return self._registry[AgentKernelRegistryKey(self._agent_id, key)]
519+
520+
def __setitem__(self, key: KernelId, value: AbstractKernel) -> None:
521+
self._registry[AgentKernelRegistryKey(self._agent_id, key)] = value
522+
523+
def __delitem__(self, key: KernelId) -> None:
524+
del self._registry[AgentKernelRegistryKey(self._agent_id, key)]
525+
526+
def __iter__(self) -> Iterator[KernelId]:
527+
for registry_key in self._registry:
528+
if registry_key.agent_id == self._agent_id:
529+
yield registry_key.kernel_id
530+
531+
def __len__(self) -> int:
532+
return sum(1 for key in self._registry if key.agent_id == self._agent_id)
533+
534+
535+
class KernelRegistryGlobalView(Mapping[KernelId, AbstractKernel]):
536+
_registry: KernelRegistry
537+
538+
def __init__(self, kernel_registry: KernelRegistry) -> None:
539+
super().__init__()
540+
541+
self._registry = kernel_registry
542+
543+
def __getitem__(self, key: KernelId) -> AbstractKernel:
544+
return self._registry[key]
545+
546+
def __iter__(self) -> Iterator[KernelId]:
547+
for registry_key in self._registry:
548+
yield registry_key.kernel_id
549+
550+
def __len__(self) -> int:
551+
return len(self._registry)
552+
553+
455554
_zctx = None
456555

457556

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
)
6969
from ..config.unified import AgentUnifiedConfig, ScratchType
7070
from ..exception import K8sError, UnsupportedResource
71-
from ..kernel import AbstractKernel
71+
from ..kernel import AbstractKernel, KernelRegistry
7272
from ..resources import (
7373
AbstractComputePlugin,
7474
ComputerContext,
@@ -829,6 +829,7 @@ def __init__(
829829
error_monitor: ErrorPluginContext,
830830
skip_initial_scan: bool = False,
831831
agent_public_key: Optional[PublicKey],
832+
kernel_registry: KernelRegistry,
832833
) -> None:
833834
super().__init__(
834835
etcd,
@@ -837,6 +838,7 @@ def __init__(
837838
error_monitor=error_monitor,
838839
skip_initial_scan=skip_initial_scan,
839840
agent_public_key=agent_public_key,
841+
kernel_registry=kernel_registry,
840842
)
841843

842844
async def __ainit__(self) -> None:

src/ai/backend/agent/runtime.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ai.backend.agent.agent import AbstractAgent
55
from ai.backend.agent.config.unified import AgentUnifiedConfig
66
from ai.backend.agent.etcd import AgentEtcdClientView
7+
from ai.backend.agent.kernel import KernelRegistry
78
from ai.backend.agent.monitor import AgentErrorPluginContext, AgentStatsPluginContext
89
from ai.backend.agent.types import AgentBackend
910
from ai.backend.common.auth import PublicKey
@@ -14,6 +15,7 @@
1415
class AgentRuntime(aobject):
1516
local_config: AgentUnifiedConfig
1617
agent: AbstractAgent
18+
kernel_registry: KernelRegistry
1719
etcd: AsyncEtcd
1820
etcd_view: AgentEtcdClientView
1921

@@ -28,6 +30,7 @@ def __init__(
2830
agent_public_key: Optional[PublicKey],
2931
) -> None:
3032
self.local_config = local_config
33+
self.kernel_registry = KernelRegistry()
3134
self.etcd = etcd
3235
self.etcd_view = AgentEtcdClientView(etcd, self.local_config)
3336

tests/agent/docker/test_agent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from aiodocker.exceptions import DockerError
1212

1313
from ai.backend.agent.docker.agent import DockerAgent
14+
from ai.backend.agent.kernel import KernelRegistry
1415
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
1516
from ai.backend.common.docker import ImageRef
1617
from ai.backend.common.exception import ImageNotAvailable
@@ -28,13 +29,15 @@ async def agent(local_config, test_id, mocker, socket_relay_image):
2829
mocked_etcd_get_prefix = AsyncMock(return_value={})
2930
mocker.patch.object(dummy_etcd, "get_prefix", new=mocked_etcd_get_prefix)
3031
test_case_id = secrets.token_hex(8)
32+
kernel_registry = KernelRegistry()
3133
agent = await DockerAgent.new(
3234
dummy_etcd,
3335
local_config,
3436
stats_monitor=None,
3537
error_monitor=None,
3638
skip_initial_scan=True,
3739
agent_public_key=None,
40+
kernel_registry=kernel_registry,
3841
) # for faster test iteration
3942
agent.local_instance_id = test_case_id # use per-test private registry file
4043
try:

0 commit comments

Comments
 (0)