Skip to content

Commit a3d0d3b

Browse files
committed
feat(BA-2753): Spawn multiple agents and route RPC appropriately
This change adds support for actually spawning multiple agents within the same agent server and adding agent_id field for all appropriate RPC calls in the agent server, then ensuring that the manager sends that info such that the agent server can correctly route the RPC calls to the correct agent.
1 parent 001bbe9 commit a3d0d3b

File tree

12 files changed

+994
-166
lines changed

12 files changed

+994
-166
lines changed

changes/6320.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update Agent server RPC functions to include agent ID for agent runtime with multiple agents

src/ai/backend/agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,7 @@ async def scan_running_kernels(self) -> None:
22762276
"""
22772277
ipc_base_path = self.local_config.agent.ipc_base_path
22782278
var_base_path = self.local_config.agent.var_base_path
2279-
last_registry_file = f"last_registry.{self.local_instance_id}.dat"
2279+
last_registry_file = f"last_registry.{self.id}.dat"
22802280
if os.path.isfile(ipc_base_path / last_registry_file):
22812281
shutil.move(ipc_base_path / last_registry_file, var_base_path / last_registry_file)
22822282
try:
@@ -3745,7 +3745,7 @@ async def save_last_registry(self, force=False) -> None:
37453745
if (not force) and (now <= self.last_registry_written_time + 60):
37463746
return # don't save too frequently
37473747
var_base_path = self.local_config.agent.var_base_path
3748-
last_registry_file = f"last_registry.{self.local_instance_id}.dat"
3748+
last_registry_file = f"last_registry.{self.id}.dat"
37493749
try:
37503750
with open(var_base_path / last_registry_file, "wb") as f:
37513751
pickle.dump(self.kernel_registry, f)

src/ai/backend/agent/docker/agent.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,7 @@ def __init__(
13571357
error_monitor: ErrorPluginContext,
13581358
skip_initial_scan: bool = False,
13591359
agent_public_key: Optional[PublicKey],
1360+
metadata_server: MetadataServer,
13601361
kernel_registry: KernelRegistry,
13611362
) -> None:
13621363
super().__init__(
@@ -1369,6 +1370,7 @@ def __init__(
13691370
kernel_registry=kernel_registry,
13701371
)
13711372
self.checked_invalid_images = set()
1373+
self.metadata_server = metadata_server
13721374

13731375
async def __ainit__(self) -> None:
13741376
async with closing_async(Docker()) as docker:
@@ -1414,10 +1416,10 @@ async def __ainit__(self) -> None:
14141416
self.gwbridge_subnet = None
14151417
ipc_base_path = self.local_config.agent.ipc_base_path
14161418
(ipc_base_path / "container").mkdir(parents=True, exist_ok=True)
1417-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock"
1419+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
14181420
# Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs
14191421
if sys.platform != "darwin":
1420-
socket_relay_name = f"backendai-socket-relay.{self.local_instance_id}"
1422+
socket_relay_name = f"backendai-socket-relay.{self.id}"
14211423
socket_relay_container = PersistentServiceContainer(
14221424
"backendai-socket-relay:latest",
14231425
{
@@ -1443,12 +1445,6 @@ async def __ainit__(self) -> None:
14431445
self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events())
14441446
self.docker_ptask_group = aiotools.PersistentTaskGroup()
14451447

1446-
self.metadata_server = await MetadataServer.new(
1447-
self.local_config,
1448-
self.etcd,
1449-
self.kernel_registry,
1450-
)
1451-
await self.metadata_server.start_server()
14521448
# For legacy accelerator plugins
14531449
self.docker = Docker()
14541450

@@ -1477,7 +1473,6 @@ async def shutdown(self, stop_signal: signal.Signals):
14771473
self.monitor_docker_task.cancel()
14781474
await self.monitor_docker_task
14791475

1480-
await self.metadata_server.cleanup()
14811476
if self.docker:
14821477
await self.docker.close()
14831478

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def __init__(
844844
async def __ainit__(self) -> None:
845845
await super().__ainit__()
846846
ipc_base_path = self.local_config.agent.ipc_base_path
847-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock"
847+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
848848

849849
await self.check_krunner_pv_status()
850850
await self.fetch_workers()

src/ai/backend/agent/runtime.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,45 @@
1+
import asyncio
12
import signal
23
from typing import Optional
34

45
from ai.backend.agent.agent import AbstractAgent
56
from ai.backend.agent.config.unified import AgentUnifiedConfig
7+
from ai.backend.agent.docker.metadata.server import MetadataServer
68
from ai.backend.agent.etcd import AgentEtcdClientView
79
from ai.backend.agent.kernel import KernelRegistry
810
from ai.backend.agent.monitor import AgentErrorPluginContext, AgentStatsPluginContext
911
from ai.backend.agent.types import AgentBackend
1012
from ai.backend.common.auth import PublicKey
1113
from ai.backend.common.etcd import AsyncEtcd
12-
from ai.backend.common.types import aobject
14+
from ai.backend.common.exception import (
15+
BackendAIError,
16+
ErrorCode,
17+
ErrorDetail,
18+
ErrorDomain,
19+
ErrorOperation,
20+
)
21+
from ai.backend.common.types import AgentId, aobject
22+
23+
24+
class AgentIdNotFoundError(BackendAIError):
25+
@classmethod
26+
def error_code(cls) -> ErrorCode:
27+
return ErrorCode(
28+
domain=ErrorDomain.AGENT,
29+
operation=ErrorOperation.ACCESS,
30+
error_detail=ErrorDetail.NOT_FOUND,
31+
)
1332

1433

1534
class AgentRuntime(aobject):
1635
local_config: AgentUnifiedConfig
17-
agent: AbstractAgent
36+
agents: dict[AgentId, AbstractAgent]
1837
kernel_registry: KernelRegistry
1938
etcd: AsyncEtcd
20-
etcd_view: AgentEtcdClientView
39+
etcd_views: dict[AgentId, AgentEtcdClientView]
40+
metadata_server: MetadataServer | None
2141

42+
_default_agent_id: AgentId
2243
_stop_signal: signal.Signals
2344

2445
def __init__(
@@ -30,27 +51,58 @@ def __init__(
3051
agent_public_key: Optional[PublicKey],
3152
) -> None:
3253
self.local_config = local_config
54+
self.agents = {}
3355
self.kernel_registry = KernelRegistry()
3456
self.etcd = etcd
35-
self.etcd_view = AgentEtcdClientView(etcd, self.local_config)
57+
self.etcd_views = {}
58+
self.metadata_server = None
3659

60+
self._default_agent_id = AgentId(self.local_config.agent_configs[0].agent.id)
3761
self._stop_signal = signal.SIGTERM
3862

3963
self.stats_monitor = stats_monitor
4064
self.error_monitor = error_monitor
4165
self.agent_public_key = agent_public_key
4266

4367
async def __ainit__(self) -> None:
44-
self.agent = await self._create_agent(self.etcd_view, self.local_config)
68+
tasks = []
69+
async with asyncio.TaskGroup() as tg:
70+
for agent_config in self.local_config.agent_configs:
71+
agent_id = AgentId(agent_config.agent.id)
72+
etcd_view = AgentEtcdClientView(self.etcd, agent_config)
4573

46-
async def __aexit__(self, *exc_info) -> None:
47-
await self.agent.shutdown(self._stop_signal)
48-
49-
def get_agent(self) -> AbstractAgent:
50-
return self.agent
74+
self.etcd_views[agent_id] = etcd_view
75+
tasks.append(tg.create_task(self._create_agent(etcd_view, agent_config)))
76+
self.agents = {(agent := task.result()).id: agent for task in tasks}
5177

52-
def get_etcd(self) -> AgentEtcdClientView:
53-
return self.etcd_view
78+
async def __aexit__(self, *exc_info) -> None:
79+
for agent in self.agents.values():
80+
await agent.shutdown(self._stop_signal)
81+
if self.metadata_server is not None:
82+
await self.metadata_server.cleanup()
83+
84+
def get_agents(self) -> list[AbstractAgent]:
85+
return list(self.agents.values())
86+
87+
def get_agent(self, agent_id: Optional[AgentId]) -> AbstractAgent:
88+
if agent_id is None:
89+
agent_id = self._default_agent_id
90+
if agent_id not in self.agents:
91+
raise AgentIdNotFoundError(
92+
f"Agent '{agent_id}' not found in this runtime. "
93+
f"Available agents: {', '.join(self.agents.keys())}"
94+
)
95+
return self.agents[agent_id]
96+
97+
def get_etcd(self, agent_id: Optional[AgentId]) -> AgentEtcdClientView:
98+
if agent_id is None:
99+
agent_id = self._default_agent_id
100+
if agent_id not in self.agents:
101+
raise AgentIdNotFoundError(
102+
f"Agent '{agent_id}' not found in this runtime. "
103+
f"Available agents: {', '.join(self.agents.keys())}"
104+
)
105+
return self.etcd_views[agent_id]
54106

55107
def mark_stop_signal(self, stop_signal: signal.Signals) -> None:
56108
self._stop_signal = stop_signal
@@ -64,12 +116,15 @@ async def _create_agent(
64116
case AgentBackend.DOCKER:
65117
from .docker.agent import DockerAgent
66118

119+
await self._initialize_metadata_server()
120+
67121
return await DockerAgent.new(
68122
etcd_view,
69123
agent_config,
70124
stats_monitor=self.stats_monitor,
71125
error_monitor=self.error_monitor,
72126
agent_public_key=self.agent_public_key,
127+
metadata_server=self.metadata_server,
73128
)
74129
case AgentBackend.KUBERNETES:
75130
from .kubernetes.agent import KubernetesAgent
@@ -91,3 +146,13 @@ async def _create_agent(
91146
error_monitor=self.error_monitor,
92147
agent_public_key=self.agent_public_key,
93148
)
149+
150+
async def _initialize_metadata_server(self) -> None:
151+
from .docker.metadata.server import MetadataServer
152+
153+
self.metadata_server = await MetadataServer.new(
154+
self.local_config,
155+
self.etcd,
156+
kernel_registry=self.kernel_registry.global_view(),
157+
)
158+
await self.metadata_server.start_server()

0 commit comments

Comments
 (0)