1+ import asyncio
12import signal
23from typing import Optional
34
45from ai .backend .agent .agent import AbstractAgent
56from ai .backend .agent .config .unified import AgentUnifiedConfig
7+ from ai .backend .agent .docker .metadata .server import MetadataServer
68from ai .backend .agent .etcd import AgentEtcdClientView
79from ai .backend .agent .kernel import KernelRegistry
810from ai .backend .agent .monitor import AgentErrorPluginContext , AgentStatsPluginContext
911from ai .backend .agent .types import AgentBackend
1012from ai .backend .common .auth import PublicKey
1113from ai .backend .common .etcd import AsyncEtcd
12- from ai .backend .common .types import aobject
14+ from ai .backend .common .exception import (
15+ BackendAIError ,
16+ ErrorCode ,
17+ ErrorDetail ,
18+ ErrorDomain ,
19+ ErrorOperation ,
20+ )
21+ from ai .backend .common .types import AgentId , aobject
22+
23+
24+ class AgentIdNotFoundError (BackendAIError ):
25+ @classmethod
26+ def error_code (cls ) -> ErrorCode :
27+ return ErrorCode (
28+ domain = ErrorDomain .AGENT ,
29+ operation = ErrorOperation .ACCESS ,
30+ error_detail = ErrorDetail .NOT_FOUND ,
31+ )
1332
1433
1534class AgentRuntime (aobject ):
1635 local_config : AgentUnifiedConfig
17- agent : AbstractAgent
36+ agents : dict [ AgentId , AbstractAgent ]
1837 kernel_registry : KernelRegistry
1938 etcd : AsyncEtcd
20- etcd_view : AgentEtcdClientView
39+ etcd_views : dict [AgentId , AgentEtcdClientView ]
40+ metadata_server : MetadataServer | None
2141
42+ _default_agent_id : AgentId
2243 _stop_signal : signal .Signals
2344
2445 def __init__ (
@@ -30,27 +51,58 @@ def __init__(
3051 agent_public_key : Optional [PublicKey ],
3152 ) -> None :
3253 self .local_config = local_config
54+ self .agents = {}
3355 self .kernel_registry = KernelRegistry ()
3456 self .etcd = etcd
35- self .etcd_view = AgentEtcdClientView (etcd , self .local_config )
57+ self .etcd_views = {}
58+ self .metadata_server = None
3659
60+ self ._default_agent_id = AgentId (self .local_config .agent_configs [0 ].agent .id )
3761 self ._stop_signal = signal .SIGTERM
3862
3963 self .stats_monitor = stats_monitor
4064 self .error_monitor = error_monitor
4165 self .agent_public_key = agent_public_key
4266
4367 async def __ainit__ (self ) -> None :
44- self .agent = await self ._create_agent (self .etcd_view , self .local_config )
68+ tasks = []
69+ async with asyncio .TaskGroup () as tg :
70+ for agent_config in self .local_config .agent_configs :
71+ agent_id = AgentId (agent_config .agent .id )
72+ etcd_view = AgentEtcdClientView (self .etcd , agent_config )
4573
46- async def __aexit__ (self , * exc_info ) -> None :
47- await self .agent .shutdown (self ._stop_signal )
48-
49- def get_agent (self ) -> AbstractAgent :
50- return self .agent
74+ self .etcd_views [agent_id ] = etcd_view
75+ tasks .append (tg .create_task (self ._create_agent (etcd_view , agent_config )))
76+ self .agents = {(agent := task .result ()).id : agent for task in tasks }
5177
52- def get_etcd (self ) -> AgentEtcdClientView :
53- return self .etcd_view
78+ async def __aexit__ (self , * exc_info ) -> None :
79+ for agent in self .agents .values ():
80+ await agent .shutdown (self ._stop_signal )
81+ if self .metadata_server is not None :
82+ await self .metadata_server .cleanup ()
83+
84+ def get_agents (self ) -> list [AbstractAgent ]:
85+ return list (self .agents .values ())
86+
87+ def get_agent (self , agent_id : Optional [AgentId ]) -> AbstractAgent :
88+ if agent_id is None :
89+ agent_id = self ._default_agent_id
90+ if agent_id not in self .agents :
91+ raise AgentIdNotFoundError (
92+ f"Agent '{ agent_id } ' not found in this runtime. "
93+ f"Available agents: { ', ' .join (self .agents .keys ())} "
94+ )
95+ return self .agents [agent_id ]
96+
97+ def get_etcd (self , agent_id : Optional [AgentId ]) -> AgentEtcdClientView :
98+ if agent_id is None :
99+ agent_id = self ._default_agent_id
100+ if agent_id not in self .agents :
101+ raise AgentIdNotFoundError (
102+ f"Agent '{ agent_id } ' not found in this runtime. "
103+ f"Available agents: { ', ' .join (self .agents .keys ())} "
104+ )
105+ return self .etcd_views [agent_id ]
54106
55107 def mark_stop_signal (self , stop_signal : signal .Signals ) -> None :
56108 self ._stop_signal = stop_signal
@@ -64,12 +116,15 @@ async def _create_agent(
64116 case AgentBackend .DOCKER :
65117 from .docker .agent import DockerAgent
66118
119+ await self ._initialize_metadata_server ()
120+
67121 return await DockerAgent .new (
68122 etcd_view ,
69123 agent_config ,
70124 stats_monitor = self .stats_monitor ,
71125 error_monitor = self .error_monitor ,
72126 agent_public_key = self .agent_public_key ,
127+ metadata_server = self .metadata_server ,
73128 )
74129 case AgentBackend .KUBERNETES :
75130 from .kubernetes .agent import KubernetesAgent
@@ -91,3 +146,13 @@ async def _create_agent(
91146 error_monitor = self .error_monitor ,
92147 agent_public_key = self .agent_public_key ,
93148 )
149+
150+ async def _initialize_metadata_server (self ) -> None :
151+ from .docker .metadata .server import MetadataServer
152+
153+ self .metadata_server = await MetadataServer .new (
154+ self .local_config ,
155+ self .etcd ,
156+ kernel_registry = self .kernel_registry .global_view (),
157+ )
158+ await self .metadata_server .start_server ()
0 commit comments