Skip to content

Commit 33ecbc4

Browse files
committed
feat(BA-3023): Add unified multi agent etcd client handling
This change introduces AgentEtcdClientView, which is a class that acts as an adaptor layer for ensuring that the config scope of etcd is always in sync with the specific agent's scaling group and agent ID.
1 parent f3e6537 commit 33ecbc4

File tree

8 files changed

+749
-15
lines changed

8 files changed

+749
-15
lines changed

changes/6721.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add AgentEtcdClientView for clean handling of etcd clients for multi agents

src/ai/backend/agent/agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
)
6565
from trafaret import DataError
6666

67+
from ai.backend.agent.etcd import AgentEtcdClientView
6768
from ai.backend.agent.metrics.metric import (
6869
StatScope,
6970
StatTaskObserver,
@@ -256,7 +257,6 @@
256257

257258
if TYPE_CHECKING:
258259
from ai.backend.common.auth import PublicKey
259-
from ai.backend.common.etcd import AsyncEtcd
260260

261261
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
262262

@@ -762,7 +762,7 @@ class AbstractAgent(
762762
id: AgentId
763763
loop: asyncio.AbstractEventLoop
764764
local_config: AgentUnifiedConfig
765-
etcd: AsyncEtcd
765+
etcd: AgentEtcdClientView
766766
local_instance_id: str
767767
kernel_registry: MutableMapping[KernelId, AbstractKernel]
768768
computers: MutableMapping[DeviceName, ComputerContext]
@@ -829,7 +829,7 @@ def track_create(
829829

830830
def __init__(
831831
self,
832-
etcd: AsyncEtcd,
832+
etcd: AgentEtcdClientView,
833833
local_config: AgentUnifiedConfig,
834834
*,
835835
stats_monitor: StatsPluginContext,

src/ai/backend/agent/docker/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from aiotools import TaskGroup
4949
from async_timeout import timeout
5050

51+
from ai.backend.agent.etcd import AgentEtcdClientView
5152
from ai.backend.common.cgroup import get_cgroup_mount_point
5253
from ai.backend.common.data.image.types import ScannedImage
5354
from ai.backend.common.docker import (
@@ -143,7 +144,6 @@
143144

144145
if TYPE_CHECKING:
145146
from ai.backend.common.auth import PublicKey
146-
from ai.backend.common.etcd import AsyncEtcd
147147

148148
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
149149
eof_sentinel = Sentinel.TOKEN
@@ -1350,7 +1350,7 @@ class DockerAgent(AbstractAgent[DockerKernel, DockerKernelCreationContext]):
13501350

13511351
def __init__(
13521352
self,
1353-
etcd: AsyncEtcd,
1353+
etcd: AgentEtcdClientView,
13541354
local_config: AgentUnifiedConfig,
13551355
*,
13561356
stats_monitor: StatsPluginContext,

src/ai/backend/agent/etcd.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
from collections.abc import AsyncGenerator, Iterable, Mapping
2+
from typing import Optional, override
3+
4+
from etcd_client import CondVar
5+
6+
from ai.backend.agent.config.unified import AgentUnifiedConfig
7+
from ai.backend.common.etcd import (
8+
AbstractKVStore,
9+
AsyncEtcd,
10+
ConfigScopes,
11+
Event,
12+
GetPrefixValue,
13+
NestedStrKeyedMapping,
14+
)
15+
from ai.backend.common.exception import (
16+
BackendAIError,
17+
ErrorCode,
18+
ErrorDetail,
19+
ErrorDomain,
20+
ErrorOperation,
21+
)
22+
from ai.backend.common.types import QueueSentinel
23+
24+
25+
class AgentEtcdError(BackendAIError):
26+
@classmethod
27+
def error_code(cls) -> ErrorCode:
28+
return ErrorCode(
29+
domain=ErrorDomain.AGENT,
30+
operation=ErrorOperation.SETUP,
31+
error_detail=ErrorDetail.NOT_FOUND,
32+
)
33+
34+
35+
# AgentEtcdClientView inherits from AsyncEtcd, but really it's just composing an AsyncEtcd instance
36+
# and acting as an adaptor. Inheritance is made only to make the type checker happy, and the places
37+
# that use AsyncEtcd really should not take the concrete implementation type AsyncEtcd, but rather
38+
# the interface type AbstractKVStore. In the current codebase, manually modifying places that
39+
# currently take in an AsyncEtcd instance to instead take in AbstractKVStore would be too invasive.
40+
class AgentEtcdClientView(AsyncEtcd, AbstractKVStore):
41+
_etcd: AsyncEtcd
42+
_config: AgentUnifiedConfig
43+
44+
def __init__(
45+
self,
46+
etcd: AsyncEtcd,
47+
config: AgentUnifiedConfig,
48+
) -> None:
49+
self._etcd = etcd
50+
self._config = config
51+
52+
def _augment_scope_prefix_map(
53+
self,
54+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]],
55+
) -> Mapping[ConfigScopes, str]:
56+
if scope_prefix_map is None:
57+
scope_prefix_map = {}
58+
59+
agent_config = self._config.agent
60+
return {
61+
**scope_prefix_map,
62+
ConfigScopes.SGROUP: f"sgroup/{agent_config.scaling_group}",
63+
ConfigScopes.NODE: f"nodes/agents/{agent_config.id}",
64+
}
65+
66+
@override
67+
async def put(
68+
self,
69+
key: str,
70+
val: str,
71+
*,
72+
scope: ConfigScopes = ConfigScopes.GLOBAL,
73+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
74+
):
75+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
76+
await self._etcd.put(key, val, scope=scope, scope_prefix_map=scope_prefix_map)
77+
78+
@override
79+
async def put_prefix(
80+
self,
81+
key: str,
82+
dict_obj: NestedStrKeyedMapping,
83+
*,
84+
scope: ConfigScopes = ConfigScopes.GLOBAL,
85+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
86+
):
87+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
88+
await self._etcd.put_prefix(key, dict_obj, scope=scope, scope_prefix_map=scope_prefix_map)
89+
90+
@override
91+
async def put_dict(
92+
self,
93+
flattened_dict_obj: Mapping[str, str],
94+
*,
95+
scope: ConfigScopes = ConfigScopes.GLOBAL,
96+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
97+
):
98+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
99+
await self._etcd.put_dict(
100+
flattened_dict_obj, scope=scope, scope_prefix_map=scope_prefix_map
101+
)
102+
103+
@override
104+
async def get(
105+
self,
106+
key: str,
107+
*,
108+
scope: ConfigScopes = ConfigScopes.MERGED,
109+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
110+
) -> Optional[str]:
111+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
112+
return await self._etcd.get(key, scope=scope, scope_prefix_map=scope_prefix_map)
113+
114+
@override
115+
async def get_prefix(
116+
self,
117+
key_prefix: str,
118+
*,
119+
scope: ConfigScopes = ConfigScopes.MERGED,
120+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
121+
) -> GetPrefixValue:
122+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
123+
return await self._etcd.get_prefix(
124+
key_prefix,
125+
scope=scope,
126+
scope_prefix_map=scope_prefix_map,
127+
)
128+
129+
@override
130+
async def replace(
131+
self,
132+
key: str,
133+
initial_val: str,
134+
new_val: str,
135+
*,
136+
scope: ConfigScopes = ConfigScopes.GLOBAL,
137+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
138+
) -> bool:
139+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
140+
return await self._etcd.replace(
141+
key,
142+
initial_val,
143+
new_val,
144+
scope=scope,
145+
scope_prefix_map=scope_prefix_map,
146+
)
147+
148+
@override
149+
async def delete(
150+
self,
151+
key: str,
152+
*,
153+
scope: ConfigScopes = ConfigScopes.GLOBAL,
154+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
155+
):
156+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
157+
await self._etcd.delete(key, scope=scope, scope_prefix_map=scope_prefix_map)
158+
159+
@override
160+
async def delete_multi(
161+
self,
162+
keys: Iterable[str],
163+
*,
164+
scope: ConfigScopes = ConfigScopes.GLOBAL,
165+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
166+
):
167+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
168+
await self._etcd.delete_multi(keys, scope=scope, scope_prefix_map=scope_prefix_map)
169+
170+
@override
171+
async def delete_prefix(
172+
self,
173+
key_prefix: str,
174+
*,
175+
scope: ConfigScopes = ConfigScopes.GLOBAL,
176+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
177+
):
178+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
179+
await self._etcd.delete_prefix(key_prefix, scope=scope, scope_prefix_map=scope_prefix_map)
180+
181+
@override
182+
async def watch(
183+
self,
184+
key: str,
185+
*,
186+
scope: ConfigScopes = ConfigScopes.GLOBAL,
187+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
188+
once: bool = False,
189+
ready_event: Optional[CondVar] = None,
190+
cleanup_event: Optional[CondVar] = None,
191+
wait_timeout: Optional[float] = None,
192+
) -> AsyncGenerator[QueueSentinel | Event, None]:
193+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
194+
watch_result = self._etcd.watch(
195+
key,
196+
scope=scope,
197+
scope_prefix_map=scope_prefix_map,
198+
once=once,
199+
ready_event=ready_event,
200+
cleanup_event=cleanup_event,
201+
wait_timeout=wait_timeout,
202+
)
203+
async for item in watch_result:
204+
yield item
205+
206+
@override
207+
async def watch_prefix(
208+
self,
209+
key_prefix: str,
210+
*,
211+
scope: ConfigScopes = ConfigScopes.GLOBAL,
212+
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
213+
once: bool = False,
214+
ready_event: Optional[CondVar] = None,
215+
cleanup_event: Optional[CondVar] = None,
216+
wait_timeout: Optional[float] = None,
217+
) -> AsyncGenerator[QueueSentinel | Event, None]:
218+
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
219+
watch_prefix_result = self._etcd.watch_prefix(
220+
key_prefix,
221+
scope=scope,
222+
scope_prefix_map=scope_prefix_map,
223+
once=once,
224+
ready_event=ready_event,
225+
cleanup_event=cleanup_event,
226+
wait_timeout=wait_timeout,
227+
)
228+
async for item in watch_prefix_result:
229+
yield item

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
from kubernetes_asyncio import client as kube_client
3333
from kubernetes_asyncio import config as kube_config
3434

35+
from ai.backend.agent.etcd import AgentEtcdClientView
3536
from ai.backend.common.asyncio import current_loop
3637
from ai.backend.common.docker import ImageRef, KernelFeatures
3738
from ai.backend.common.dto.agent.response import PurgeImagesResp
3839
from ai.backend.common.dto.manager.rpc_request import PurgeImagesReq
39-
from ai.backend.common.etcd import AsyncEtcd
4040
from ai.backend.common.events.dispatcher import EventProducer
4141
from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext
4242
from ai.backend.common.types import (
@@ -822,7 +822,7 @@ class KubernetesAgent(
822822

823823
def __init__(
824824
self,
825-
etcd: AsyncEtcd,
825+
etcd: AgentEtcdClientView,
826826
local_config: AgentUnifiedConfig,
827827
*,
828828
stats_monitor: StatsPluginContext,

src/ai/backend/agent/runtime.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ai.backend.agent.agent import AbstractAgent
55
from ai.backend.agent.config.unified import AgentUnifiedConfig
6+
from ai.backend.agent.etcd import AgentEtcdClientView
67
from ai.backend.agent.monitor import AgentErrorPluginContext, AgentStatsPluginContext
78
from ai.backend.agent.types import AgentBackend
89
from ai.backend.common.auth import PublicKey
@@ -14,6 +15,7 @@ class AgentRuntime(aobject):
1415
local_config: AgentUnifiedConfig
1516
agent: AbstractAgent
1617
etcd: AsyncEtcd
18+
etcd_view: AgentEtcdClientView
1719

1820
_stop_signal: signal.Signals
1921

@@ -27,6 +29,7 @@ def __init__(
2729
) -> None:
2830
self.local_config = local_config
2931
self.etcd = etcd
32+
self.etcd_view = AgentEtcdClientView(etcd, self.local_config)
3033

3134
self._stop_signal = signal.SIGTERM
3235

@@ -35,23 +38,23 @@ def __init__(
3538
self.agent_public_key = agent_public_key
3639

3740
async def __ainit__(self) -> None:
38-
self.agent = await self._create_agent(self.etcd, self.local_config)
41+
self.agent = await self._create_agent(self.etcd_view, self.local_config)
3942

4043
async def __aexit__(self, *exc_info) -> None:
4144
await self.agent.shutdown(self._stop_signal)
4245

4346
def get_agent(self) -> AbstractAgent:
4447
return self.agent
4548

46-
def get_etcd(self) -> AsyncEtcd:
47-
return self.etcd
49+
def get_etcd(self) -> AgentEtcdClientView:
50+
return self.etcd_view
4851

4952
def mark_stop_signal(self, stop_signal: signal.Signals) -> None:
5053
self._stop_signal = stop_signal
5154

5255
async def _create_agent(
5356
self,
54-
etcd: AsyncEtcd,
57+
etcd_view: AgentEtcdClientView,
5558
agent_config: AgentUnifiedConfig,
5659
) -> AbstractAgent:
5760
agent_kwargs = {
@@ -64,12 +67,12 @@ async def _create_agent(
6467
case AgentBackend.DOCKER:
6568
from .docker.agent import DockerAgent
6669

67-
return await DockerAgent.new(etcd, agent_config, **agent_kwargs)
70+
return await DockerAgent.new(etcd_view, agent_config, **agent_kwargs)
6871
case AgentBackend.KUBERNETES:
6972
from .kubernetes.agent import KubernetesAgent
7073

71-
return await KubernetesAgent.new(etcd, agent_config, **agent_kwargs)
74+
return await KubernetesAgent.new(etcd_view, agent_config, **agent_kwargs)
7275
case AgentBackend.DUMMY:
7376
from ai.backend.agent.dummy.agent import DummyAgent
7477

75-
return await DummyAgent.new(etcd, agent_config, **agent_kwargs)
78+
return await DummyAgent.new(etcd_view, agent_config, **agent_kwargs)

src/ai/backend/agent/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,7 @@ async def prepare_krunner_volumes(local_config: AgentUnifiedConfig) -> None:
12421242
)
12431243
krunner_volumes: Mapping[str, str] = await kernel_mod.prepare_krunner_env(
12441244
local_config.model_dump(by_alias=True)
1245-
) # type: ignore
1245+
)
12461246
# TODO: merge k8s branch: nfs_mount_path = local_config['baistatic']['mounted-at']
12471247
log.info("Kernel runner environments: {}", [*krunner_volumes.keys()])
12481248
local_config.update(container_update={"krunner_volumes": krunner_volumes})

0 commit comments

Comments
 (0)