Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/6721.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add AgentEtcdClientView for clean handling of etcd clients for multi agents
6 changes: 3 additions & 3 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
)
from trafaret import DataError

from ai.backend.agent.etcd import AgentEtcdClientView
from ai.backend.agent.metrics.metric import (
StatScope,
StatTaskObserver,
Expand Down Expand Up @@ -256,7 +257,6 @@

if TYPE_CHECKING:
from ai.backend.common.auth import PublicKey
from ai.backend.common.etcd import AsyncEtcd

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

Expand Down Expand Up @@ -762,7 +762,7 @@ class AbstractAgent(
id: AgentId
loop: asyncio.AbstractEventLoop
local_config: AgentUnifiedConfig
etcd: AsyncEtcd
etcd: AgentEtcdClientView
local_instance_id: str
kernel_registry: MutableMapping[KernelId, AbstractKernel]
computers: MutableMapping[DeviceName, ComputerContext]
Expand Down Expand Up @@ -829,7 +829,7 @@ def track_create(

def __init__(
self,
etcd: AsyncEtcd,
etcd: AgentEtcdClientView,
local_config: AgentUnifiedConfig,
*,
stats_monitor: StatsPluginContext,
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from aiotools import TaskGroup
from async_timeout import timeout

from ai.backend.agent.etcd import AgentEtcdClientView
from ai.backend.common.cgroup import get_cgroup_mount_point
from ai.backend.common.data.image.types import ScannedImage
from ai.backend.common.docker import (
Expand Down Expand Up @@ -143,7 +144,6 @@

if TYPE_CHECKING:
from ai.backend.common.auth import PublicKey
from ai.backend.common.etcd import AsyncEtcd

log = BraceStyleAdapter(logging.getLogger(__spec__.name))
eof_sentinel = Sentinel.TOKEN
Expand Down Expand Up @@ -1350,7 +1350,7 @@ class DockerAgent(AbstractAgent[DockerKernel, DockerKernelCreationContext]):

def __init__(
self,
etcd: AsyncEtcd,
etcd: AgentEtcdClientView,
local_config: AgentUnifiedConfig,
*,
stats_monitor: StatsPluginContext,
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/docker/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import aiofiles

from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.etcd import AbstractKVStore
from ai.backend.common.types import DeviceName, SlotName
from ai.backend.logging import BraceStyleAdapter

Expand All @@ -21,7 +21,7 @@


async def load_resources(
etcd: AsyncEtcd,
etcd: AbstractKVStore,
local_config: Mapping[str, Any],
) -> Mapping[DeviceName, AbstractComputePlugin]:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/dummy/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from decimal import Decimal
from typing import Any, Mapping, MutableMapping

from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.etcd import AbstractKVStore
from ai.backend.common.types import DeviceName, SlotName
from ai.backend.logging import BraceStyleAdapter

Expand All @@ -17,7 +17,7 @@


async def load_resources(
etcd: AsyncEtcd,
etcd: AbstractKVStore,
local_config: Mapping[str, Any],
dummy_config: Mapping[str, Any],
) -> Mapping[DeviceName, AbstractComputePlugin]:
Expand Down
235 changes: 235 additions & 0 deletions src/ai/backend/agent/etcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import ChainMap, MutableMapping, Optional, cast, override

from etcd_client import CondVar

from ai.backend.agent.config.unified import AgentUnifiedConfig
from ai.backend.common.etcd import (
AbstractKVStore,
AsyncEtcd,
ConfigScopes,
Event,
GetPrefixValue,
NestedStrKeyedMapping,
)
from ai.backend.common.exception import (
BackendAIError,
ErrorCode,
ErrorDetail,
ErrorDomain,
ErrorOperation,
)
from ai.backend.common.types import QueueSentinel


class AgentEtcdError(BackendAIError):
@classmethod
def error_code(cls) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.AGENT,
operation=ErrorOperation.SETUP,
error_detail=ErrorDetail.NOT_FOUND,
)


class AgentEtcdClientView(AbstractKVStore):
_etcd: AsyncEtcd
_config: AgentUnifiedConfig

def __init__(
self,
etcd: AsyncEtcd,
config: AgentUnifiedConfig,
) -> None:
self._etcd = etcd
self._config = config

@property
def _agent_scope_prefix_map(self) -> Mapping[ConfigScopes, str]:
"""
This is kept as a @property method instead of a simple variable, because this way any
updates that are made to the config object (e.g. scaling group) is correctly applied as the
scope prefix mapping is recalculated every time.
"""
return {
ConfigScopes.SGROUP: f"sgroup/{self._config.agent.scaling_group}",
ConfigScopes.NODE: f"nodes/agents/{self._config.agent.id}",
}

def _augment_scope_prefix_map(
self,
override: Optional[Mapping[ConfigScopes, str]],
) -> Mapping[ConfigScopes, str]:
"""
This stub ensures immutable usage of the ChainMap because ChainMap does *not*
have the immutable version in typeshed.
(ref: https://github.com/python/typeshed/issues/6042)
"""
return ChainMap(
cast(MutableMapping, override) or {}, cast(MutableMapping, self._agent_scope_prefix_map)
)

@override
async def put(
self,
key: str,
val: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
):
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
await self._etcd.put(key, val, scope=scope, scope_prefix_map=scope_prefix_map)

@override
async def put_prefix(
self,
key: str,
dict_obj: NestedStrKeyedMapping,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
):
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
await self._etcd.put_prefix(key, dict_obj, scope=scope, scope_prefix_map=scope_prefix_map)

@override
async def put_dict(
self,
flattened_dict_obj: Mapping[str, str],
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
):
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
await self._etcd.put_dict(
flattened_dict_obj, scope=scope, scope_prefix_map=scope_prefix_map
)

@override
async def get(
self,
key: str,
*,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
) -> Optional[str]:
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
return await self._etcd.get(key, scope=scope, scope_prefix_map=scope_prefix_map)

@override
async def get_prefix(
self,
key_prefix: str,
*,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
) -> GetPrefixValue:
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
return await self._etcd.get_prefix(
key_prefix,
scope=scope,
scope_prefix_map=scope_prefix_map,
)

@override
async def replace(
self,
key: str,
initial_val: str,
new_val: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
) -> bool:
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
return await self._etcd.replace(
key,
initial_val,
new_val,
scope=scope,
scope_prefix_map=scope_prefix_map,
)

@override
async def delete(
self,
key: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
):
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
await self._etcd.delete(key, scope=scope, scope_prefix_map=scope_prefix_map)

@override
async def delete_multi(
self,
keys: Iterable[str],
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
):
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
await self._etcd.delete_multi(keys, scope=scope, scope_prefix_map=scope_prefix_map)

@override
async def delete_prefix(
self,
key_prefix: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
):
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
await self._etcd.delete_prefix(key_prefix, scope=scope, scope_prefix_map=scope_prefix_map)

@override
async def watch(
self,
key: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
once: bool = False,
ready_event: Optional[CondVar] = None,
cleanup_event: Optional[CondVar] = None,
wait_timeout: Optional[float] = None,
) -> AsyncGenerator[QueueSentinel | Event, None]:
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
watch_result = self._etcd.watch(
key,
scope=scope,
scope_prefix_map=scope_prefix_map,
once=once,
ready_event=ready_event,
cleanup_event=cleanup_event,
wait_timeout=wait_timeout,
)
async for item in watch_result:
yield item

@override
async def watch_prefix(
self,
key_prefix: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Optional[Mapping[ConfigScopes, str]] = None,
once: bool = False,
ready_event: Optional[CondVar] = None,
cleanup_event: Optional[CondVar] = None,
wait_timeout: Optional[float] = None,
) -> AsyncGenerator[QueueSentinel | Event, None]:
scope_prefix_map = self._augment_scope_prefix_map(scope_prefix_map)
watch_prefix_result = self._etcd.watch_prefix(
key_prefix,
scope=scope,
scope_prefix_map=scope_prefix_map,
once=once,
ready_event=ready_event,
cleanup_event=cleanup_event,
wait_timeout=wait_timeout,
)
async for item in watch_prefix_result:
yield item
4 changes: 2 additions & 2 deletions src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
from kubernetes_asyncio import client as kube_client
from kubernetes_asyncio import config as kube_config

from ai.backend.agent.etcd import AgentEtcdClientView
from ai.backend.common.asyncio import current_loop
from ai.backend.common.docker import ImageRef, KernelFeatures
from ai.backend.common.dto.agent.response import PurgeImagesResp
from ai.backend.common.dto.manager.rpc_request import PurgeImagesReq
from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.events.dispatcher import EventProducer
from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext
from ai.backend.common.types import (
Expand Down Expand Up @@ -822,7 +822,7 @@ class KubernetesAgent(

def __init__(
self,
etcd: AsyncEtcd,
etcd: AgentEtcdClientView,
local_config: AgentUnifiedConfig,
*,
stats_monitor: StatsPluginContext,
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/kubernetes/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import aiofiles

from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.etcd import AbstractKVStore
from ai.backend.common.types import DeviceName, SlotName
from ai.backend.logging import BraceStyleAdapter

Expand All @@ -21,7 +21,7 @@


async def load_resources(
etcd: AsyncEtcd, local_config: Mapping[str, Any]
etcd: AbstractKVStore, local_config: Mapping[str, Any]
) -> Mapping[DeviceName, AbstractComputePlugin]:
"""
Detect and load the accelerator plugins.
Expand Down
Loading
Loading