1+ import asyncio
12import importlib
23import signal
3- from typing import Optional
4+ from typing import TYPE_CHECKING , Optional
45
56from ai .backend .agent .agent import AbstractAgent
67from ai .backend .agent .config .unified import AgentUnifiedConfig
78from ai .backend .agent .etcd import AgentEtcdClientView
89from ai .backend .agent .kernel import KernelRegistry
910from ai .backend .agent .monitor import AgentErrorPluginContext , AgentStatsPluginContext
11+ from ai .backend .agent .types import AgentBackend
1012from ai .backend .common .auth import PublicKey
1113from ai .backend .common .etcd import AsyncEtcd , ConfigScopes
12- from ai .backend .common .types import aobject
14+ from ai .backend .common .exception import (
15+ BackendAIError ,
16+ ErrorCode ,
17+ ErrorDetail ,
18+ ErrorDomain ,
19+ ErrorOperation ,
20+ )
21+ from ai .backend .common .types import AgentId , aobject
22+
23+ if TYPE_CHECKING :
24+ from .docker .metadata .server import MetadataServer
25+
26+
27+ class AgentIdNotFoundError (BackendAIError ):
28+ @classmethod
29+ def error_code (cls ) -> ErrorCode :
30+ return ErrorCode (
31+ domain = ErrorDomain .AGENT ,
32+ operation = ErrorOperation .ACCESS ,
33+ error_detail = ErrorDetail .NOT_FOUND ,
34+ )
1335
1436
1537class AgentRuntime (aobject ):
1638 local_config : AgentUnifiedConfig
17- agent : AbstractAgent
39+ agents : dict [ AgentId , AbstractAgent ]
1840 kernel_registry : KernelRegistry
1941 etcd : AsyncEtcd
20- etcd_view : AgentEtcdClientView
42+ etcd_views : dict [AgentId , AgentEtcdClientView ]
43+ metadata_server : MetadataServer | None
2144
45+ _default_agent_id : AgentId
2246 _stop_signal : signal .Signals
2347
2448 def __init__ (
@@ -30,33 +54,68 @@ def __init__(
3054 agent_public_key : Optional [PublicKey ],
3155 ) -> None :
3256 self .local_config = local_config
57+ self .agents = {}
3358 self .kernel_registry = KernelRegistry ()
3459 self .etcd = etcd
35- self .etcd_view = AgentEtcdClientView (etcd , self .local_config )
60+ self .etcd_views = {}
61+ self .metadata_server = None
3662
63+ agent_configs = self .local_config .get_agent_configs ()
64+ self ._default_agent_id = AgentId (agent_configs [0 ].agent .id )
3765 self ._stop_signal = signal .SIGTERM
3866
3967 self .stats_monitor = stats_monitor
4068 self .error_monitor = error_monitor
4169 self .agent_public_key = agent_public_key
4270
4371 async def __ainit__ (self ) -> None :
44- self .agent = await self ._create_agent (self .etcd_view , self .local_config )
72+ if self .local_config .agent_common .backend == AgentBackend .DOCKER :
73+ await self ._initialize_metadata_server ()
4574
46- async def __aexit__ (self , * exc_info ) -> None :
47- await self .agent .shutdown (self ._stop_signal )
75+ tasks = []
76+ async with asyncio .TaskGroup () as tg :
77+ for agent_config in self .local_config .get_agent_configs ():
78+ agent_id = AgentId (agent_config .agent .id )
79+ etcd_view = AgentEtcdClientView (self .etcd , agent_config )
4880
49- def get_agent (self ) -> AbstractAgent :
50- return self .agent
81+ self .etcd_views [agent_id ] = etcd_view
82+ tasks .append (tg .create_task (self ._create_agent (etcd_view , agent_config )))
83+ self .agents = {(agent := task .result ()).id : agent for task in tasks }
5184
52- def get_etcd (self ) -> AgentEtcdClientView :
53- return self .etcd_view
85+ async def __aexit__ (self , * exc_info ) -> None :
86+ for agent in self .agents .values ():
87+ await agent .shutdown (self ._stop_signal )
88+ if self .metadata_server is not None :
89+ await self .metadata_server .cleanup ()
90+
91+ def get_agents (self ) -> list [AbstractAgent ]:
92+ return list (self .agents .values ())
93+
94+ def get_agent (self , agent_id : Optional [AgentId ]) -> AbstractAgent :
95+ if agent_id is None :
96+ agent_id = self ._default_agent_id
97+ if agent_id not in self .agents :
98+ raise AgentIdNotFoundError (
99+ f"Agent '{ agent_id } ' not found in this runtime. "
100+ f"Available agents: { ', ' .join (self .agents .keys ())} "
101+ )
102+ return self .agents [agent_id ]
103+
104+ def get_etcd (self , agent_id : Optional [AgentId ]) -> AgentEtcdClientView :
105+ if agent_id is None :
106+ agent_id = self ._default_agent_id
107+ if agent_id not in self .agents :
108+ raise AgentIdNotFoundError (
109+ f"Agent '{ agent_id } ' not found in this runtime. "
110+ f"Available agents: { ', ' .join (self .agents .keys ())} "
111+ )
112+ return self .etcd_views [agent_id ]
54113
55114 def mark_stop_signal (self , stop_signal : signal .Signals ) -> None :
56115 self ._stop_signal = stop_signal
57116
58- async def update_status (self , status ) -> None :
59- etcd = self .get_etcd ()
117+ async def update_status (self , status , agent_id : AgentId ) -> None :
118+ etcd = self .get_etcd (agent_id )
60119 await etcd .put ("" , status , scope = ConfigScopes .NODE )
61120
62121 async def _create_agent (
@@ -75,3 +134,13 @@ async def _create_agent(
75134 agent_cls = agent_mod .get_agent_cls ()
76135
77136 return agent_cls .new (etcd_view , agent_config , ** agent_kwargs )
137+
138+ async def _initialize_metadata_server (self ) -> None :
139+ from .docker .metadata .server import MetadataServer
140+
141+ self .metadata_server = await MetadataServer .new (
142+ self .local_config ,
143+ self .etcd ,
144+ kernel_registry = self .kernel_registry .global_view (),
145+ )
146+ await self .metadata_server .start_server ()
0 commit comments