Skip to content

Commit ae23173

Browse files
committed
feat(BA-2753): Spawn multiple agents and route RPC appropriately
This change adds support for actually spawning multiple agents within the same agent server and adding agent_id field for all appropriate RPC calls in the agent server, then ensuring that the manager sends that info such that the agent server can correctly route the RPC calls to the correct agent.
1 parent 09b2462 commit ae23173

File tree

12 files changed

+1002
-170
lines changed

12 files changed

+1002
-170
lines changed

changes/6320.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update Agent server RPC functions to include agent ID for agent runtime with multiple agents

src/ai/backend/agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,7 @@ async def scan_running_kernels(self) -> None:
22762276
"""
22772277
ipc_base_path = self.local_config.agent.ipc_base_path
22782278
var_base_path = self.local_config.agent.var_base_path
2279-
last_registry_file = f"last_registry.{self.local_instance_id}.dat"
2279+
last_registry_file = f"last_registry.{self.id}.dat"
22802280
if os.path.isfile(ipc_base_path / last_registry_file):
22812281
shutil.move(ipc_base_path / last_registry_file, var_base_path / last_registry_file)
22822282
try:
@@ -3745,7 +3745,7 @@ async def save_last_registry(self, force=False) -> None:
37453745
if (not force) and (now <= self.last_registry_written_time + 60):
37463746
return # don't save too frequently
37473747
var_base_path = self.local_config.agent.var_base_path
3748-
last_registry_file = f"last_registry.{self.local_instance_id}.dat"
3748+
last_registry_file = f"last_registry.{self.id}.dat"
37493749
try:
37503750
with open(var_base_path / last_registry_file, "wb") as f:
37513751
pickle.dump(self.kernel_registry, f)

src/ai/backend/agent/docker/agent.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@
138138
update_nested_dict,
139139
)
140140
from .kernel import DockerKernel
141-
from .metadata.server import MetadataServer
142141
from .resources import load_resources, scan_available_resources
143142
from .utils import PersistentServiceContainer
144143

@@ -1341,7 +1340,6 @@ class DockerAgent(AbstractAgent[DockerKernel, DockerKernelCreationContext]):
13411340
monitor_docker_task: asyncio.Task
13421341
agent_sockpath: Path
13431342
agent_sock_task: asyncio.Task
1344-
metadata_server: MetadataServer
13451343
docker_ptask_group: aiotools.PersistentTaskGroup
13461344
gwbridge_subnet: Optional[str]
13471345
checked_invalid_images: Set[str]
@@ -1414,10 +1412,10 @@ async def __ainit__(self) -> None:
14141412
self.gwbridge_subnet = None
14151413
ipc_base_path = self.local_config.agent.ipc_base_path
14161414
(ipc_base_path / "container").mkdir(parents=True, exist_ok=True)
1417-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock"
1415+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
14181416
# Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs
14191417
if sys.platform != "darwin":
1420-
socket_relay_name = f"backendai-socket-relay.{self.local_instance_id}"
1418+
socket_relay_name = f"backendai-socket-relay.{self.id}"
14211419
socket_relay_container = PersistentServiceContainer(
14221420
"backendai-socket-relay:latest",
14231421
{
@@ -1443,12 +1441,6 @@ async def __ainit__(self) -> None:
14431441
self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events())
14441442
self.docker_ptask_group = aiotools.PersistentTaskGroup()
14451443

1446-
self.metadata_server = await MetadataServer.new(
1447-
self.local_config,
1448-
self.etcd,
1449-
self.kernel_registry,
1450-
)
1451-
await self.metadata_server.start_server()
14521444
# For legacy accelerator plugins
14531445
self.docker = Docker()
14541446

@@ -1477,7 +1469,6 @@ async def shutdown(self, stop_signal: signal.Signals):
14771469
self.monitor_docker_task.cancel()
14781470
await self.monitor_docker_task
14791471

1480-
await self.metadata_server.cleanup()
14811472
if self.docker:
14821473
await self.docker.close()
14831474

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def __init__(
844844
async def __ainit__(self) -> None:
845845
await super().__ainit__()
846846
ipc_base_path = self.local_config.agent.ipc_base_path
847-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock"
847+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
848848

849849
await self.check_krunner_pv_status()
850850
await self.fetch_workers()

src/ai/backend/agent/runtime.py

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,48 @@
1+
import asyncio
12
import importlib
23
import signal
3-
from typing import Optional
4+
from typing import TYPE_CHECKING, Optional
45

56
from ai.backend.agent.agent import AbstractAgent
67
from ai.backend.agent.config.unified import AgentUnifiedConfig
78
from ai.backend.agent.etcd import AgentEtcdClientView
89
from ai.backend.agent.kernel import KernelRegistry
910
from ai.backend.agent.monitor import AgentErrorPluginContext, AgentStatsPluginContext
11+
from ai.backend.agent.types import AgentBackend
1012
from ai.backend.common.auth import PublicKey
1113
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
12-
from ai.backend.common.types import aobject
14+
from ai.backend.common.exception import (
15+
BackendAIError,
16+
ErrorCode,
17+
ErrorDetail,
18+
ErrorDomain,
19+
ErrorOperation,
20+
)
21+
from ai.backend.common.types import AgentId, aobject
22+
23+
if TYPE_CHECKING:
24+
from .docker.metadata.server import MetadataServer
25+
26+
27+
class AgentIdNotFoundError(BackendAIError):
28+
@classmethod
29+
def error_code(cls) -> ErrorCode:
30+
return ErrorCode(
31+
domain=ErrorDomain.AGENT,
32+
operation=ErrorOperation.ACCESS,
33+
error_detail=ErrorDetail.NOT_FOUND,
34+
)
1335

1436

1537
class AgentRuntime(aobject):
1638
local_config: AgentUnifiedConfig
17-
agent: AbstractAgent
39+
agents: dict[AgentId, AbstractAgent]
1840
kernel_registry: KernelRegistry
1941
etcd: AsyncEtcd
20-
etcd_view: AgentEtcdClientView
42+
etcd_views: dict[AgentId, AgentEtcdClientView]
43+
metadata_server: MetadataServer | None
2144

45+
_default_agent_id: AgentId
2246
_stop_signal: signal.Signals
2347

2448
def __init__(
@@ -30,33 +54,68 @@ def __init__(
3054
agent_public_key: Optional[PublicKey],
3155
) -> None:
3256
self.local_config = local_config
57+
self.agents = {}
3358
self.kernel_registry = KernelRegistry()
3459
self.etcd = etcd
35-
self.etcd_view = AgentEtcdClientView(etcd, self.local_config)
60+
self.etcd_views = {}
61+
self.metadata_server = None
3662

63+
agent_configs = self.local_config.get_agent_configs()
64+
self._default_agent_id = AgentId(agent_configs[0].agent.id)
3765
self._stop_signal = signal.SIGTERM
3866

3967
self.stats_monitor = stats_monitor
4068
self.error_monitor = error_monitor
4169
self.agent_public_key = agent_public_key
4270

4371
async def __ainit__(self) -> None:
44-
self.agent = await self._create_agent(self.etcd_view, self.local_config)
72+
if self.local_config.agent_common.backend == AgentBackend.DOCKER:
73+
await self._initialize_metadata_server()
4574

46-
async def __aexit__(self, *exc_info) -> None:
47-
await self.agent.shutdown(self._stop_signal)
75+
tasks = []
76+
async with asyncio.TaskGroup() as tg:
77+
for agent_config in self.local_config.get_agent_configs():
78+
agent_id = AgentId(agent_config.agent.id)
79+
etcd_view = AgentEtcdClientView(self.etcd, agent_config)
4880

49-
def get_agent(self) -> AbstractAgent:
50-
return self.agent
81+
self.etcd_views[agent_id] = etcd_view
82+
tasks.append(tg.create_task(self._create_agent(etcd_view, agent_config)))
83+
self.agents = {(agent := task.result()).id: agent for task in tasks}
5184

52-
def get_etcd(self) -> AgentEtcdClientView:
53-
return self.etcd_view
85+
async def __aexit__(self, *exc_info) -> None:
86+
for agent in self.agents.values():
87+
await agent.shutdown(self._stop_signal)
88+
if self.metadata_server is not None:
89+
await self.metadata_server.cleanup()
90+
91+
def get_agents(self) -> list[AbstractAgent]:
92+
return list(self.agents.values())
93+
94+
def get_agent(self, agent_id: Optional[AgentId]) -> AbstractAgent:
95+
if agent_id is None:
96+
agent_id = self._default_agent_id
97+
if agent_id not in self.agents:
98+
raise AgentIdNotFoundError(
99+
f"Agent '{agent_id}' not found in this runtime. "
100+
f"Available agents: {', '.join(self.agents.keys())}"
101+
)
102+
return self.agents[agent_id]
103+
104+
def get_etcd(self, agent_id: Optional[AgentId]) -> AgentEtcdClientView:
105+
if agent_id is None:
106+
agent_id = self._default_agent_id
107+
if agent_id not in self.agents:
108+
raise AgentIdNotFoundError(
109+
f"Agent '{agent_id}' not found in this runtime. "
110+
f"Available agents: {', '.join(self.agents.keys())}"
111+
)
112+
return self.etcd_views[agent_id]
54113

55114
def mark_stop_signal(self, stop_signal: signal.Signals) -> None:
56115
self._stop_signal = stop_signal
57116

58-
async def update_status(self, status) -> None:
59-
etcd = self.get_etcd()
117+
async def update_status(self, status, agent_id: AgentId) -> None:
118+
etcd = self.get_etcd(agent_id)
60119
await etcd.put("", status, scope=ConfigScopes.NODE)
61120

62121
async def _create_agent(
@@ -75,3 +134,13 @@ async def _create_agent(
75134
agent_cls = agent_mod.get_agent_cls()
76135

77136
return agent_cls.new(etcd_view, agent_config, **agent_kwargs)
137+
138+
async def _initialize_metadata_server(self) -> None:
139+
from .docker.metadata.server import MetadataServer
140+
141+
self.metadata_server = await MetadataServer.new(
142+
self.local_config,
143+
self.etcd,
144+
kernel_registry=self.kernel_registry.global_view(),
145+
)
146+
await self.metadata_server.start_server()

0 commit comments

Comments
 (0)