11import asyncio
22import importlib
33import signal
4- from typing import TYPE_CHECKING , Optional
4+ from typing import TYPE_CHECKING , Mapping , Optional
55
66from ai .backend .agent .agent import AbstractAgent
77from ai .backend .agent .config .unified import AgentUnifiedConfig
1818 ErrorDomain ,
1919 ErrorOperation ,
2020)
21- from ai .backend .common .types import AgentId , aobject
21+ from ai .backend .common .types import AgentId
2222
2323if TYPE_CHECKING :
2424 from .docker .metadata .server import MetadataServer
@@ -34,52 +34,61 @@ def error_code(cls) -> ErrorCode:
3434 )
3535
3636
37- class AgentRuntime ( aobject ) :
37+ class AgentRuntime :
3838 _local_config : AgentUnifiedConfig
3939 _agents : dict [AgentId , AbstractAgent ]
40+ _default_agent : AbstractAgent
4041 _kernel_registry : KernelRegistry
4142 _etcd : AsyncEtcd
42- _etcd_views : dict [AgentId , AgentEtcdClientView ]
43+ _etcd_views : Mapping [AgentId , AgentEtcdClientView ]
4344
44- _default_agent_id : AgentId
4545 _stop_signal : signal .Signals
4646
4747 def __init__ (
4848 self ,
4949 local_config : AgentUnifiedConfig ,
5050 etcd : AsyncEtcd ,
51- stats_monitor : AgentStatsPluginContext ,
52- error_monitor : AgentErrorPluginContext ,
53- agent_public_key : Optional [PublicKey ],
5451 ) -> None :
5552 self ._local_config = local_config
5653 self ._agents = {}
5754 self ._kernel_registry = KernelRegistry ()
5855 self ._etcd = etcd
59- self ._etcd_views = {}
56+ self ._etcd_views = {
57+ AgentId (agent_config .agent .id ): AgentEtcdClientView (self ._etcd , agent_config )
58+ for agent_config in self ._local_config .get_agent_configs ()
59+ }
6060 self ._metadata_server : MetadataServer | None = None
6161
62- agent_configs = self ._local_config .get_agent_configs ()
63- self ._default_agent_id = AgentId (agent_configs [0 ].agent .id )
6462 self ._stop_signal = signal .SIGTERM
6563
66- self .stats_monitor = stats_monitor
67- self .error_monitor = error_monitor
68- self .agent_public_key = agent_public_key
69-
70- async def __ainit__ (self ) -> None :
64+ async def create_agents (
65+ self ,
66+ stats_monitor : AgentStatsPluginContext ,
67+ error_monitor : AgentErrorPluginContext ,
68+ agent_public_key : Optional [PublicKey ],
69+ ) -> None :
7170 if self ._local_config .agent_common .backend == AgentBackend .DOCKER :
7271 await self ._initialize_metadata_server ()
7372
74- tasks = []
73+ tasks : list [ asyncio . Task ] = []
7574 async with asyncio .TaskGroup () as tg :
7675 for agent_config in self ._local_config .get_agent_configs ():
7776 agent_id = AgentId (agent_config .agent .id )
78- etcd_view = AgentEtcdClientView (self ._etcd , agent_config )
79-
80- self ._etcd_views [agent_id ] = etcd_view
81- tasks .append (tg .create_task (self ._create_agent (etcd_view , agent_config )))
82- self ._agents = {(agent := task .result ()).id : agent for task in tasks }
77+ tasks .append (
78+ tg .create_task (
79+ self ._create_agent (
80+ self .get_etcd (agent_id ),
81+ agent_config ,
82+ stats_monitor ,
83+ error_monitor ,
84+ agent_public_key ,
85+ )
86+ )
87+ )
88+
89+ agents = [task .result () for task in tasks ]
90+ self ._default_agent = agents [0 ]
91+ self ._agents = {agent .id : agent for agent in agents }
8392
8493 async def __aexit__ (self , * exc_info ) -> None :
8594 for agent in self ._agents .values ():
@@ -92,21 +101,19 @@ def get_agents(self) -> list[AbstractAgent]:
92101
93102 def get_agent (self , agent_id : Optional [AgentId ]) -> AbstractAgent :
94103 if agent_id is None :
95- agent_id = self ._default_agent_id
104+ return self ._default_agent
96105 if agent_id not in self ._agents :
97106 raise AgentIdNotFoundError (
98107 f"Agent '{ agent_id } ' not found in this runtime. "
99108 f"Available agents: { ', ' .join (self ._agents .keys ())} "
100109 )
101110 return self ._agents [agent_id ]
102111
103- def get_etcd (self , agent_id : Optional [AgentId ]) -> AgentEtcdClientView :
104- if agent_id is None :
105- agent_id = self ._default_agent_id
106- if agent_id not in self ._agents :
112+ def get_etcd (self , agent_id : AgentId ) -> AgentEtcdClientView :
113+ if agent_id not in self ._etcd_views :
107114 raise AgentIdNotFoundError (
108- f"Agent '{ agent_id } ' not found in this runtime. "
109- f"Available agents : { ', ' .join (self ._agents .keys ())} "
115+ f"Etcd client for agent '{ agent_id } ' not found in this runtime. "
116+ f"Available agent etcd views : { ', ' .join (self ._etcd_views .keys ())} "
110117 )
111118 return self ._etcd_views [agent_id ]
112119
@@ -121,11 +128,14 @@ async def _create_agent(
121128 self ,
122129 etcd_view : AgentEtcdClientView ,
123130 agent_config : AgentUnifiedConfig ,
131+ stats_monitor : AgentStatsPluginContext ,
132+ error_monitor : AgentErrorPluginContext ,
133+ agent_public_key : Optional [PublicKey ],
124134 ) -> AbstractAgent :
125135 agent_kwargs = {
126- "stats_monitor" : self . stats_monitor ,
127- "error_monitor" : self . error_monitor ,
128- "agent_public_key" : self . agent_public_key ,
136+ "stats_monitor" : stats_monitor ,
137+ "error_monitor" : error_monitor ,
138+ "agent_public_key" : agent_public_key ,
129139 }
130140
131141 backend = self ._local_config .agent_common .backend
0 commit comments