Skip to content

Commit 1169ebb

Browse files
committed
feat(BA-2753): Spawn multiple agents and route RPC appropriately
This change adds support for actually spawning multiple agents within the same agent server and adding agent_id field for all appropriate RPC calls in the agent server, then ensuring that the manager sends that info such that the agent server can correctly route the RPC calls to the correct agent.
1 parent 3318324 commit 1169ebb

File tree

13 files changed

+1197
-224
lines changed

13 files changed

+1197
-224
lines changed

changes/6320.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update Agent server RPC functions to include agent ID for agent runtime with multiple agents

src/ai/backend/agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,7 @@ async def scan_running_kernels(self) -> None:
22762276
"""
22772277
ipc_base_path = self.local_config.agent.ipc_base_path
22782278
var_base_path = self.local_config.agent.var_base_path
2279-
last_registry_file = f"last_registry.{self.local_instance_id}.dat"
2279+
last_registry_file = f"last_registry.{self.id}.dat"
22802280
if os.path.isfile(ipc_base_path / last_registry_file):
22812281
shutil.move(ipc_base_path / last_registry_file, var_base_path / last_registry_file)
22822282
try:
@@ -3745,7 +3745,7 @@ async def save_last_registry(self, force=False) -> None:
37453745
if (not force) and (now <= self.last_registry_written_time + 60):
37463746
return # don't save too frequently
37473747
var_base_path = self.local_config.agent.var_base_path
3748-
last_registry_file = f"last_registry.{self.local_instance_id}.dat"
3748+
last_registry_file = f"last_registry.{self.id}.dat"
37493749
try:
37503750
with open(var_base_path / last_registry_file, "wb") as f:
37513751
pickle.dump(self.kernel_registry, f)

src/ai/backend/agent/docker/agent.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,7 @@ def __init__(
13141314
error_monitor: ErrorPluginContext,
13151315
skip_initial_scan: bool = False,
13161316
agent_public_key: Optional[PublicKey],
1317+
metadata_server: MetadataServer,
13171318
) -> None:
13181319
super().__init__(
13191320
etcd,
@@ -1324,6 +1325,8 @@ def __init__(
13241325
agent_public_key=agent_public_key,
13251326
)
13261327
self.checked_invalid_images = set()
1328+
self.metadata_server = metadata_server
1329+
self.metadata_server.register_agent(self)
13271330

13281331
async def __ainit__(self) -> None:
13291332
async with closing_async(Docker()) as docker:
@@ -1369,10 +1372,10 @@ async def __ainit__(self) -> None:
13691372
self.gwbridge_subnet = None
13701373
ipc_base_path = self.local_config.agent.ipc_base_path
13711374
(ipc_base_path / "container").mkdir(parents=True, exist_ok=True)
1372-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock"
1375+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
13731376
# Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs
13741377
if sys.platform != "darwin":
1375-
socket_relay_name = f"backendai-socket-relay.{self.local_instance_id}"
1378+
socket_relay_name = f"backendai-socket-relay.{self.id}"
13761379
socket_relay_container = PersistentServiceContainer(
13771380
"backendai-socket-relay:latest",
13781381
{
@@ -1398,12 +1401,6 @@ async def __ainit__(self) -> None:
13981401
self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events())
13991402
self.docker_ptask_group = aiotools.PersistentTaskGroup()
14001403

1401-
self.metadata_server = await MetadataServer.new(
1402-
self.local_config,
1403-
self.etcd,
1404-
self.kernel_registry,
1405-
)
1406-
await self.metadata_server.start_server()
14071404
# For legacy accelerator plugins
14081405
self.docker = Docker()
14091406

@@ -1432,7 +1429,6 @@ async def shutdown(self, stop_signal: signal.Signals):
14321429
self.monitor_docker_task.cancel()
14331430
await self.monitor_docker_task
14341431

1435-
await self.metadata_server.cleanup()
14361432
if self.docker:
14371433
await self.docker.close()
14381434

src/ai/backend/agent/docker/metadata/server.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import logging
22
from collections.abc import Sequence
33
from http import HTTPStatus
4-
from typing import Any, List, Mapping, MutableMapping, cast
4+
from typing import Any, Iterator, List, Mapping, MutableMapping, cast, override
55
from uuid import UUID
66

77
import attr
88
from aiodocker.docker import Docker
99
from aiohttp import web
1010
from aiohttp.typedefs import Handler, Middleware
1111

12+
from ai.backend.agent.agent import AbstractAgent
1213
from ai.backend.agent.config.unified import AgentUnifiedConfig
1314
from ai.backend.agent.docker.kernel import prepare_kernel_metadata_uri_handling
1415
from ai.backend.agent.kernel import AbstractKernel
@@ -17,7 +18,7 @@
1718
from ai.backend.common.etcd import AsyncEtcd
1819
from ai.backend.common.json import dump_json_str
1920
from ai.backend.common.plugin import BasePluginContext
20-
from ai.backend.common.types import KernelId, aobject
21+
from ai.backend.common.types import AgentId, KernelId, aobject
2122
from ai.backend.logging import BraceStyleAdapter
2223

2324
from .plugin import MetadataPlugin
@@ -87,17 +88,50 @@ async def list_versions(request: web.Request) -> web.Response:
8788
return web.Response(body="latest/")
8889

8990

91+
class AggregateKernelRegistry(Mapping[KernelId, AbstractKernel]):
92+
_agents: dict[AgentId, AbstractAgent]
93+
94+
def __init__(self) -> None:
95+
self._agents = {}
96+
97+
def register_agent(self, agent: AbstractAgent) -> None:
98+
self._agents[agent.id] = agent
99+
100+
@override
101+
def __getitem__(self, kernel_id: KernelId) -> AbstractKernel:
102+
for agent in self._agents.values():
103+
if kernel_id in agent.kernel_registry:
104+
return agent.kernel_registry[kernel_id]
105+
raise KeyError(kernel_id)
106+
107+
@override
108+
def __iter__(self) -> Iterator[KernelId]:
109+
for agent in self._agents.values():
110+
yield from agent.kernel_registry.keys()
111+
112+
@override
113+
def __len__(self) -> int:
114+
return sum(len(agent.kernel_registry) for agent in self._agents.values())
115+
116+
@override
117+
def __contains__(self, x: object, /) -> bool:
118+
if not isinstance(x, str):
119+
return False
120+
return any(agent.__contains__(x) for agent in self._agents)
121+
122+
90123
class MetadataServer(aobject):
91124
app: web.Application
92125
runner: web.AppRunner
93126
route_structure: MutableMapping[str, Any]
94127
loaded_apps: List[str]
128+
kernel_registry: AggregateKernelRegistry
95129

96130
def __init__(
97131
self,
98132
local_config: AgentUnifiedConfig,
99133
etcd: AsyncEtcd,
100-
kernel_registry: Mapping[KernelId, AbstractKernel],
134+
kernel_registry: AggregateKernelRegistry,
101135
) -> None:
102136
app = web.Application(
103137
middlewares=[
@@ -112,6 +146,7 @@ def __init__(
112146
self.app = app
113147
self.loaded_apps = []
114148
self.route_structure = {"latest": {"extension": {}}}
149+
self.kernel_registry = kernel_registry
115150

116151
async def __ainit__(self):
117152
local_config = cast(AgentUnifiedConfig, self.app["_root.context"].local_config)
@@ -132,6 +167,9 @@ async def __ainit__(self):
132167
self.app.router.add_route("GET", "/", list_versions)
133168
self.app.router.add_route("GET", "/{version}", self.list_available_apps)
134169

170+
def register_agent(self, agent: AbstractAgent) -> None:
171+
self.kernel_registry.register_agent(agent)
172+
135173
async def list_available_apps(self, request: web.Request) -> web.Response:
136174
return web.Response(body="\n".join([x + "/" for x in self.loaded_apps]))
137175

src/ai/backend/agent/etcd.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Collection
2+
3+
from ai.backend.agent.config.unified import AgentUnifiedConfig
4+
from ai.backend.common.data.config.types import EtcdConfigData
5+
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
6+
from ai.backend.common.types import AgentId
7+
8+
9+
class EtcdClientRegistry:
10+
_etcd_config: EtcdConfigData
11+
_etcd_clients: dict[AgentId, AsyncEtcd]
12+
_global_etcd: AsyncEtcd
13+
14+
@property
15+
def global_etcd(self) -> AsyncEtcd:
16+
return self._global_etcd
17+
18+
def __init__(self, etcd_config: EtcdConfigData) -> None:
19+
self._etcd_config = etcd_config
20+
self._etcd_clients = {}
21+
self._global_etcd = self._create_client(agent_id=None, scaling_group=None)
22+
23+
async def close(self) -> None:
24+
for etcd in self._etcd_clients.values():
25+
await etcd.close()
26+
await self._global_etcd.close()
27+
28+
def get_client(self, agent_id: AgentId) -> AsyncEtcd:
29+
return self._etcd_clients[agent_id]
30+
31+
def prefill_clients(self, prefill_data: Collection[AgentUnifiedConfig]) -> None:
32+
for agent_config in prefill_data:
33+
agent_id = AgentId(agent_config.agent.id)
34+
self._etcd_clients[agent_id] = self._create_client(
35+
agent_id, agent_config.agent.scaling_group
36+
)
37+
38+
def _create_client(self, agent_id: AgentId | None, scaling_group: str | None) -> AsyncEtcd:
39+
scope_prefix_map = {ConfigScopes.GLOBAL: ""}
40+
if scaling_group is not None:
41+
scope_prefix_map[ConfigScopes.SGROUP] = f"sgroup/{scaling_group}"
42+
if agent_id is not None:
43+
scope_prefix_map[ConfigScopes.NODE] = f"nodes/agents/{agent_id}"
44+
45+
if self._etcd_config.user and self._etcd_config.password:
46+
etcd_credentials = {
47+
"user": self._etcd_config.user,
48+
"password": self._etcd_config.password,
49+
}
50+
else:
51+
etcd_credentials = None
52+
53+
return AsyncEtcd(
54+
[addr.to_legacy() for addr in self._etcd_config.addrs],
55+
self._etcd_config.namespace,
56+
scope_prefix_map,
57+
credentials=etcd_credentials,
58+
)

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ def __init__(
842842
async def __ainit__(self) -> None:
843843
await super().__ainit__()
844844
ipc_base_path = self.local_config.agent.ipc_base_path
845-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock"
845+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
846846

847847
await self.check_krunner_pv_status()
848848
await self.fetch_workers()

0 commit comments

Comments
 (0)