Skip to content

Commit 242db6e

Browse files
committed
feat(BA-3023): Add unified multi agent etcd client handling
This change introduces EtcdClientRegistry, which is a class that can contain etcd clients for multiple agents, each with its own prefix information at scaling group and individual node level.
1 parent 3318324 commit 242db6e

File tree

4 files changed

+118
-51
lines changed

4 files changed

+118
-51
lines changed

changes/6721.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add EtcdClientRegistry for clean handling of etcd clients for multi agents

src/ai/backend/agent/etcd.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Collection
2+
3+
from ai.backend.agent.config.unified import AgentUnifiedConfig
4+
from ai.backend.common.data.config.types import EtcdConfigData
5+
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
6+
from ai.backend.common.types import AgentId
7+
8+
9+
class EtcdClientRegistry:
10+
_etcd_config: EtcdConfigData
11+
_etcd_clients: dict[AgentId, AsyncEtcd]
12+
_global_etcd: AsyncEtcd
13+
14+
@property
15+
def global_etcd(self) -> AsyncEtcd:
16+
return self._global_etcd
17+
18+
def __init__(self, etcd_config: EtcdConfigData) -> None:
19+
self._etcd_config = etcd_config
20+
self._etcd_clients = {}
21+
self._global_etcd = self._create_client(agent_id=None, scaling_group=None)
22+
23+
async def close(self) -> None:
24+
for etcd in self._etcd_clients.values():
25+
await etcd.close()
26+
await self._global_etcd.close()
27+
28+
def get_client(self, agent_id: AgentId) -> AsyncEtcd:
29+
return self._etcd_clients[agent_id]
30+
31+
def prefill_clients(self, prefill_data: Collection[AgentUnifiedConfig]) -> None:
32+
for agent_config in prefill_data:
33+
agent_id = AgentId(agent_config.agent.id)
34+
self._etcd_clients[agent_id] = self._create_client(
35+
agent_id, agent_config.agent.scaling_group
36+
)
37+
38+
def _create_client(self, agent_id: AgentId | None, scaling_group: str | None) -> AsyncEtcd:
39+
scope_prefix_map = {ConfigScopes.GLOBAL: ""}
40+
if scaling_group is not None:
41+
scope_prefix_map[ConfigScopes.SGROUP] = f"sgroup/{scaling_group}"
42+
if agent_id is not None:
43+
scope_prefix_map[ConfigScopes.NODE] = f"nodes/agents/{agent_id}"
44+
45+
if self._etcd_config.user and self._etcd_config.password:
46+
etcd_credentials = {
47+
"user": self._etcd_config.user,
48+
"password": self._etcd_config.password,
49+
}
50+
else:
51+
etcd_credentials = None
52+
53+
return AsyncEtcd(
54+
[addr.to_legacy() for addr in self._etcd_config.addrs],
55+
self._etcd_config.namespace,
56+
scope_prefix_map,
57+
credentials=etcd_credentials,
58+
)

src/ai/backend/agent/server.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from setproctitle import setproctitle
5050
from zmq.auth.certs import load_certificate
5151

52+
from ai.backend.agent.etcd import EtcdClientRegistry
5253
from ai.backend.agent.metrics.metric import RPCMetricObserver
5354
from ai.backend.agent.resources import scan_gpu_alloc_map
5455
from ai.backend.common import config, identity, msgpack, utils
@@ -96,6 +97,7 @@
9697
ServiceMetadata,
9798
)
9899
from ai.backend.common.types import (
100+
AgentId,
99101
ClusterInfo,
100102
CommitStatus,
101103
ContainerId,
@@ -309,6 +311,7 @@ class AgentRPCServer(aobject):
309311

310312
loop: asyncio.AbstractEventLoop
311313
agent: AbstractAgent
314+
etcd_client_registry: EtcdClientRegistry
312315
rpc_server: Peer
313316
rpc_addr: str
314317
agent_addr: str
@@ -318,17 +321,19 @@ class AgentRPCServer(aobject):
318321

319322
def __init__(
320323
self,
321-
etcd: AsyncEtcd,
324+
etcd_client_registry: EtcdClientRegistry,
322325
local_config: AgentUnifiedConfig,
323326
*,
324327
skip_detect_manager: bool = False,
325328
) -> None:
326329
self.loop = current_loop()
327-
self.etcd = etcd
330+
self.etcd_client_registry = etcd_client_registry
328331
self.local_config = local_config
329332
self.skip_detect_manager = skip_detect_manager
330333
self._stop_signal = signal.SIGTERM
331334

335+
self.etcd_client_registry.prefill_clients(self.local_config.agent_configs)
336+
332337
async def __ainit__(self) -> None:
333338
# Start serving requests.
334339
await self.update_status("starting")
@@ -339,11 +344,12 @@ async def __ainit__(self) -> None:
339344
await self.read_agent_config()
340345
await self.read_agent_config_container()
341346

347+
global_etcd = self.etcd_client_registry.global_etcd
342348
self.stats_monitor = AgentStatsPluginContext(
343-
self.etcd, self.local_config.model_dump(by_alias=True)
349+
global_etcd, self.local_config.model_dump(by_alias=True)
344350
)
345351
self.error_monitor = AgentErrorPluginContext(
346-
self.etcd, self.local_config.model_dump(by_alias=True)
352+
global_etcd, self.local_config.model_dump(by_alias=True)
347353
)
348354
await self.stats_monitor.init()
349355
await self.error_monitor.init()
@@ -380,7 +386,7 @@ async def __ainit__(self) -> None:
380386
backend = self.local_config.agent_common.backend
381387
agent_mod = importlib.import_module(f"ai.backend.agent.{backend.value}")
382388
self.agent = await agent_mod.get_agent_cls().new( # type: ignore
383-
self.etcd,
389+
self.etcd_client_registry.get_client(AgentId(self.local_config.agent.id)),
384390
self.local_config,
385391
stats_monitor=self.stats_monitor,
386392
error_monitor=self.error_monitor,
@@ -422,13 +428,13 @@ async def _debug_server_task():
422428

423429
self.debug_server_task = asyncio.create_task(_debug_server_task())
424430

425-
await self.etcd.put("ip", rpc_addr.host, scope=ConfigScopes.NODE)
431+
await self.agent.etcd.put("ip", rpc_addr.host, scope=ConfigScopes.NODE)
426432

427433
watcher_port = utils.nmget(
428434
self.local_config.model_dump(), "watcher.service-addr.port", None
429435
)
430436
if watcher_port is not None:
431-
await self.etcd.put("watcher_port", watcher_port, scope=ConfigScopes.NODE)
437+
await self.agent.etcd.put("watcher_port", watcher_port, scope=ConfigScopes.NODE)
432438

433439
await self.update_status("running")
434440

@@ -472,10 +478,11 @@ def _ensure_serializable(o) -> Any:
472478

473479
async def detect_manager(self):
474480
log.info("detecting the manager...")
475-
manager_instances = await self.etcd.get_prefix("nodes/manager")
481+
global_etcd = self.etcd_client_registry.global_etcd
482+
manager_instances = await global_etcd.get_prefix("nodes/manager")
476483
if not manager_instances:
477484
log.warning("watching etcd to wait for the manager being available")
478-
async with aclosing(self.etcd.watch_prefix("nodes/manager")) as agen:
485+
async with aclosing(global_etcd.watch_prefix("nodes/manager")) as agen: # type: ignore
479486
async for ev in agen:
480487
match ev:
481488
case QueueSentinel.CLOSED | QueueSentinel.TIMEOUT:
@@ -486,9 +493,11 @@ async def detect_manager(self):
486493
log.info("detected at least one manager running")
487494

488495
async def read_agent_config(self):
496+
global_etcd = self.etcd_client_registry.global_etcd
497+
489498
# Fill up Redis configs from etcd and store as separate attributes
490499
self._redis_config = config.redis_config_iv.check(
491-
await self.etcd.get_prefix("config/redis"),
500+
await global_etcd.get_prefix("config/redis"),
492501
)
493502
log.info("configured redis: {0}", self._redis_config)
494503

@@ -506,7 +515,7 @@ async def read_agent_config(self):
506515
# Fill up vfolder configs from etcd and store as separate attributes
507516
# TODO: Integrate vfolder_config into local_config
508517
self._vfolder_config = config.vfolder_config_iv.check(
509-
await self.etcd.get_prefix("volumes"),
518+
await global_etcd.get_prefix("volumes"),
510519
)
511520
if self._vfolder_config["mount"] is None:
512521
log.info(
@@ -517,7 +526,7 @@ async def read_agent_config(self):
517526
log.info("configured vfolder fs prefix: {0}", self._vfolder_config["fsprefix"])
518527

519528
# Fill up shared agent configurations from etcd.
520-
agent_etcd_config_raw = await self.etcd.get_prefix("config/agent")
529+
agent_etcd_config_raw = await global_etcd.get_prefix("config/agent")
521530
if agent_etcd_config_raw:
522531
try:
523532
# Parse specific etcd configs and update the unified config
@@ -541,7 +550,9 @@ async def read_agent_config(self):
541550
async def read_agent_config_container(self):
542551
# Fill up global container configurations from etcd.
543552
try:
544-
container_etcd_config_raw = await self.etcd.get_prefix("config/container")
553+
container_etcd_config_raw = await self.etcd_client_registry.global_etcd.get_prefix(
554+
"config/container"
555+
)
545556
if container_etcd_config_raw:
546557
# Update config by creating a new instance with modified values
547558
container_updates = {}
@@ -586,7 +597,7 @@ async def __aexit__(self, *exc_info) -> None:
586597

587598
@collect_error
588599
async def update_status(self, status):
589-
await self.etcd.put("", status, scope=ConfigScopes.NODE)
600+
await self.agent.etcd.put("", status, scope=ConfigScopes.NODE)
590601

591602
@rpc_function
592603
@collect_error
@@ -1221,29 +1232,12 @@ async def aiomonitor_ctx(
12211232

12221233

12231234
@asynccontextmanager
1224-
async def etcd_ctx(local_config: AgentUnifiedConfig) -> AsyncGenerator[AsyncEtcd]:
1225-
etcd_credentials = None
1226-
if local_config.etcd.user and local_config.etcd.password:
1227-
etcd_credentials = {
1228-
"user": local_config.etcd.user,
1229-
"password": local_config.etcd.password,
1230-
}
1231-
scope_prefix_map = {
1232-
ConfigScopes.GLOBAL: "",
1233-
ConfigScopes.SGROUP: f"sgroup/{local_config.agent.scaling_group}",
1234-
ConfigScopes.NODE: f"nodes/agents/{local_config.agent.id}",
1235-
}
1236-
etcd_config_data = local_config.etcd.to_dataclass()
1237-
etcd = AsyncEtcd(
1238-
[addr.to_legacy() for addr in etcd_config_data.addrs],
1239-
local_config.etcd.namespace,
1240-
scope_prefix_map,
1241-
credentials=etcd_credentials,
1242-
)
1235+
async def etcd_ctx(local_config: AgentUnifiedConfig) -> AsyncGenerator[EtcdClientRegistry]:
1236+
etcd_client_registry = EtcdClientRegistry(local_config.etcd.to_dataclass())
12431237
try:
1244-
yield etcd
1238+
yield etcd_client_registry
12451239
finally:
1246-
await etcd.close()
1240+
await etcd_client_registry.close()
12471241

12481242

12491243
async def prepare_krunner_volumes(local_config: AgentUnifiedConfig) -> AgentUnifiedConfig:
@@ -1253,7 +1247,7 @@ async def prepare_krunner_volumes(local_config: AgentUnifiedConfig) -> AgentUnif
12531247
)
12541248
krunner_volumes: Mapping[str, str] = await kernel_mod.prepare_krunner_env(
12551249
local_config.model_dump(by_alias=True)
1256-
) # type: ignore
1250+
)
12571251
# TODO: merge k8s branch: nfs_mount_path = local_config['baistatic']['mounted-at']
12581252
log.info("Kernel runner environments: {}", [*krunner_volumes.keys()])
12591253
return local_config.with_updates(container_update={"krunner_volumes": krunner_volumes})
@@ -1317,10 +1311,10 @@ async def auto_detect_agent_network(
13171311

13181312
@asynccontextmanager
13191313
async def agent_server_ctx(
1320-
local_config: AgentUnifiedConfig, etcd: AsyncEtcd
1314+
local_config: AgentUnifiedConfig, etcd_client_registry: EtcdClientRegistry
13211315
) -> AsyncGenerator[AgentRPCServer]:
13221316
agent_server = await AgentRPCServer.new(
1323-
etcd,
1317+
etcd_client_registry,
13241318
local_config,
13251319
skip_detect_manager=local_config.agent_common.skip_manager_detection,
13261320
)
@@ -1419,18 +1413,19 @@ async def server_main(
14191413
local_config = await auto_detect_agent_identity(local_config)
14201414

14211415
# etcd's scope-prefix map depends on the auto-detected identity info.
1422-
etcd = await agent_init_stack.enter_async_context(etcd_ctx(local_config))
1423-
local_config = await auto_detect_agent_network(local_config, etcd)
1424-
plugins = await etcd.get_prefix_dict("config/plugins/accelerator")
1416+
etcd_client_registry = await agent_init_stack.enter_async_context(etcd_ctx(local_config))
1417+
global_etcd = etcd_client_registry.global_etcd
1418+
local_config = await auto_detect_agent_network(local_config, global_etcd)
1419+
plugins = await global_etcd.get_prefix_dict("config/plugins/accelerator")
14251420
local_config = local_config.with_changes(plugins=plugins)
14261421

14271422
# Start RPC server.
14281423
agent_server = await agent_init_stack.enter_async_context(
1429-
agent_server_ctx(local_config, etcd)
1424+
agent_server_ctx(local_config, etcd_client_registry)
14301425
)
14311426
monitor.console_locals["agent_server"] = agent_server
14321427

1433-
await agent_init_stack.enter_async_context(service_discovery_ctx(etcd, agent_server))
1428+
await agent_init_stack.enter_async_context(service_discovery_ctx(global_etcd, agent_server))
14341429
log.info("Started the agent service.")
14351430
except Exception:
14361431
log.exception("Server initialization failure; triggering shutdown...")

tests/agent/test_agent.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020

2121
class Dummy:
22-
pass
22+
def prefill_clients(self, prefill_data):
23+
"""Mock implementation of prefill_clients for testing"""
24+
pass
2325

2426

2527
kgid = "kernel-gid"
@@ -29,8 +31,9 @@ class Dummy:
2931

3032
@pytest.fixture
3133
async def arpcs_no_ainit(test_id, redis_container):
32-
etcd = Dummy()
33-
etcd.get_prefix = None
34+
etcd_client_registry = Dummy()
35+
etcd_client_registry.global_etcd = Dummy()
36+
etcd_client_registry.global_etcd.get_prefix = None
3437

3538
# Create a minimal pydantic config for testing
3639
config = AgentUnifiedConfig(
@@ -40,14 +43,18 @@ async def arpcs_no_ainit(test_id, redis_container):
4043
etcd=EtcdConfig(namespace="test", addr=HostPortPair(host="127.0.0.1", port=2379)),
4144
)
4245

43-
ars = AgentRPCServer(etcd=etcd, local_config=config, skip_detect_manager=True)
46+
ars = AgentRPCServer(
47+
etcd_client_registry=etcd_client_registry, local_config=config, skip_detect_manager=True
48+
)
4449
yield ars
4550

4651

4752
@pytest.mark.asyncio
4853
async def test_read_agent_config_container_invalid01(arpcs_no_ainit, mocker):
4954
inspect_mock = AsyncMock(return_value={"a": 1, "b": 2})
50-
mocker.patch.object(arpcs_no_ainit.etcd, "get_prefix", new=inspect_mock)
55+
mocker.patch.object(
56+
arpcs_no_ainit.etcd_client_registry.global_etcd, "get_prefix", new=inspect_mock
57+
)
5158
await arpcs_no_ainit.read_agent_config_container()
5259
# Check that kernel-gid and kernel-uid are still at their default values (converted from -1)
5360
assert (
@@ -61,7 +68,9 @@ async def test_read_agent_config_container_invalid01(arpcs_no_ainit, mocker):
6168
@pytest.mark.asyncio
6269
async def test_read_agent_config_container_invalid02(arpcs_no_ainit, mocker):
6370
inspect_mock = AsyncMock(return_value={})
64-
mocker.patch.object(arpcs_no_ainit.etcd, "get_prefix", new=inspect_mock)
71+
mocker.patch.object(
72+
arpcs_no_ainit.etcd_client_registry.global_etcd, "get_prefix", new=inspect_mock
73+
)
6574
await arpcs_no_ainit.read_agent_config_container()
6675
# Check that kernel-gid and kernel-uid are still at their default values (converted from -1)
6776
assert (
@@ -75,7 +84,9 @@ async def test_read_agent_config_container_invalid02(arpcs_no_ainit, mocker):
7584
@pytest.mark.asyncio
7685
async def test_read_agent_config_container_1valid(arpcs_no_ainit, mocker):
7786
inspect_mock = AsyncMock(return_value={kgid: 10})
78-
mocker.patch.object(arpcs_no_ainit.etcd, "get_prefix", new=inspect_mock)
87+
mocker.patch.object(
88+
arpcs_no_ainit.etcd_client_registry.global_etcd, "get_prefix", new=inspect_mock
89+
)
7990
await arpcs_no_ainit.read_agent_config_container()
8091

8192
assert arpcs_no_ainit.local_config.container.kernel_gid.real == 10
@@ -87,7 +98,9 @@ async def test_read_agent_config_container_1valid(arpcs_no_ainit, mocker):
8798
@pytest.mark.asyncio
8899
async def test_read_agent_config_container_2valid(arpcs_no_ainit, mocker):
89100
inspect_mock = AsyncMock(return_value={kgid: 10, kuid: 20})
90-
mocker.patch.object(arpcs_no_ainit.etcd, "get_prefix", new=inspect_mock)
101+
mocker.patch.object(
102+
arpcs_no_ainit.etcd_client_registry.global_etcd, "get_prefix", new=inspect_mock
103+
)
91104
await arpcs_no_ainit.read_agent_config_container()
92105

93106
assert arpcs_no_ainit.local_config.container.kernel_gid.real == 10

0 commit comments

Comments
 (0)