Skip to content

Commit 1e4b560

Browse files
committed
feat(BA-3023): Add unified multi agent etcd client handling
This change introduces AgentEtcdClientView, which is a class that acts as an adaptor layer for ensuring that the config scope of etcd is always in sync with the specific agent's scaling group and agent ID.
1 parent ab573ad commit 1e4b560

File tree

5 files changed

+242
-3
lines changed

5 files changed

+242
-3
lines changed

changes/6721.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add AgentEtcdClientView for clean handling of etcd clients for multi agents

src/ai/backend/agent/agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
)
6565
from trafaret import DataError
6666

67+
from ai.backend.agent.etcd import AgentEtcdClientView
6768
from ai.backend.agent.metrics.metric import (
6869
StatScope,
6970
StatTaskObserver,
@@ -762,7 +763,7 @@ class AbstractAgent(
762763
id: AgentId
763764
loop: asyncio.AbstractEventLoop
764765
local_config: AgentUnifiedConfig
765-
etcd: AsyncEtcd
766+
etcd: AgentEtcdClientView
766767
local_instance_id: str
767768
kernel_registry: MutableMapping[KernelId, AbstractKernel]
768769
computers: MutableMapping[DeviceName, ComputerContext]
@@ -839,7 +840,7 @@ def __init__(
839840
) -> None:
840841
self._skip_initial_scan = skip_initial_scan
841842
self.loop = current_loop()
842-
self.etcd = etcd
843+
self.etcd = AgentEtcdClientView(etcd, config_container=self)
843844
self.local_config = local_config
844845
self.id = AgentId(local_config.agent.id or f"agent-{uuid4()}")
845846
self.local_instance_id = generate_local_instance_id(__file__)

src/ai/backend/agent/config/unified.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Any,
1717
Mapping,
1818
Optional,
19+
Protocol,
1920
Self,
2021
Sequence,
2122
TypeVar,
@@ -1309,3 +1310,7 @@ def _validate_agent_configs(self) -> Self:
13091310
config.validate_agent_specific_config()
13101311

13111312
return self
1313+
1314+
1315+
class AgentConfigContainer(Protocol):
1316+
local_config: AgentUnifiedConfig

src/ai/backend/agent/etcd.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from collections.abc import AsyncGenerator, Iterable, Mapping
2+
from typing import Optional, override
3+
4+
from etcd_client import CondVar
5+
6+
from ai.backend.agent.config.unified import AgentConfigContainer
7+
from ai.backend.common.etcd import (
8+
AbstractKVStore,
9+
AsyncEtcd,
10+
ConfigScopes,
11+
Event,
12+
GetPrefixValue,
13+
NestedStrKeyedMapping,
14+
)
15+
from ai.backend.common.exception import (
16+
BackendAIError,
17+
ErrorCode,
18+
ErrorDetail,
19+
ErrorDomain,
20+
ErrorOperation,
21+
)
22+
from ai.backend.common.types import QueueSentinel
23+
24+
25+
class AgentEtcdError(BackendAIError):
26+
@classmethod
27+
def error_code(cls) -> ErrorCode:
28+
return ErrorCode(
29+
domain=ErrorDomain.AGENT,
30+
operation=ErrorOperation.SETUP,
31+
error_detail=ErrorDetail.NOT_FOUND,
32+
)
33+
34+
35+
# AgentEtcdClientView inherits from AsyncEtcd, but really it's just composing an AsyncEtcd instance
36+
# and acting as an adaptor. Inheritance is made only to make the type checker happy, and the places
37+
# that use AsyncEtcd really should not take the concrete implementation type AsyncEtcd, but rather
38+
# the interface type AbstractKVStore. In the current codebase, manually modifying places that
39+
# currently take in an AsyncEtcd instance to instead take in AbstractKVStore would be too invasive.
40+
class AgentEtcdClientView(AsyncEtcd, AbstractKVStore):
41+
_etcd: AsyncEtcd
42+
# We take in a config container instead of the unified config object itself, because config
43+
# objects are never mutated, but rather replaced with a copy.
44+
_config_container: AgentConfigContainer
45+
46+
def __init__(
47+
self,
48+
etcd: AsyncEtcd,
49+
config_container: AgentConfigContainer,
50+
) -> None:
51+
self._etcd = etcd
52+
self._config_container = config_container
53+
54+
def _augment_scope_prefix_map(
55+
self,
56+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]],
57+
) -> Mapping[ConfigScopes, str]:
58+
if scope_prefix_map is None:
59+
scope_prefix_map = {}
60+
61+
agent_config = self._config_container.local_config.agent
62+
return {
63+
**scope_prefix_map,
64+
ConfigScopes.SGROUP: f"sgroup/{agent_config.scaling_group}",
65+
ConfigScopes.NODE: f"nodes/agents/{agent_config.id}",
66+
}
67+
68+
@override
69+
async def put(
70+
self,
71+
key: str,
72+
val: str,
73+
*,
74+
scope: ConfigScopes = ConfigScopes.GLOBAL,
75+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
76+
):
77+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
78+
await self._etcd.put(key, val, scope=scope, scope_prefix_map=scope_prefix_map)
79+
80+
@override
81+
async def put_prefix(
82+
self,
83+
key: str,
84+
dict_obj: NestedStrKeyedMapping,
85+
*,
86+
scope: ConfigScopes = ConfigScopes.GLOBAL,
87+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
88+
):
89+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
90+
await self._etcd.put_prefix(key, dict_obj, scope=scope, scope_prefix_map=scope_prefix_map)
91+
92+
@override
93+
async def put_dict(
94+
self,
95+
flattened_dict_obj: Mapping[str, str],
96+
*,
97+
scope: ConfigScopes = ConfigScopes.GLOBAL,
98+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
99+
):
100+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
101+
await self._etcd.put_dict(
102+
flattened_dict_obj, scope=scope, scope_prefix_map=scope_prefix_map
103+
)
104+
105+
@override
106+
async def get(
107+
self,
108+
key: str,
109+
*,
110+
scope: ConfigScopes = ConfigScopes.MERGED,
111+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
112+
) -> Optional[str]:
113+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
114+
return await self._etcd.get(key, scope=scope, scope_prefix_map=scope_prefix_map)
115+
116+
@override
117+
async def get_prefix(
118+
self,
119+
key_prefix: str,
120+
*,
121+
scope: ConfigScopes = ConfigScopes.MERGED,
122+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
123+
) -> GetPrefixValue:
124+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
125+
return await self._etcd.get_prefix(
126+
key_prefix,
127+
scope=scope,
128+
scope_prefix_map=scope_prefix_map,
129+
)
130+
131+
@override
132+
async def replace(
133+
self,
134+
key: str,
135+
initial_val: str,
136+
new_val: str,
137+
*,
138+
scope: ConfigScopes = ConfigScopes.GLOBAL,
139+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
140+
) -> bool:
141+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
142+
return await self._etcd.replace(
143+
key,
144+
initial_val,
145+
new_val,
146+
scope=scope,
147+
scope_prefix_map=scope_prefix_map,
148+
)
149+
150+
@override
151+
async def delete(
152+
self,
153+
key: str,
154+
*,
155+
scope: ConfigScopes = ConfigScopes.GLOBAL,
156+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
157+
):
158+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
159+
await self._etcd.delete(key, scope=scope, scope_prefix_map=scope_prefix_map)
160+
161+
@override
162+
async def delete_multi(
163+
self,
164+
keys: Iterable[str],
165+
*,
166+
scope: ConfigScopes = ConfigScopes.GLOBAL,
167+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
168+
):
169+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
170+
await self._etcd.delete_multi(keys, scope=scope, scope_prefix_map=scope_prefix_map)
171+
172+
@override
173+
async def delete_prefix(
174+
self,
175+
key_prefix: str,
176+
*,
177+
scope: ConfigScopes = ConfigScopes.GLOBAL,
178+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
179+
):
180+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
181+
await self._etcd.delete_prefix(key_prefix, scope=scope, scope_prefix_map=scope_prefix_map)
182+
183+
@override
184+
async def watch(
185+
self,
186+
key: str,
187+
*,
188+
scope: ConfigScopes = ConfigScopes.GLOBAL,
189+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
190+
once: bool = False,
191+
ready_event: Optional[CondVar] = None,
192+
cleanup_event: Optional[CondVar] = None,
193+
wait_timeout: Optional[float] = None,
194+
) -> AsyncGenerator[QueueSentinel | Event, None]:
195+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
196+
watch_result = self._etcd.watch(
197+
key,
198+
scope=scope,
199+
scope_prefix_map=scope_prefix_map,
200+
once=once,
201+
ready_event=ready_event,
202+
cleanup_event=cleanup_event,
203+
wait_timeout=wait_timeout,
204+
)
205+
async for item in watch_result:
206+
yield item
207+
208+
@override
209+
async def watch_prefix(
210+
self,
211+
key_prefix: str,
212+
*,
213+
scope: ConfigScopes = ConfigScopes.GLOBAL,
214+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
215+
once: bool = False,
216+
ready_event: Optional[CondVar] = None,
217+
cleanup_event: Optional[CondVar] = None,
218+
wait_timeout: Optional[float] = None,
219+
) -> AsyncGenerator[QueueSentinel | Event, None]:
220+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
221+
watch_prefix_result = self._etcd.watch_prefix(
222+
key_prefix,
223+
scope=scope,
224+
scope_prefix_map=scope_prefix_map,
225+
once=once,
226+
ready_event=ready_event,
227+
cleanup_event=cleanup_event,
228+
wait_timeout=wait_timeout,
229+
)
230+
async for item in watch_prefix_result:
231+
yield item

src/ai/backend/agent/server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ class AgentRPCServer(aobject):
309309

310310
loop: asyncio.AbstractEventLoop
311311
agent: AbstractAgent
312+
etcd: AsyncEtcd
312313
rpc_server: Peer
313314
rpc_addr: str
314315
agent_addr: str
@@ -1253,7 +1254,7 @@ async def prepare_krunner_volumes(local_config: AgentUnifiedConfig) -> AgentUnif
12531254
)
12541255
krunner_volumes: Mapping[str, str] = await kernel_mod.prepare_krunner_env(
12551256
local_config.model_dump(by_alias=True)
1256-
) # type: ignore
1257+
)
12571258
# TODO: merge k8s branch: nfs_mount_path = local_config['baistatic']['mounted-at']
12581259
log.info("Kernel runner environments: {}", [*krunner_volumes.keys()])
12591260
return local_config.with_updates(container_update={"krunner_volumes": krunner_volumes})

0 commit comments

Comments
 (0)