Skip to content

Commit 64852fe

Browse files
committed
refactor(BA-2753): Respond to feedback
1 parent 456e028 commit 64852fe

File tree

2 files changed

+44
-35
lines changed

2 files changed

+44
-35
lines changed

src/ai/backend/agent/runtime.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import importlib
33
import signal
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Mapping, Optional
55

66
from ai.backend.agent.agent import AbstractAgent
77
from ai.backend.agent.config.unified import AgentUnifiedConfig
@@ -18,7 +18,7 @@
1818
ErrorDomain,
1919
ErrorOperation,
2020
)
21-
from ai.backend.common.types import AgentId, aobject
21+
from ai.backend.common.types import AgentId
2222

2323
if TYPE_CHECKING:
2424
from .docker.metadata.server import MetadataServer
@@ -34,52 +34,61 @@ def error_code(cls) -> ErrorCode:
3434
)
3535

3636

37-
class AgentRuntime(aobject):
37+
class AgentRuntime:
3838
_local_config: AgentUnifiedConfig
3939
_agents: dict[AgentId, AbstractAgent]
40+
_default_agent: AbstractAgent
4041
_kernel_registry: KernelRegistry
4142
_etcd: AsyncEtcd
42-
_etcd_views: dict[AgentId, AgentEtcdClientView]
43+
_etcd_views: Mapping[AgentId, AgentEtcdClientView]
4344

44-
_default_agent_id: AgentId
4545
_stop_signal: signal.Signals
4646

4747
def __init__(
4848
self,
4949
local_config: AgentUnifiedConfig,
5050
etcd: AsyncEtcd,
51-
stats_monitor: AgentStatsPluginContext,
52-
error_monitor: AgentErrorPluginContext,
53-
agent_public_key: Optional[PublicKey],
5451
) -> None:
5552
self._local_config = local_config
5653
self._agents = {}
5754
self._kernel_registry = KernelRegistry()
5855
self._etcd = etcd
59-
self._etcd_views = {}
56+
self._etcd_views = {
57+
AgentId(agent_config.agent.id): AgentEtcdClientView(self._etcd, agent_config)
58+
for agent_config in self._local_config.get_agent_configs()
59+
}
6060
self._metadata_server: MetadataServer | None = None
6161

62-
agent_configs = self._local_config.get_agent_configs()
63-
self._default_agent_id = AgentId(agent_configs[0].agent.id)
6462
self._stop_signal = signal.SIGTERM
6563

66-
self.stats_monitor = stats_monitor
67-
self.error_monitor = error_monitor
68-
self.agent_public_key = agent_public_key
69-
70-
async def __ainit__(self) -> None:
64+
async def create_agents(
65+
self,
66+
stats_monitor: AgentStatsPluginContext,
67+
error_monitor: AgentErrorPluginContext,
68+
agent_public_key: Optional[PublicKey],
69+
) -> None:
7170
if self._local_config.agent_common.backend == AgentBackend.DOCKER:
7271
await self._initialize_metadata_server()
7372

74-
tasks = []
73+
tasks: list[asyncio.Task] = []
7574
async with asyncio.TaskGroup() as tg:
7675
for agent_config in self._local_config.get_agent_configs():
7776
agent_id = AgentId(agent_config.agent.id)
78-
etcd_view = AgentEtcdClientView(self._etcd, agent_config)
79-
80-
self._etcd_views[agent_id] = etcd_view
81-
tasks.append(tg.create_task(self._create_agent(etcd_view, agent_config)))
82-
self._agents = {(agent := task.result()).id: agent for task in tasks}
77+
tasks.append(
78+
tg.create_task(
79+
self._create_agent(
80+
self.get_etcd(agent_id),
81+
agent_config,
82+
stats_monitor,
83+
error_monitor,
84+
agent_public_key,
85+
)
86+
)
87+
)
88+
89+
agents = [task.result() for task in tasks]
90+
self._default_agent = agents[0]
91+
self._agents = {agent.id: agent for agent in agents}
8392

8493
async def __aexit__(self, *exc_info) -> None:
8594
for agent in self._agents.values():
@@ -92,21 +101,19 @@ def get_agents(self) -> list[AbstractAgent]:
92101

93102
def get_agent(self, agent_id: Optional[AgentId]) -> AbstractAgent:
94103
if agent_id is None:
95-
agent_id = self._default_agent_id
104+
return self._default_agent
96105
if agent_id not in self._agents:
97106
raise AgentIdNotFoundError(
98107
f"Agent '{agent_id}' not found in this runtime. "
99108
f"Available agents: {', '.join(self._agents.keys())}"
100109
)
101110
return self._agents[agent_id]
102111

103-
def get_etcd(self, agent_id: Optional[AgentId]) -> AgentEtcdClientView:
104-
if agent_id is None:
105-
agent_id = self._default_agent_id
106-
if agent_id not in self._agents:
112+
def get_etcd(self, agent_id: AgentId) -> AgentEtcdClientView:
113+
if agent_id not in self._etcd_views:
107114
raise AgentIdNotFoundError(
108-
f"Agent '{agent_id}' not found in this runtime. "
109-
f"Available agents: {', '.join(self._agents.keys())}"
115+
f"Etcd client for agent '{agent_id}' not found in this runtime. "
116+
f"Available agent etcd views: {', '.join(self._etcd_views.keys())}"
110117
)
111118
return self._etcd_views[agent_id]
112119

@@ -121,11 +128,14 @@ async def _create_agent(
121128
self,
122129
etcd_view: AgentEtcdClientView,
123130
agent_config: AgentUnifiedConfig,
131+
stats_monitor: AgentStatsPluginContext,
132+
error_monitor: AgentErrorPluginContext,
133+
agent_public_key: Optional[PublicKey],
124134
) -> AbstractAgent:
125135
agent_kwargs = {
126-
"stats_monitor": self.stats_monitor,
127-
"error_monitor": self.error_monitor,
128-
"agent_public_key": self.agent_public_key,
136+
"stats_monitor": stats_monitor,
137+
"error_monitor": error_monitor,
138+
"agent_public_key": agent_public_key,
129139
}
130140

131141
backend = self._local_config.agent_common.backend

src/ai/backend/agent/server.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def __init__(
285285
self.loop = current_loop()
286286
self.etcd = etcd
287287
self.local_config = local_config
288+
self.runtime = AgentRuntime(self.local_config, self.etcd)
288289
self.skip_detect_manager = skip_detect_manager
289290

290291
async def __ainit__(self) -> None:
@@ -337,9 +338,7 @@ async def __ainit__(self) -> None:
337338
self.rpc_auth_agent_secret_key = None
338339
auth_handler = None
339340

340-
self.runtime = await AgentRuntime.new(
341-
self.local_config,
342-
self.etcd,
341+
await self.runtime.create_agents(
343342
self.stats_monitor,
344343
self.error_monitor,
345344
self.rpc_auth_agent_public_key,

0 commit comments

Comments
 (0)