Skip to content

Commit 04a0f3a

Browse files
committed
feat(BA-2753): Spawn multiple agents and route RPC appropriately
This change adds support for actually spawning multiple agents within the same agent server and adding agent_id field for all appropriate RPC calls in the agent server, then ensuring that the manager sends that info such that the agent server can correctly route the RPC calls to the correct agent.
1 parent 2f60616 commit 04a0f3a

File tree

12 files changed

+1251
-269
lines changed

12 files changed

+1251
-269
lines changed

changes/6320.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update Agent server RPC functions to include agent ID for agent runtime with multiple agents

src/ai/backend/agent/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,8 @@ async def __ainit__(self) -> None:
888888
"AbstractAgent.__ainit__": "Redis runtime configuration is not set."
889889
})
890890

891+
self.local_config.agent.image_commit_path.mkdir(parents=True, exist_ok=True)
892+
891893
redis_profile_target = self.local_config.redis.to_redis_profile_target()
892894
stream_redis_target = redis_profile_target.profile_target(RedisRole.STREAM)
893895
mq = await self._make_message_queue(stream_redis_target)
@@ -2276,7 +2278,7 @@ async def scan_running_kernels(self) -> None:
22762278
"""
22772279
ipc_base_path = self.local_config.agent.ipc_base_path
22782280
var_base_path = self.local_config.agent.var_base_path
2279-
last_registry_file = f"last_registry.{self.local_instance_id}.dat"
2281+
last_registry_file = f"last_registry.{self.id}.dat"
22802282
if os.path.isfile(ipc_base_path / last_registry_file):
22812283
shutil.move(ipc_base_path / last_registry_file, var_base_path / last_registry_file)
22822284
try:
@@ -3745,7 +3747,7 @@ async def save_last_registry(self, force=False) -> None:
37453747
if (not force) and (now <= self.last_registry_written_time + 60):
37463748
return # don't save too frequently
37473749
var_base_path = self.local_config.agent.var_base_path
3748-
last_registry_file = f"last_registry.{self.local_instance_id}.dat"
3750+
last_registry_file = f"last_registry.{self.id}.dat"
37493751
try:
37503752
with open(var_base_path / last_registry_file, "wb") as f:
37513753
pickle.dump(self.kernel_registry, f)

src/ai/backend/agent/config/unified.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,12 @@ class AgentOverrideConfig(BaseConfigSchema):
11401140
description="Resource config overrides for the individual agent",
11411141
)
11421142

1143-
def construct_unified_config(self, *, default: AgentUnifiedConfig) -> AgentUnifiedConfig:
1143+
def construct_unified_config(
1144+
self,
1145+
*,
1146+
default: AgentUnifiedConfig,
1147+
agent_idx: int,
1148+
) -> AgentUnifiedConfig:
11441149
agent_updates: dict[str, Any] = {}
11451150
if self.agent is not None:
11461151
agent_override_fields = self.agent.model_dump(include=self.agent.model_fields_set)
@@ -1282,7 +1287,10 @@ def validate(config: AgentSpecificConfig) -> None:
12821287
return self
12831288

12841289
def _for_each_agent(self, func: Callable[[AgentUnifiedConfig], R]) -> list[R]:
1285-
agents = [agent.construct_unified_config(default=self) for agent in self.agents]
1290+
agents = [
1291+
agent.construct_unified_config(default=self, agent_idx=i)
1292+
for i, agent in enumerate(self.agents)
1293+
]
12861294
if not agents:
12871295
agents.append(self)
12881296

src/ai/backend/agent/docker/agent.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,7 +1298,7 @@ class DockerAgent(AbstractAgent[DockerKernel, DockerKernelCreationContext]):
12981298
monitor_docker_task: asyncio.Task
12991299
agent_sockpath: Path
13001300
agent_sock_task: asyncio.Task
1301-
metadata_server: MetadataServer
1301+
metadata_server: Optional[MetadataServer]
13021302
docker_ptask_group: aiotools.PersistentTaskGroup
13031303
gwbridge_subnet: Optional[str]
13041304
checked_invalid_images: Set[str]
@@ -1324,6 +1324,9 @@ def __init__(
13241324
agent_public_key=agent_public_key,
13251325
)
13261326
self.checked_invalid_images = set()
1327+
# MetadataServer must be shared across all instances of DockerAgent.
1328+
# metadata_server is initialized by AgentRPCServer and assigned via assign_metadata_server()
1329+
self.metadata_server = None
13271330

13281331
async def __ainit__(self) -> None:
13291332
async with closing_async(Docker()) as docker:
@@ -1369,10 +1372,10 @@ async def __ainit__(self) -> None:
13691372
self.gwbridge_subnet = None
13701373
ipc_base_path = self.local_config.agent.ipc_base_path
13711374
(ipc_base_path / "container").mkdir(parents=True, exist_ok=True)
1372-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock"
1375+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
13731376
# Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs
13741377
if sys.platform != "darwin":
1375-
socket_relay_name = f"backendai-socket-relay.{self.local_instance_id}"
1378+
socket_relay_name = f"backendai-socket-relay.{self.id}"
13761379
socket_relay_container = PersistentServiceContainer(
13771380
"backendai-socket-relay:latest",
13781381
{
@@ -1398,12 +1401,6 @@ async def __ainit__(self) -> None:
13981401
self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events())
13991402
self.docker_ptask_group = aiotools.PersistentTaskGroup()
14001403

1401-
self.metadata_server = await MetadataServer.new(
1402-
self.local_config,
1403-
self.etcd,
1404-
self.kernel_registry,
1405-
)
1406-
await self.metadata_server.start_server()
14071404
# For legacy accelerator plugins
14081405
self.docker = Docker()
14091406

@@ -1416,6 +1413,9 @@ async def __ainit__(self) -> None:
14161413
blocklist=self.local_config.agent.block_network_plugins,
14171414
)
14181415

1416+
def assign_metadata_server(self, metadata_server: MetadataServer) -> None:
1417+
self.metadata_server = metadata_server
1418+
14191419
async def shutdown(self, stop_signal: signal.Signals):
14201420
# Stop handling agent sock.
14211421
if self.agent_sock_task is not None:
@@ -1432,7 +1432,6 @@ async def shutdown(self, stop_signal: signal.Signals):
14321432
self.monitor_docker_task.cancel()
14331433
await self.monitor_docker_task
14341434

1435-
await self.metadata_server.cleanup()
14361435
if self.docker:
14371436
await self.docker.close()
14381437

src/ai/backend/agent/etcd.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Collection
2+
3+
from ai.backend.agent.config.unified import AgentUnifiedConfig
4+
from ai.backend.common.data.config.types import EtcdConfigData
5+
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
6+
from ai.backend.common.types import AgentId
7+
8+
9+
class EtcdClientRegistry:
10+
_etcd_config: EtcdConfigData
11+
_etcd_clients: dict[AgentId, AsyncEtcd]
12+
_global_etcd: AsyncEtcd
13+
14+
@property
15+
def global_etcd(self) -> AsyncEtcd:
16+
return self._global_etcd
17+
18+
def __init__(self, etcd_config: EtcdConfigData) -> None:
19+
self._etcd_config = etcd_config
20+
self._etcd_clients = {}
21+
self._global_etcd = self._create_client(agent_id=None, scaling_group=None)
22+
23+
async def close(self) -> None:
24+
for etcd in self._etcd_clients.values():
25+
await etcd.close()
26+
await self._global_etcd.close()
27+
28+
def get_client(self, agent_id: AgentId) -> AsyncEtcd:
29+
return self._etcd_clients[agent_id]
30+
31+
def prefill_clients(self, prefill_data: Collection[AgentUnifiedConfig]) -> None:
32+
for agent_config in prefill_data:
33+
agent_id = AgentId(agent_config.agent.id)
34+
self._etcd_clients[agent_id] = self._create_client(
35+
agent_id, agent_config.agent.scaling_group
36+
)
37+
38+
def _create_client(self, agent_id: AgentId | None, scaling_group: str | None) -> AsyncEtcd:
39+
scope_prefix_map = {ConfigScopes.GLOBAL: ""}
40+
if scaling_group is not None:
41+
scope_prefix_map[ConfigScopes.SGROUP] = f"sgroup/{scaling_group}"
42+
if agent_id is not None:
43+
scope_prefix_map[ConfigScopes.NODE] = f"nodes/agents/{agent_id}"
44+
45+
if self._etcd_config.user and self._etcd_config.password:
46+
etcd_credentials = {
47+
"user": self._etcd_config.user,
48+
"password": self._etcd_config.password,
49+
}
50+
else:
51+
etcd_credentials = None
52+
53+
return AsyncEtcd(
54+
[addr.to_legacy() for addr in self._etcd_config.addrs],
55+
self._etcd_config.namespace,
56+
scope_prefix_map,
57+
credentials=etcd_credentials,
58+
)

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ def __init__(
842842
async def __ainit__(self) -> None:
843843
await super().__ainit__()
844844
ipc_base_path = self.local_config.agent.ipc_base_path
845-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock"
845+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
846846

847847
await self.check_krunner_pv_status()
848848
await self.fetch_workers()

0 commit comments

Comments
 (0)