diff --git a/changes/6320.feature.md b/changes/6320.feature.md new file mode 100644 index 00000000000..3353768c962 --- /dev/null +++ b/changes/6320.feature.md @@ -0,0 +1 @@ +Update Agent server RPC functions to include agent ID for agent runtime with multiple agents \ No newline at end of file diff --git a/changes/6724.feature.md b/changes/6724.feature.md new file mode 100644 index 00000000000..d49567aeb75 --- /dev/null +++ b/changes/6724.feature.md @@ -0,0 +1 @@ +Add custom resource allocation in agent server config \ No newline at end of file diff --git a/configs/agent/sample.toml b/configs/agent/sample.toml index 3741e8fc687..50a5fd52dec 100644 --- a/configs/agent/sample.toml +++ b/configs/agent/sample.toml @@ -16,14 +16,6 @@ scaling-group = "default" # Scaling group type scaling-group-type = "compute" - # Allowed compute plugins - ## allow-compute-plugins = [ "ai.backend.accelerator.cuda_open", "ai.backend.activator.agent",] - # Blocked compute plugins - ## block-compute-plugins = [ "ai.backend.accelerator.mock",] - # Allowed network plugins - ## allow-network-plugins = [ "ai.backend.manager.network.overlay",] - # Blocked network plugins - ## block-network-plugins = [ "ai.backend.manager.network.overlay",] # Whether to force terminate abusing containers force-terminate-abusing-containers = false # Kernel creation concurrency @@ -81,6 +73,14 @@ metadata-server-bind-host = "0.0.0.0" # Metadata server port metadata-server-port = 40128 + # Allowed compute plugins + ## allow-compute-plugins = [ "ai.backend.activator.agent", "ai.backend.accelerator.cuda_open",] + # Blocked compute plugins + ## block-compute-plugins = [ "ai.backend.accelerator.mock",] + # Allowed network plugins + ## allow-network-plugins = [ "ai.backend.manager.network.overlay",] + # Blocked network plugins + ## block-network-plugins = [ "ai.backend.manager.network.overlay",] # Path for image commit image-commit-path = "tmp/backend.ai/commit" # Path for abuse reports @@ -156,6 +156,12 @@ # Currently this value is unused. In future releases, it may be used to preserve # the minimum disk space from the scratch disk allocation via loopback files. reserved-disk = "8G" + # Resource allocation mode for multi-agent scenarios. + # - `shared`: All agents share the full resource pool (default, backward + # compatible). + # - `auto-split`: Automatically divide resources equally (1/N) among all agents. + # - `manual`: Manually specify per-agent resource allocations via config. + allocation-mode = "shared" # The alignment of the reported main memory size to absorb tiny deviations from # per-node firmware/hardware settings. Recommended to be multiple of the # page/hugepage size (e.g., 2 MiB). @@ -165,6 +171,22 @@ # Affinity policy affinity-policy = "INTERLEAVED" + # Resource allocations. + # Only used in MANUAL allocation mode. + [resource.allocations] + # Hard CPU allocation for this agent (e.g., 8 cores). + # Only used in MANUAL allocation mode. + # All agents must specify this value when allocation-mode is MANUAL. + cpu = 8 + # Hard memory allocation for this agent (e.g., "32G"). + # Only used in MANUAL allocation mode. + # All agents must specify this value when allocation-mode is MANUAL. + mem = "32G" + + # Device-specific per-slot resource allocations. + # Only used in MANUAL allocation mode. + [resource.allocations.devices] + # Pyroscope configuration [pyroscope] # Whether to enable Pyroscope profiling @@ -351,14 +373,6 @@ scaling-group = "default" # Scaling group type scaling-group-type = "compute" - # Allowed compute plugins - ## allow-compute-plugins = [ "ai.backend.accelerator.cuda_open", "ai.backend.activator.agent",] - # Blocked compute plugins - ## block-compute-plugins = [ "ai.backend.accelerator.mock",] - # Allowed network plugins - ## allow-network-plugins = [ "ai.backend.manager.network.overlay",] - # Blocked network plugins - ## block-network-plugins = [ "ai.backend.manager.network.overlay",] # Whether to force terminate abusing containers force-terminate-abusing-containers = false # Kernel creation concurrency @@ -383,7 +397,7 @@ # late into the agent's runtime. port-range = [ 30000, 31000,] # Statistics type - ## stats-type = "cgroup" + ## stats-type = "docker" # Sandbox type sandbox-type = "docker" # Jail arguments @@ -393,7 +407,7 @@ # Scratch root directory scratch-root = "scratches" # Scratch size - scratch-size = 0 + scratch-size = "0" # Scratch NFS address ## scratch-nfs-address = "192.168.1.100:/export" # Scratch NFS options @@ -409,24 +423,15 @@ # Resource config overrides for the individual agent [agents.resource] - # The number of CPU cores reserved for the operating system and the agent - # service. - reserved-cpu = 1 - # The memory space reserved for the operating system and the agent service. It - # is subtracted from the reported main memory size and not available for user - # workload allocation. Depending on the memory-align-size option and system - # configuration, this may not be the exact value but have slightly less or more - # values within the memory-align-size. - reserved-mem = 1073741824 - # The disk space reserved for the operating system and the agent service. - # Currently this value is unused. In future releases, it may be used to preserve - # the minimum disk space from the scratch disk allocation via loopback files. - reserved-disk = 8589934592 - # The alignment of the reported main memory size to absorb tiny deviations from - # per-node firmware/hardware settings. Recommended to be multiple of the - # page/hugepage size (e.g., 2 MiB). - memory-align-size = 16777216 - # Resource allocation order - allocation-order = [ "cuda", "rocm", "tpu", "cpu", "mem",] - # Affinity policy - affinity-policy = 1 + # Hard CPU allocation for this agent (e.g., 8 cores). + # Only used in MANUAL allocation mode. + # All agents must specify this value when allocation-mode is MANUAL. + cpu = 8 + # Hard memory allocation for this agent (e.g., "32G"). + # Only used in MANUAL allocation mode. + # All agents must specify this value when allocation-mode is MANUAL. + mem = "32G" + + # Device-specific per-slot resource allocations. + # Only used in MANUAL allocation mode. + [agents.resource.devices] diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 37dea05727f..523cca2022f 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -2276,7 +2276,7 @@ async def scan_running_kernels(self) -> None: """ ipc_base_path = self.local_config.agent.ipc_base_path var_base_path = self.local_config.agent.var_base_path - last_registry_file = f"last_registry.{self.local_instance_id}.dat" + last_registry_file = f"last_registry.{self.id}.dat" if os.path.isfile(ipc_base_path / last_registry_file): shutil.move(ipc_base_path / last_registry_file, var_base_path / last_registry_file) try: @@ -3745,7 +3745,7 @@ async def save_last_registry(self, force=False) -> None: if (not force) and (now <= self.last_registry_written_time + 60): return # don't save too frequently var_base_path = self.local_config.agent.var_base_path - last_registry_file = f"last_registry.{self.local_instance_id}.dat" + last_registry_file = f"last_registry.{self.id}.dat" try: with open(var_base_path / last_registry_file, "wb") as f: pickle.dump(self.kernel_registry, f) diff --git a/src/ai/backend/agent/config/unified.py b/src/ai/backend/agent/config/unified.py index fee6e503823..623e88d3bfc 100644 --- a/src/ai/backend/agent/config/unified.py +++ b/src/ai/backend/agent/config/unified.py @@ -11,6 +11,7 @@ import os import sys import textwrap +from decimal import Decimal from pathlib import Path from typing import ( Any, @@ -49,6 +50,8 @@ BinarySizeField, ResourceGroupType, ServiceDiscoveryType, + SlotName, + SlotNameField, ) from ai.backend.logging import BraceStyleAdapter from ai.backend.logging.config import LoggingConfig @@ -77,6 +80,12 @@ class ScratchType(enum.StrEnum): K8S_NFS = "k8s-nfs" +class ResourceAllocationMode(enum.StrEnum): + SHARED = "shared" + AUTO_SPLIT = "auto-split" + MANUAL = "manual" + + class AgentConfigValidationContext(BaseConfigValidationContext): is_invoked_subcommand: bool @@ -504,6 +513,34 @@ class CommonAgentConfig(BaseConfigSchema): validation_alias=AliasChoices("metadata-server-port", "metadata_server_port"), serialization_alias="metadata-server-port", ) + allow_compute_plugins: Optional[set[str]] = Field( + default=None, + description="Allowed compute plugins", + examples=[{"ai.backend.activator.agent", "ai.backend.accelerator.cuda_open"}], + validation_alias=AliasChoices("allow-compute-plugins", "allow_compute_plugins"), + serialization_alias="allow-compute-plugins", + ) + block_compute_plugins: Optional[set[str]] = Field( + default=None, + description="Blocked compute plugins", + examples=[{"ai.backend.accelerator.mock"}], + validation_alias=AliasChoices("block-compute-plugins", "block_compute_plugins"), + serialization_alias="block-compute-plugins", + ) + allow_network_plugins: Optional[set[str]] = Field( + default=None, + description="Allowed network plugins", + examples=[{"ai.backend.manager.network.overlay"}], + validation_alias=AliasChoices("allow-network-plugins", "allow_network_plugins"), + serialization_alias="allow-network-plugins", + ) + block_network_plugins: Optional[set[str]] = Field( + default=None, + description="Blocked network plugins", + examples=[{"ai.backend.manager.network.overlay"}], + validation_alias=AliasChoices("block-network-plugins", "block_network_plugins"), + serialization_alias="block-network-plugins", + ) image_commit_path: AutoDirectoryPath = Field( default=AutoDirectoryPath("./tmp/backend.ai/commit"), description="Path for image commit", @@ -596,34 +633,6 @@ class OverridableAgentConfig(BaseConfigSchema): validation_alias=AliasChoices("scaling-group-type", "scaling_group_type"), serialization_alias="scaling-group-type", ) - allow_compute_plugins: Optional[set[str]] = Field( - default=None, - description="Allowed compute plugins", - examples=[{"ai.backend.activator.agent", "ai.backend.accelerator.cuda_open"}], - validation_alias=AliasChoices("allow-compute-plugins", "allow_compute_plugins"), - serialization_alias="allow-compute-plugins", - ) - block_compute_plugins: Optional[set[str]] = Field( - default=None, - description="Blocked compute plugins", - examples=[{"ai.backend.accelerator.mock"}], - validation_alias=AliasChoices("block-compute-plugins", "block_compute_plugins"), - serialization_alias="block-compute-plugins", - ) - allow_network_plugins: Optional[set[str]] = Field( - default=None, - description="Allowed network plugins", - examples=[{"ai.backend.manager.network.overlay"}], - validation_alias=AliasChoices("allow-network-plugins", "allow_network_plugins"), - serialization_alias="allow-network-plugins", - ) - block_network_plugins: Optional[set[str]] = Field( - default=None, - description="Blocked network plugins", - examples=[{"ai.backend.manager.network.overlay"}], - validation_alias=AliasChoices("block-network-plugins", "block_network_plugins"), - serialization_alias="block-network-plugins", - ) force_terminate_abusing_containers: bool = Field( default=False, description="Whether to force terminate abusing containers", @@ -867,6 +876,48 @@ class ContainerConfig(CommonContainerConfig, OverridableContainerConfig): pass +class ResourceAllocationConfig(BaseConfigSchema): + cpu: int = Field( + description=textwrap.dedent(""" + Hard CPU allocation for this agent (e.g., 8 cores). + Only used in MANUAL allocation mode. + All agents must specify this value when allocation-mode is MANUAL. + """), + examples=[8, 16], + ) + mem: BinarySizeField = Field( + description=textwrap.dedent(""" + Hard memory allocation for this agent (e.g., "32G"). + Only used in MANUAL allocation mode. + All agents must specify this value when allocation-mode is MANUAL. + """), + examples=["32G", "64G"], + ) + devices: Mapping[SlotNameField, Decimal] = Field( + default_factory=dict, + description=textwrap.dedent(""" + Device-specific per-slot resource allocations. + Only used in MANUAL allocation mode. + """), + examples=[{"cuda.mem": "0.3", "cuda.shares": "0.5"}], + ) + + model_config = ConfigDict( + extra="allow", + arbitrary_types_allowed=True, + ) + + @model_validator(mode="after") + def validate_values_are_positive(self) -> Self: + if self.cpu is not None and self.cpu < 0: + raise ValueError(f"Allocated cpu must not be a negative value, but given {self.cpu}") + if self.mem is not None and self.mem < 0: + raise ValueError(f"Allocated mem must not be a negative value, but given {self.mem}") + if any(value < 0 for value in self.devices.values()): + raise ValueError("All allocated device resource values must not be a negative value") + return self + + class ResourceConfig(BaseConfigSchema): reserved_cpu: int = Field( default=1, @@ -899,6 +950,25 @@ class ResourceConfig(BaseConfigSchema): validation_alias=AliasChoices("reserved-disk", "reserved_disk"), serialization_alias="reserved-disk", ) + allocation_mode: ResourceAllocationMode = Field( + default=ResourceAllocationMode.SHARED, + description=textwrap.dedent(""" + Resource allocation mode for multi-agent scenarios. + - `shared`: All agents share the full resource pool (default, backward compatible). + - `auto-split`: Automatically divide resources equally (1/N) among all agents. + - `manual`: Manually specify per-agent resource allocations via config. + """), + examples=[item.value for item in ResourceAllocationMode], + validation_alias=AliasChoices("allocation-mode", "allocation_mode"), + serialization_alias="allocation-mode", + ) + allocations: Optional[ResourceAllocationConfig] = Field( + default=None, + description=textwrap.dedent(""" + Resource allocations. + Only used in MANUAL allocation mode. + """), + ) memory_align_size: BinarySizeField = Field( default=BinarySize.finite_from_str("16M"), description=( @@ -1183,11 +1253,11 @@ class AgentOverrideConfig(BaseConfigSchema): Only override fields if necessary. """), ) - container: OverridableContainerConfig | None = Field( + container: Optional[OverridableContainerConfig] = Field( default=None, description="Container config overrides for the individual agent", ) - resource: ResourceConfig | None = Field( + resource: Optional[ResourceAllocationConfig] = Field( default=None, description="Resource config overrides for the individual agent", ) @@ -1209,10 +1279,19 @@ def construct_unified_config(self, *, default: AgentUnifiedConfig) -> AgentUnifi update=container_override_fields ) if self.resource is not None: - resource_override_fields = self.resource.model_dump( - include=self.resource.model_fields_set + default_allocations = default.resource.allocations + override_allocations = self.resource + if default_allocations is None: + merged_allocations = override_allocations + else: + merged_allocations = default_allocations.model_copy( + update=override_allocations.model_dump( + include=override_allocations.model_fields_set + ) + ) + agent_updates["resource"] = default.resource.model_copy( + update={"allocations": merged_allocations} ) - agent_updates["resource"] = default.resource.model_copy(update=resource_override_fields) return default.model_copy(update=agent_updates) @@ -1314,3 +1393,30 @@ def _validate_agent_configs(self) -> Self: config.validate_agent_specific_config() return self + + @model_validator(mode="after") + def _validate_resource_allocation_mode(self) -> Self: + agent_configs = self.get_agent_configs() + + match self.resource.allocation_mode: + case ResourceAllocationMode.SHARED | ResourceAllocationMode.AUTO_SPLIT: + for config in agent_configs: + if config.resource.allocations is not None: + raise ValueError( + "On non-MANUAL mode, config must not specify manual resource allocations" + ) + + case ResourceAllocationMode.MANUAL: + slot_names: list[set[SlotName]] = [] + for config in agent_configs: + if config.resource.allocations is None: + raise ValueError( + "On MANUAL mode, config must specify cpu and mem resource allocations" + ) + + slot_names.append(set(config.resource.allocations.devices.keys())) + + if not all(slot_name == slot_names[0] for slot_name in slot_names): + raise ValueError("All agents must have the same slots defined in the devices!") + + return self diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index 530ab896486..023cc3d7e84 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -138,7 +138,6 @@ update_nested_dict, ) from .kernel import DockerKernel -from .metadata.server import MetadataServer from .resources import load_resources, scan_available_resources from .utils import PersistentServiceContainer @@ -1341,7 +1340,6 @@ class DockerAgent(AbstractAgent[DockerKernel, DockerKernelCreationContext]): monitor_docker_task: asyncio.Task agent_sockpath: Path agent_sock_task: asyncio.Task - metadata_server: MetadataServer docker_ptask_group: aiotools.PersistentTaskGroup gwbridge_subnet: Optional[str] checked_invalid_images: Set[str] @@ -1414,10 +1412,10 @@ async def __ainit__(self) -> None: self.gwbridge_subnet = None ipc_base_path = self.local_config.agent.ipc_base_path (ipc_base_path / "container").mkdir(parents=True, exist_ok=True) - self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock" + self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock" # Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs if sys.platform != "darwin": - socket_relay_name = f"backendai-socket-relay.{self.local_instance_id}" + socket_relay_name = f"backendai-socket-relay.{self.id}" socket_relay_container = PersistentServiceContainer( "backendai-socket-relay:latest", { @@ -1443,12 +1441,6 @@ async def __ainit__(self) -> None: self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events()) self.docker_ptask_group = aiotools.PersistentTaskGroup() - self.metadata_server = await MetadataServer.new( - self.local_config, - self.etcd, - self.kernel_registry, - ) - await self.metadata_server.start_server() # For legacy accelerator plugins self.docker = Docker() @@ -1477,7 +1469,6 @@ async def shutdown(self, stop_signal: signal.Signals): self.monitor_docker_task.cancel() await self.monitor_docker_task - await self.metadata_server.cleanup() if self.docker: await self.docker.close() diff --git a/src/ai/backend/agent/errors/__init__.py b/src/ai/backend/agent/errors/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/agent/errors/runtime.py b/src/ai/backend/agent/errors/runtime.py new file mode 100644 index 00000000000..9c329f8cd1c --- /dev/null +++ b/src/ai/backend/agent/errors/runtime.py @@ -0,0 +1,17 @@ +from ai.backend.common.exception import ( + BackendAIError, + ErrorCode, + ErrorDetail, + ErrorDomain, + ErrorOperation, +) + + +class AgentIdNotFoundError(BackendAIError): + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.AGENT, + operation=ErrorOperation.ACCESS, + error_detail=ErrorDetail.NOT_FOUND, + ) diff --git a/src/ai/backend/agent/kubernetes/agent.py b/src/ai/backend/agent/kubernetes/agent.py index cb2411c5c5a..e36de0dfb74 100644 --- a/src/ai/backend/agent/kubernetes/agent.py +++ b/src/ai/backend/agent/kubernetes/agent.py @@ -844,7 +844,7 @@ def __init__( async def __ainit__(self) -> None: await super().__ainit__() ipc_base_path = self.local_config.agent.ipc_base_path - self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.local_instance_id}.sock" + self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock" await self.check_krunner_pv_status() await self.fetch_workers() diff --git a/src/ai/backend/agent/runtime.py b/src/ai/backend/agent/runtime.py index 5f242e6b3b2..60e9c8dc7a0 100644 --- a/src/ai/backend/agent/runtime.py +++ b/src/ai/backend/agent/runtime.py @@ -1,78 +1,175 @@ +from __future__ import annotations + +import asyncio import importlib import signal -from typing import Optional, Type +from typing import TYPE_CHECKING, Mapping, Optional, Type from ai.backend.agent.agent import AbstractAgent from ai.backend.agent.config.unified import AgentUnifiedConfig +from ai.backend.agent.errors.runtime import AgentIdNotFoundError from ai.backend.agent.etcd import AgentEtcdClientView from ai.backend.agent.kernel import KernelRegistry from ai.backend.agent.monitor import AgentErrorPluginContext, AgentStatsPluginContext +from ai.backend.agent.types import AgentBackend from ai.backend.common.auth import PublicKey from ai.backend.common.etcd import AsyncEtcd, ConfigScopes -from ai.backend.common.types import aobject +from ai.backend.common.types import AgentId + +if TYPE_CHECKING: + from .docker.metadata.server import MetadataServer -class AgentRuntime(aobject): - local_config: AgentUnifiedConfig - agent: AbstractAgent - kernel_registry: KernelRegistry - etcd: AsyncEtcd - etcd_view: AgentEtcdClientView +class AgentRuntime: + _local_config: AgentUnifiedConfig + _etcd_views: Mapping[AgentId, AgentEtcdClientView] + _agents: Mapping[AgentId, AbstractAgent] + _default_agent: AbstractAgent + _kernel_registry: KernelRegistry + _metadata_server: Optional[MetadataServer] _stop_signal: signal.Signals - def __init__( - self, + @classmethod + async def create_runtime( + cls, local_config: AgentUnifiedConfig, etcd: AsyncEtcd, stats_monitor: AgentStatsPluginContext, error_monitor: AgentErrorPluginContext, agent_public_key: Optional[PublicKey], - ) -> None: - self.local_config = local_config - self.kernel_registry = KernelRegistry() - self.etcd = etcd - self.etcd_view = AgentEtcdClientView(etcd, self.local_config) - - self._stop_signal = signal.SIGTERM - - self.stats_monitor = stats_monitor - self.error_monitor = error_monitor - self.agent_public_key = agent_public_key - - async def __ainit__(self) -> None: - self.agent = await self._create_agent(self.etcd_view, self.local_config) - - async def __aexit__(self, *exc_info) -> None: - await self.agent.shutdown(self._stop_signal) - - def get_agent(self) -> AbstractAgent: - return self.agent - - def get_etcd(self) -> AgentEtcdClientView: - return self.etcd_view - - def mark_stop_signal(self, stop_signal: signal.Signals) -> None: - self._stop_signal = stop_signal - - async def update_status(self, status) -> None: - etcd = self.get_etcd() - await etcd.put("", status, scope=ConfigScopes.NODE) - + ) -> AgentRuntime: + kernel_registry = KernelRegistry() + + if local_config.agent_common.backend == AgentBackend.DOCKER: + metadata_server = await cls._create_metadata_server(local_config, etcd, kernel_registry) + else: + metadata_server = None + + agent_configs = local_config.get_agent_configs() + etcd_views: dict[AgentId, AgentEtcdClientView] = {} + create_agent_tasks: list[asyncio.Task] = [] + async with asyncio.TaskGroup() as tg: + for agent_config in agent_configs: + agent_id = AgentId(agent_config.agent.id) + + etcd_view = AgentEtcdClientView(etcd, agent_config) + etcd_views[agent_id] = etcd_view + + create_agent_task = tg.create_task( + cls._create_agent( + local_config, + etcd_view, + kernel_registry, + agent_config, + stats_monitor, + error_monitor, + agent_public_key, + ) + ) + create_agent_tasks.append(create_agent_task) + agents_list = [task.result() for task in create_agent_tasks] + default_agent = agents_list[0] + agents = {agent.id: agent for agent in agents_list} + + return AgentRuntime( + local_config=local_config, + etcd_views=etcd_views, + agents=agents, + default_agent=default_agent, + kernel_registry=kernel_registry, + metadata_server=metadata_server, + ) + + @classmethod + async def _create_metadata_server( + cls, + local_config: AgentUnifiedConfig, + etcd: AsyncEtcd, + kernel_registry: KernelRegistry, + ) -> MetadataServer: + from .docker.metadata.server import MetadataServer + + metadata_server = await MetadataServer.new( + local_config, + etcd, + kernel_registry=kernel_registry.global_view(), + ) + await metadata_server.start_server() + return metadata_server + + @classmethod async def _create_agent( - self, + cls, + local_config: AgentUnifiedConfig, etcd_view: AgentEtcdClientView, + kernel_registry: KernelRegistry, agent_config: AgentUnifiedConfig, + stats_monitor: AgentStatsPluginContext, + error_monitor: AgentErrorPluginContext, + agent_public_key: Optional[PublicKey], ) -> AbstractAgent: agent_kwargs = { - "kernel_registry": self.kernel_registry, - "stats_monitor": self.stats_monitor, - "error_monitor": self.error_monitor, - "agent_public_key": self.agent_public_key, + "kernel_registry": kernel_registry, + "stats_monitor": stats_monitor, + "error_monitor": error_monitor, + "agent_public_key": agent_public_key, } - backend = self.local_config.agent_common.backend + backend = local_config.agent_common.backend agent_mod = importlib.import_module(f"ai.backend.agent.{backend.value}") agent_cls: Type[AbstractAgent] = agent_mod.get_agent_cls() return await agent_cls.new(etcd_view, agent_config, **agent_kwargs) + + def __init__( + self, + local_config: AgentUnifiedConfig, + etcd_views: Mapping[AgentId, AgentEtcdClientView], + agents: dict[AgentId, AbstractAgent], + default_agent: AbstractAgent, + kernel_registry: KernelRegistry, + metadata_server: Optional[MetadataServer] = None, + ) -> None: + self._local_config = local_config + self._etcd_views = etcd_views + self._agents = agents + self._default_agent = default_agent + self._kernel_registry = kernel_registry + self._metadata_server = metadata_server + + self._stop_signal = signal.SIGTERM + + async def __aexit__(self, *exc_info) -> None: + for agent in self._agents.values(): + await agent.shutdown(self._stop_signal) + if self._metadata_server is not None: + await self._metadata_server.cleanup() + + def get_agents(self) -> list[AbstractAgent]: + return list(self._agents.values()) + + def get_agent(self, agent_id: Optional[AgentId]) -> AbstractAgent: + if agent_id is None: + return self._default_agent + if agent_id not in self._agents: + raise AgentIdNotFoundError( + f"Agent '{agent_id}' not found in this runtime. " + f"Available agents: {', '.join(self._agents.keys())}" + ) + return self._agents[agent_id] + + def get_etcd(self, agent_id: AgentId) -> AgentEtcdClientView: + if agent_id not in self._etcd_views: + raise AgentIdNotFoundError( + f"Etcd client for agent '{agent_id}' not found in this runtime. " + f"Available agent etcd views: {', '.join(self._etcd_views.keys())}" + ) + return self._etcd_views[agent_id] + + def mark_stop_signal(self, stop_signal: signal.Signals) -> None: + self._stop_signal = stop_signal + + async def update_status(self, status: str, agent_id: AgentId) -> None: + etcd = self.get_etcd(agent_id) + await etcd.put("", status, scope=ConfigScopes.NODE) diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 06e9bc99d6f..84f6eaf4b72 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -48,6 +48,7 @@ from setproctitle import setproctitle from zmq.auth.certs import load_certificate +from ai.backend.agent.agent import AbstractAgent from ai.backend.agent.metrics.metric import RPCMetricObserver from ai.backend.agent.monitor import AgentErrorPluginContext, AgentStatsPluginContext from ai.backend.agent.resources import scan_gpu_alloc_map @@ -97,6 +98,7 @@ ServiceMetadata, ) from ai.backend.common.types import ( + AgentId, ClusterInfo, CommitStatus, ContainerId, @@ -142,7 +144,8 @@ async def _inner(self: AgentRPCServer, *args, **kwargs): try: return await meth(self, *args, **kwargs) except Exception: - agent = self.runtime.get_agent() + agent_id = kwargs.get("agent_id", None) + agent = self.runtime.get_agent(agent_id) await agent.produce_error_event() raise @@ -285,9 +288,6 @@ def __init__( self.skip_detect_manager = skip_detect_manager async def __ainit__(self) -> None: - # Start serving requests. - await self.update_status("starting") - if not self.skip_detect_manager: await self.detect_manager() @@ -332,7 +332,7 @@ async def __ainit__(self) -> None: self.rpc_auth_agent_secret_key = None auth_handler = None - self.runtime = await AgentRuntime.new( + self.runtime = await AgentRuntime.create_runtime( self.local_config, self.etcd, self.stats_monitor, @@ -340,6 +340,11 @@ async def __ainit__(self) -> None: self.rpc_auth_agent_public_key, ) + # Start serving requests. + async with asyncio.TaskGroup() as tg: + for agent_id in self.local_config.agent_ids: + tg.create_task(self.update_status("starting", agent_id)) + rpc_addr = self.local_config.agent_common.rpc_listen_addr self.rpc_server = Peer( bind=ZeroMQAddress(f"tcp://{rpc_addr.address}"), @@ -375,16 +380,20 @@ async def _debug_server_task(): self.debug_server_task = asyncio.create_task(_debug_server_task()) - etcd = self.runtime.get_etcd() - await etcd.put("ip", rpc_addr.host, scope=ConfigScopes.NODE) + async with asyncio.TaskGroup() as tg: + for agent in self.runtime.get_agents(): + etcd = self.runtime.get_etcd(agent.id) + tg.create_task(etcd.put("ip", rpc_addr.host, scope=ConfigScopes.NODE)) - watcher_port = utils.nmget( - self.local_config.model_dump(), "watcher.service-addr.port", None - ) - if watcher_port is not None: - await etcd.put("watcher_port", watcher_port, scope=ConfigScopes.NODE) + watcher_port = utils.nmget( + agent.local_config.model_dump(), "watcher.service-addr.port", None + ) + if watcher_port is not None: + tg.create_task(etcd.put("watcher_port", watcher_port, scope=ConfigScopes.NODE)) - await self.update_status("running") + async with asyncio.TaskGroup() as tg: + for agent in self.runtime.get_agents(): + tg.create_task(self.update_status("running", agent.id)) async def status_snapshot_request_handler( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter @@ -403,19 +412,21 @@ def _ensure_serializable(o) -> Any: return str(o) try: - agent = self.runtime.get_agent() - if agent: + if self.runtime.get_agents(): snapshot = { - "registry": { - str(kern_id): _ensure_serializable(kern.__getstate__()) - for kern_id, kern in agent.kernel_registry.items() - }, - "allocs": { - str(computer): _ensure_serializable( - dict(computer_ctx.alloc_map.allocations) - ) - for computer, computer_ctx in agent.computers.items() - }, + str(agent.id): { + "registry": { + str(kern_id): _ensure_serializable(kern.__getstate__()) + for kern_id, kern in agent.kernel_registry.items() + }, + "allocs": { + str(computer): _ensure_serializable( + dict(computer_ctx.alloc_map.allocations) + ) + for computer, computer_ctx in agent.computers.items() + }, + } + for agent in self.runtime.get_agents() } writer.write(pretty_json(snapshot)) await writer.drain() @@ -538,24 +549,47 @@ async def __aexit__(self, *exc_info) -> None: await self.error_monitor.cleanup() @collect_error - async def update_status(self, status): - await self.runtime.update_status(status) + async def update_status(self, status: str, agent_id: AgentId): + await self.runtime.update_status(status, agent_id) @rpc_function @collect_error - async def update_scaling_group(self, scaling_group): + async def update_scaling_group(self, scaling_group: str, agent_id: AgentId | None = None): cfg_src_path = config.find_config_file("agent") with open(cfg_src_path, "r") as f: data = tomlkit.load(f) - data["agent"]["scaling-group"] = scaling_group + agent = self.runtime.get_agent(agent_id) + if "agents" in data: + self._update_scaling_group_override(data, scaling_group, agent) + else: + self._update_scaling_group_default(data, scaling_group) shutil.copy(cfg_src_path, f"{cfg_src_path}.bak") with open(cfg_src_path, "w") as f: tomlkit.dump(data, f) - agent = self.runtime.get_agent() agent.update_scaling_group(scaling_group) log.info("rpc::update_scaling_group()") + def _update_scaling_group_default( + self, + config_data: tomlkit.TOMLDocument, + scaling_group: str, + ) -> None: + config_data["agent"]["scaling-group"] = scaling_group # type: ignore[index] + + def _update_scaling_group_override( + self, + config_data: tomlkit.TOMLDocument, + scaling_group: str, + agent: AbstractAgent, + ) -> None: + assert "agents" in config_data + + for agent_config in config_data["agents"]: # type: ignore[union-attr] + if agent_config["agent"]["id"] == str(agent.id): # type: ignore[index] + agent_config["agent"]["scaling-group"] = scaling_group # type: ignore[index] + break + @rpc_function @collect_error async def ping(self, msg: str) -> str: @@ -564,45 +598,64 @@ async def ping(self, msg: str) -> str: @rpc_function @collect_error - async def gather_hwinfo(self) -> Mapping[str, HardwareMetadata]: + async def gather_hwinfo( + self, + agent_id: AgentId | None = None, + ) -> Mapping[str, HardwareMetadata]: log.debug("rpc::gather_hwinfo()") - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.gather_hwinfo() @rpc_function @collect_error - async def ping_kernel(self, kernel_id: str) -> dict[str, float] | None: + async def ping_kernel( + self, + kernel_id: str, + agent_id: AgentId | None = None, + ) -> dict[str, float] | None: log.debug("rpc::ping_kernel(k:{})", kernel_id) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.ping_kernel(KernelId(UUID(kernel_id))) @rpc_function @collect_error - async def check_pulling(self, image_name: str) -> bool: + async def check_pulling( + self, + image_name: str, + agent_id: AgentId | None = None, + ) -> bool: """Check if an image is being pulled.""" log.debug("rpc::check_pulling(image:{})", image_name) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return image_name in agent._active_pulls @rpc_function @collect_error - async def check_creating(self, kernel_id: str) -> bool: + async def check_creating( + self, + kernel_id: str, + agent_id: AgentId | None = None, + ) -> bool: """Check if a kernel is being created or already exists.""" log.debug("rpc::check_creating(k:{})", kernel_id) kid = KernelId(UUID(kernel_id)) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) # Check if kernel is being created OR already exists in registry return kid in agent._active_creates or kid in agent.kernel_registry @rpc_function @collect_error - async def check_running(self, kernel_id: str) -> bool: + async def check_running( + self, + kernel_id: str, + agent_id: AgentId | None = None, + ) -> bool: """Check if a kernel is running.""" log.debug("rpc::check_running(k:{})", kernel_id) kid = KernelId(UUID(kernel_id)) # Safely get kernel from registry - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) kernel_obj = agent.kernel_registry.get(kid) # Check if kernel exists and is running @@ -616,8 +669,9 @@ async def check_running(self, kernel_id: str) -> bool: async def sync_kernel_registry( self, raw_kernel_session_ids: Iterable[tuple[str, str]], + agent_id: AgentId | None = None, ) -> None: - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) kernel_session_ids = [ (KernelId(UUID(raw_kid)), SessionId(UUID(raw_sid))) @@ -656,13 +710,14 @@ async def sync_kernel_registry( async def check_and_pull( self, image_configs: Mapping[str, ImageConfig], + agent_id: AgentId | None = None, ) -> dict[str, str]: """ Check whether the agent has images and pull if needed. Delegates to agent's check_and_pull method which handles tracking. """ log.debug("rpc::check_and_pull(images:{})", list(image_configs.keys())) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.check_and_pull(image_configs) @rpc_function @@ -674,12 +729,13 @@ async def create_kernels( raw_configs: Sequence[dict], raw_cluster_info: dict, kernel_image_refs: dict[KernelId, ImageRef], + agent_id: AgentId | None = None, ): cluster_info = cast(ClusterInfo, raw_cluster_info) session_id = SessionId(UUID(raw_session_id)) coros = [] - agent = self.runtime.get_agent() - throttle_sema = asyncio.Semaphore(self.local_config.agent.kernel_creation_concurrency) + agent = self.runtime.get_agent(agent_id) + throttle_sema = asyncio.Semaphore(agent.local_config.agent.kernel_creation_concurrency) for raw_kernel_id, raw_config in zip(raw_kernel_ids, raw_configs): log.info( "rpc::create_kernel(k:{0}, img:{1})", @@ -740,11 +796,12 @@ async def destroy_kernel( session_id: str, reason: Optional[KernelLifecycleEventReason] = None, suppress_events: bool = False, + agent_id: AgentId | None = None, ): loop = asyncio.get_running_loop() done = loop.create_future() log.info("rpc::destroy_kernel(k:{0})", kernel_id) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) await agent.inject_container_lifecycle_event( KernelId(UUID(kernel_id)), SessionId(UUID(session_id)), @@ -760,6 +817,7 @@ async def destroy_kernel( async def purge_containers( self, container_kernel_ids: list[tuple[str, str]], + agent_id: AgentId | None = None, ) -> PurgeContainersResp: str_kernel_ids = [str(kid) for _, kid in container_kernel_ids] log.info("rpc::purge_containers(kernel_ids:{0})", str_kernel_ids) @@ -770,7 +828,7 @@ async def purge_containers( ) for cid, kid in container_kernel_ids ] - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) asyncio.create_task(agent.purge_containers(kernel_container_pairs)) return PurgeContainersResp() @@ -779,33 +837,48 @@ async def purge_containers( async def drop_kernel_registry( self, kernel_ids: list[UUID], + agent_id: AgentId | None = None, ) -> DropKernelRegistryResp: str_kernel_ids = [str(kid) for kid in kernel_ids] log.info("rpc::drop_kernel_registry(kernel_ids:{0})", str_kernel_ids) kernel_ids_to_purge = [KernelId(kid) for kid in kernel_ids] - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) asyncio.create_task(agent.clean_kernel_objects(kernel_ids_to_purge)) return DropKernelRegistryResp() @rpc_function @collect_error - async def interrupt_kernel(self, kernel_id: str): + async def interrupt_kernel( + self, + kernel_id: str, + agent_id: AgentId | None = None, + ): log.info("rpc::interrupt_kernel(k:{0})", kernel_id) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) await agent.interrupt_kernel(KernelId(UUID(kernel_id))) @rpc_function_v2 @collect_error - async def get_completions(self, kernel_id: str, text: str, opts: dict) -> CodeCompletionResp: + async def get_completions( + self, + kernel_id: str, + text: str, + opts: dict, + agent_id: AgentId | None = None, + ) -> CodeCompletionResp: log.debug("rpc::get_completions(k:{0}, ...)", kernel_id) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.get_completions(KernelId(UUID(kernel_id)), text, opts) @rpc_function @collect_error - async def get_logs(self, kernel_id: str): + async def get_logs( + self, + kernel_id: str, + agent_id: AgentId | None = None, + ): log.info("rpc::get_logs(k:{0})", kernel_id) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.get_logs(KernelId(UUID(kernel_id))) @rpc_function @@ -816,9 +889,10 @@ async def restart_kernel( kernel_id: str, kernel_image: ImageRef, updated_config: dict, + agent_id: AgentId | None = None, ) -> dict[str, Any]: log.info("rpc::restart_kernel(s:{0}, k:{1})", session_id, kernel_id) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.restart_kernel( KernelOwnershipData( KernelId(UUID(kernel_id)), @@ -841,6 +915,7 @@ async def execute( code: str, opts: dict[str, Any], flush_timeout: float, + agent_id: AgentId | None = None, ) -> dict[str, Any]: if mode != "continue": log.info( @@ -850,7 +925,7 @@ async def execute( mode, code[:20] + "..." if len(code) > 20 else code, ) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) result = await agent.execute( SessionId(UUID(session_id)), KernelId(UUID(kernel_id)), @@ -871,6 +946,7 @@ async def trigger_batch_execution( kernel_id: str, code: str, timeout: Optional[float], + agent_id: AgentId | None = None, ) -> None: log.info( "rpc::trigger_batch_execution(k:{0}, s:{1}, code:{2}, timeout:{3})", @@ -879,7 +955,7 @@ async def trigger_batch_execution( code, timeout, ) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) await agent.create_batch_execution_task( SessionId(UUID(session_id)), KernelId(UUID(kernel_id)), code, timeout ) @@ -891,9 +967,10 @@ async def start_service( kernel_id: str, service: str, opts: dict[str, Any], + agent_id: AgentId | None = None, ) -> dict[str, Any]: log.info("rpc::start_service(k:{0}, app:{1})", kernel_id, service) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.start_service(KernelId(UUID(kernel_id)), service, opts) @rpc_function @@ -902,10 +979,11 @@ async def get_commit_status( self, kernel_id: str, subdir: str, + agent_id: AgentId | None = None, ) -> dict[str, Any]: # Only this function logs debug since web sends request at short intervals log.debug("rpc::get_commit_status(k:{})", kernel_id) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) status: CommitStatus = await agent.get_commit_status( KernelId(UUID(kernel_id)), subdir, @@ -925,9 +1003,10 @@ async def commit( canonical: str | None = None, filename: str | None = None, extra_labels: dict[str, str] = {}, + agent_id: AgentId | None = None, ) -> dict[str, Any]: log.info("rpc::commit(k:{})", kernel_id) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) bgtask_mgr = agent.background_task_manager async def _commit(reporter: ProgressReporter) -> None: @@ -953,9 +1032,10 @@ async def push_image( self, image_ref: ImageRef, registry_conf: ImageRegistry, + agent_id: AgentId | None = None, ) -> dict[str, Any]: log.info("rpc::push_image(c:{})", image_ref.canonical) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) bgtask_mgr = agent.background_task_manager image_push_timeout = cast(Optional[float], self.local_config.api.push_timeout) @@ -976,7 +1056,11 @@ async def _push_image(reporter: ProgressReporter) -> None: @rpc_function_v2 @collect_error async def purge_images( - self, image_canonicals: list[str], force: bool, noprune: bool + self, + image_canonicals: list[str], + force: bool, + noprune: bool, + agent_id: AgentId | None = None, ) -> PurgeImagesResp: log.info( "rpc::purge_images(images:{0}, force:{1}, noprune:{2})", @@ -984,14 +1068,14 @@ async def purge_images( force, noprune, ) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.purge_images( PurgeImagesReq(images=image_canonicals, force=force, noprune=noprune) ) @rpc_function @collect_error - async def get_local_config(self) -> Mapping[str, Any]: + async def get_local_config(self, agent_id: AgentId | None = None) -> Mapping[str, Any]: report_path: Path | None = self.local_config.agent_common.abuse_report_path return { "agent": { @@ -1006,65 +1090,95 @@ async def shutdown_service( self, kernel_id: str, service: str, + agent_id: AgentId | None = None, ): log.info("rpc::shutdown_service(k:{0}, app:{1})", kernel_id, service) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.shutdown_service(KernelId(UUID(kernel_id)), service) @rpc_function @collect_error - async def upload_file(self, kernel_id: str, filename: str, filedata: bytes): + async def upload_file( + self, + kernel_id: str, + filename: str, + filedata: bytes, + agent_id: AgentId | None = None, + ): log.info("rpc::upload_file(k:{0}, fn:{1})", kernel_id, filename) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) await agent.accept_file(KernelId(UUID(kernel_id)), filename, filedata) @rpc_function @collect_error - async def download_file(self, kernel_id: str, filepath: str): + async def download_file( + self, + kernel_id: str, + filepath: str, + agent_id: AgentId | None = None, + ): log.info("rpc::download_file(k:{0}, fn:{1})", kernel_id, filepath) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.download_file(KernelId(UUID(kernel_id)), filepath) @rpc_function @collect_error - async def download_single(self, kernel_id: str, filepath: str): + async def download_single( + self, + kernel_id: str, + filepath: str, + agent_id: AgentId | None = None, + ): log.info("rpc::download_single(k:{0}, fn:{1})", kernel_id, filepath) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.download_single(KernelId(UUID(kernel_id)), filepath) @rpc_function @collect_error - async def list_files(self, kernel_id: str, path: str): + async def list_files( + self, + kernel_id: str, + path: str, + agent_id: AgentId | None = None, + ): log.info("rpc::list_files(k:{0}, fn:{1})", kernel_id, path) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.list_files(KernelId(UUID(kernel_id)), path) @rpc_function @collect_error - async def shutdown_agent(self, terminate_kernels: bool): + async def shutdown_agent(self, terminate_kernels: bool, agent_id: AgentId | None = None): # TODO: implement log.info("rpc::shutdown_agent()") pass @rpc_function @collect_error - async def create_local_network(self, network_name: str) -> None: + async def create_local_network( + self, + network_name: str, + agent_id: AgentId | None = None, + ) -> None: log.debug("rpc::create_local_network(name:{})", network_name) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.create_local_network(network_name) @rpc_function @collect_error - async def destroy_local_network(self, network_name: str) -> None: + async def destroy_local_network( + self, + network_name: str, + agent_id: AgentId | None = None, + ) -> None: log.debug("rpc::destroy_local_network(name:{})", network_name) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return await agent.destroy_local_network(network_name) @rpc_function @collect_error - async def reset_agent(self): + async def reset_agent(self, agent_id: AgentId | None = None): log.debug("rpc::reset()") - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) kernel_ids = tuple(agent.kernel_registry.keys()) tasks = [] for kernel_id in kernel_ids: @@ -1080,23 +1194,23 @@ async def reset_agent(self): @rpc_function @collect_error - async def assign_port(self): + async def assign_port(self, agent_id: AgentId | None = None): log.debug("rpc::assign_port()") - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) return agent.port_pool.pop() @rpc_function @collect_error - async def release_port(self, port_no: int): + async def release_port(self, port_no: int, agent_id: AgentId | None = None): log.debug("rpc::release_port(port_no:{})", port_no) - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) agent.port_pool.add(port_no) @rpc_function @collect_error - async def scan_gpu_alloc_map(self) -> Mapping[str, Any]: + async def scan_gpu_alloc_map(self, agent_id: AgentId | None = None) -> Mapping[str, Any]: log.debug("rpc::scan_gpu_alloc_map()") - agent = self.runtime.get_agent() + agent = self.runtime.get_agent(agent_id) scratch_root = agent.local_config.container.scratch_root result = await scan_gpu_alloc_map(list(agent.kernel_registry.keys()), scratch_root) return {k: str(v) for k, v in result.items()} diff --git a/src/ai/backend/common/configs/sample_generator.py b/src/ai/backend/common/configs/sample_generator.py index 2fbfebadc9e..d05f2036181 100644 --- a/src/ai/backend/common/configs/sample_generator.py +++ b/src/ai/backend/common/configs/sample_generator.py @@ -124,7 +124,10 @@ def _dump_toml_scalar( if ctx is not None: match ctx.hint: case "BinarySize": - value = f"{BinarySize(value):s}".upper() + if isinstance(value, BinarySize): + value = f"{value:s}".upper() + else: + value = f"{BinarySize.from_str(str(value)):s}".upper() case "HostPortPair": value = {"host": value["host"], "port": value["port"]} case "EnumByValue": diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 2e9edaf1d48..e6e0d980db7 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -101,6 +101,7 @@ "ResourceSlot", "ResourceGroupType", "SlotName", + "SlotNameField", "SlotTypes", "IntrinsicSlotNames", "DefaultForUnspecified", @@ -365,6 +366,17 @@ def is_accelerator(self) -> bool: return False +def _validate_slot_name(v: Any) -> SlotName: + """Validator for SlotName fields.""" + if isinstance(v, SlotName): + return v + return SlotName(v) + + +# Create a custom type annotation for SlotName fields +SlotNameField = Annotated[SlotName, PlainValidator(_validate_slot_name)] + + MetricKey = NewType("MetricKey", str) AccessKey = NewType("AccessKey", str) diff --git a/src/ai/backend/manager/clients/agent/client.py b/src/ai/backend/manager/clients/agent/client.py index 59a5454d32f..37510dc9d90 100644 --- a/src/ai/backend/manager/clients/agent/client.py +++ b/src/ai/backend/manager/clients/agent/client.py @@ -83,47 +83,50 @@ async def _with_connection(self) -> AsyncIterator[PeerInvoker]: async def gather_hwinfo(self) -> Mapping[str, Any]: """Gather hardware information from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.gather_hwinfo() + return await rpc.call.gather_hwinfo(agent_id=self.agent_id) @agent_client_resilience.apply() async def scan_gpu_alloc_map(self) -> Mapping[str, Any]: """Scan GPU allocation map from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.scan_gpu_alloc_map() + return await rpc.call.scan_gpu_alloc_map(agent_id=self.agent_id) # Image management methods @agent_client_resilience.apply() async def check_and_pull(self, image_configs: Mapping[str, ImageConfig]) -> Mapping[str, str]: """Check and pull images on the agent.""" async with self._with_connection() as rpc: - return await rpc.call.check_and_pull(image_configs) + return await rpc.call.check_and_pull(image_configs, agent_id=self.agent_id) @agent_client_resilience.apply() async def purge_images( - self, images: list[str], force: bool, noprune: bool + self, + images: list[str], + force: bool, + noprune: bool, ) -> Mapping[str, Any]: """Purge images from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.purge_images(images, force, noprune) + return await rpc.call.purge_images(images, force, noprune, agent_id=self.agent_id) # Network management methods @agent_client_resilience.apply() async def create_local_network(self, network_name: str) -> None: """Create a local network on the agent.""" async with self._with_connection() as rpc: - await rpc.call.create_local_network(network_name) + await rpc.call.create_local_network(network_name, agent_id=self.agent_id) @agent_client_resilience.apply() async def destroy_local_network(self, network_ref_name: str) -> None: """Destroy a local network on the agent.""" async with self._with_connection() as rpc: - await rpc.call.destroy_local_network(network_ref_name) + await rpc.call.destroy_local_network(network_ref_name, agent_id=self.agent_id) @agent_client_resilience.apply() async def assign_port(self) -> int: """Assign a host port on the agent.""" async with self._with_connection() as rpc: - return await rpc.call.assign_port() + return await rpc.call.assign_port(agent_id=self.agent_id) # Kernel management methods @agent_client_resilience.apply() @@ -143,6 +146,7 @@ async def create_kernels( kernel_configs, cluster_info, kernel_image_refs, + agent_id=self.agent_id, ) @agent_client_resilience.apply() @@ -160,6 +164,7 @@ async def destroy_kernel( session_id, reason, suppress_events=suppress_events, + agent_id=self.agent_id, ) @agent_client_resilience.apply() @@ -177,45 +182,46 @@ async def restart_kernel( kernel_id, image_ref, update_config, + agent_id=self.agent_id, ) @agent_client_resilience.apply() async def sync_kernel_registry(self, kernel_tuples: list[tuple[str, str]]) -> None: """Sync kernel registry on the agent.""" async with self._with_connection() as rpc: - return await rpc.call.sync_kernel_registry(kernel_tuples) + return await rpc.call.sync_kernel_registry(kernel_tuples, agent_id=self.agent_id) @agent_client_resilience.apply() async def drop_kernel_registry(self, kernel_id_list: list[KernelId]) -> None: """Drop kernel registry entries on the agent.""" async with self._with_connection() as rpc: - await rpc.call.drop_kernel_registry(kernel_id_list) + await rpc.call.drop_kernel_registry(kernel_id_list, agent_id=self.agent_id) # Health monitoring methods @agent_client_resilience.apply() async def check_pulling(self, image_name: str) -> bool: """Check if an image is being pulled.""" async with self._with_connection() as rpc: - return await rpc.call.check_pulling(image_name) + return await rpc.call.check_pulling(image_name, agent_id=self.agent_id) @agent_client_resilience.apply() async def check_creating(self, kernel_id: str) -> bool: """Check if a kernel is being created.""" async with self._with_connection() as rpc: - return await rpc.call.check_creating(str(kernel_id)) + return await rpc.call.check_creating(str(kernel_id), agent_id=self.agent_id) @agent_client_resilience.apply() async def check_running(self, kernel_id: str) -> bool: """Check if a kernel is running.""" async with self._with_connection() as rpc: - return await rpc.call.check_running(str(kernel_id)) + return await rpc.call.check_running(str(kernel_id), agent_id=self.agent_id) # Container management methods @agent_client_resilience.apply() async def purge_containers(self, serialized_data: list[tuple[str, str]]) -> None: """Purge containers on the agent.""" async with self._with_connection() as rpc: - await rpc.call.purge_containers(serialized_data) + await rpc.call.purge_containers(serialized_data, agent_id=self.agent_id) # Code execution methods @agent_client_resilience.apply() @@ -241,13 +247,14 @@ async def execute( code, opts, flush_timeout, + agent_id=self.agent_id, ) @agent_client_resilience.apply() async def interrupt_kernel(self, kernel_id: str) -> Mapping[str, Any]: """Interrupt a kernel on the agent.""" async with self._with_connection() as rpc: - return await rpc.call.interrupt_kernel(kernel_id) + return await rpc.call.interrupt_kernel(kernel_id, agent_id=self.agent_id) @agent_client_resilience.apply() async def trigger_batch_execution( @@ -264,6 +271,7 @@ async def trigger_batch_execution( kernel_id, startup_command, batch_timeout, + agent_id=self.agent_id, ) @agent_client_resilience.apply() @@ -275,7 +283,7 @@ async def get_completions( ) -> dict[str, Any]: """Get code completions from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.get_completions(kernel_id, text, opts) + return await rpc.call.get_completions(kernel_id, text, opts, agent_id=self.agent_id) # Service management methods @agent_client_resilience.apply() @@ -287,66 +295,45 @@ async def start_service( ) -> Mapping[str, Any]: """Start a service on the agent.""" async with self._with_connection() as rpc: - return await rpc.call.start_service(kernel_id, service, opts) + return await rpc.call.start_service(kernel_id, service, opts, agent_id=self.agent_id) @agent_client_resilience.apply() - async def shutdown_service( - self, - kernel_id: str, - service: str, - ) -> None: + async def shutdown_service(self, kernel_id: str, service: str) -> None: """Shutdown a service on the agent.""" async with self._with_connection() as rpc: - await rpc.call.shutdown_service(kernel_id, service) + await rpc.call.shutdown_service(kernel_id, service, agent_id=self.agent_id) # File management methods @agent_client_resilience.apply() - async def upload_file( - self, - kernel_id: str, - filename: str, - payload: bytes, - ) -> Mapping[str, Any]: + async def upload_file(self, kernel_id: str, filename: str, payload: bytes) -> Mapping[str, Any]: """Upload a file to the agent.""" async with self._with_connection() as rpc: - return await rpc.call.upload_file(kernel_id, filename, payload) + return await rpc.call.upload_file(kernel_id, filename, payload, agent_id=self.agent_id) @agent_client_resilience.apply() - async def download_file( - self, - kernel_id: str, - filepath: str, - ) -> bytes: + async def download_file(self, kernel_id: str, filepath: str) -> bytes: """Download a file from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.download_file(kernel_id, filepath) + return await rpc.call.download_file(kernel_id, filepath, agent_id=self.agent_id) @agent_client_resilience.apply() - async def download_single( - self, - kernel_id: str, - filepath: str, - ) -> bytes: + async def download_single(self, kernel_id: str, filepath: str) -> bytes: """Download a single file from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.download_single(kernel_id, filepath) + return await rpc.call.download_single(kernel_id, filepath, agent_id=self.agent_id) @agent_client_resilience.apply() - async def list_files( - self, - kernel_id: str, - path: str, - ) -> Mapping[str, Any]: + async def list_files(self, kernel_id: str, path: str) -> Mapping[str, Any]: """List files on the agent.""" async with self._with_connection() as rpc: - return await rpc.call.list_files(kernel_id, path) + return await rpc.call.list_files(kernel_id, path, agent_id=self.agent_id) # Log management methods @agent_client_resilience.apply() async def get_logs(self, kernel_id: str) -> Mapping[str, str]: """Get logs from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.get_logs(kernel_id) + return await rpc.call.get_logs(kernel_id, agent_id=self.agent_id) # Image commit methods @agent_client_resilience.apply() @@ -367,29 +354,26 @@ async def commit( kwargs["extra_labels"] = extra_labels if filename is not None: kwargs["filename"] = filename + kwargs["agent_id"] = self.agent_id return await rpc.call.commit(kernel_id, email, **kwargs) @agent_client_resilience.apply() - async def push_image( - self, - image_ref: ImageRef, - registry: Any, - ) -> Mapping[str, Any]: + async def push_image(self, image_ref: ImageRef, registry: Any) -> Mapping[str, Any]: """Push an image from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.push_image(image_ref, registry) + return await rpc.call.push_image(image_ref, registry, agent_id=self.agent_id) # Scaling group management @agent_client_resilience.apply() async def update_scaling_group(self, scaling_group: str) -> None: """Update scaling group on the agent.""" async with self._with_connection() as rpc: - await rpc.call.update_scaling_group(scaling_group) + await rpc.call.update_scaling_group(scaling_group, self.agent_id) # Local configuration management @agent_client_resilience.apply() async def get_local_config(self) -> Mapping[str, str]: """Get local configuration from the agent.""" async with self._with_connection() as rpc: - return await rpc.call.get_local_config() + return await rpc.call.get_local_config(self.agent_id) diff --git a/tests/agent/conftest.py b/tests/agent/conftest.py index ff0a8d80b47..6e94da4e5ce 100644 --- a/tests/agent/conftest.py +++ b/tests/agent/conftest.py @@ -5,11 +5,13 @@ import subprocess from collections import defaultdict from pathlib import Path +from typing import AsyncIterator import aiodocker import pytest from ai.backend.agent.config.unified import AgentUnifiedConfig +from ai.backend.agent.runtime import AgentRuntime from ai.backend.common import config from ai.backend.common import validators as tx from ai.backend.common.arch import DEFAULT_IMAGE_ARCH @@ -271,3 +273,37 @@ async def _create_container(config): finally: if container is not None: await container.delete(force=True) + + +@pytest.fixture +async def agent_runtime( + local_config: AgentUnifiedConfig, + etcd, + mocker, +) -> AsyncIterator[AgentRuntime]: + """ + Create a real AgentRuntime instance for integration testing. + + This fixture provides a fully initialized AgentRuntime with: + - Real etcd client + - Real agent configuration + - Mocked stats and error monitors (external dependencies) + - Proper cleanup after tests + """ + from unittest.mock import Mock + + mock_stats_monitor = Mock() + mock_error_monitor = Mock() + + runtime = await AgentRuntime.create_runtime( + local_config, + etcd, + mock_stats_monitor, + mock_error_monitor, + None, + ) + + try: + yield runtime + finally: + await runtime.__aexit__(None, None, None) diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index 30d569acf44..79bdfead7f3 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -1,100 +1,256 @@ """ -TODO: rewrite +Tests for agent configuration and RPC server functionality. """ +from __future__ import annotations + import os -from unittest.mock import AsyncMock +from pathlib import Path +from typing import Callable +from unittest.mock import AsyncMock, Mock, patch import pytest +import tomlkit +from ai.backend.agent.agent import AbstractAgent from ai.backend.agent.config.unified import ( + AgentBackend, AgentConfig, AgentUnifiedConfig, ContainerConfig, EtcdConfig, ResourceConfig, + ScratchType, ) +from ai.backend.agent.dummy.agent import DummyAgent from ai.backend.agent.server import AgentRPCServer from ai.backend.common.typed_validators import HostPortPair - - -class Dummy: - pass - - -kgid = "kernel-gid" -kuid = "kernel-uid" -ctnr = "container" +from ai.backend.common.types import AgentId @pytest.fixture -async def arpcs_no_ainit(test_id, redis_container): - etcd = Dummy() +def mock_etcd() -> Mock: + """Create a mock etcd object with get_prefix method.""" + etcd = Mock() etcd.get_prefix = None + return etcd + - # Create a minimal pydantic config for testing - config = AgentUnifiedConfig( - agent=AgentConfig(backend="docker"), - container=ContainerConfig(scratch_type="hostdir"), +@pytest.fixture +def base_agent_config() -> AgentUnifiedConfig: + """Create a base agent configuration for testing.""" + return AgentUnifiedConfig( + agent=AgentConfig(backend=AgentBackend.DOCKER), + container=ContainerConfig(scratch_type=ScratchType.HOSTDIR), resource=ResourceConfig(), etcd=EtcdConfig(namespace="test", addr=HostPortPair(host="127.0.0.1", port=2379)), ) - ars = AgentRPCServer(etcd=etcd, local_config=config, skip_detect_manager=True) + +@pytest.fixture +async def agent_rpc_server( + mock_etcd: Mock, base_agent_config: AgentUnifiedConfig +) -> AgentRPCServer: + """Create an AgentRPCServer instance for testing without initialization.""" + ars = AgentRPCServer(etcd=mock_etcd, local_config=base_agent_config, skip_detect_manager=True) # Mock the runtime object to return the etcd client - runtime = Dummy() - runtime.get_etcd = lambda: etcd + runtime = Mock() + runtime.get_etcd = lambda agent_id=None: mock_etcd ars.runtime = runtime - yield ars - - -@pytest.mark.asyncio -async def test_read_agent_config_container_invalid01(arpcs_no_ainit, mocker): - inspect_mock = AsyncMock(return_value={"a": 1, "b": 2}) - mocker.patch.object(arpcs_no_ainit.etcd, "get_prefix", new=inspect_mock) - await arpcs_no_ainit.read_agent_config_container() - # Check that kernel-gid and kernel-uid are still at their default values (converted from -1) - assert ( - arpcs_no_ainit.local_config.container.kernel_gid.real == os.getgid() - ) # default value (os.getgid()) - assert ( - arpcs_no_ainit.local_config.container.kernel_uid.real == os.getuid() - ) # default value (os.getuid()) - - -@pytest.mark.asyncio -async def test_read_agent_config_container_invalid02(arpcs_no_ainit, mocker): - inspect_mock = AsyncMock(return_value={}) - mocker.patch.object(arpcs_no_ainit.etcd, "get_prefix", new=inspect_mock) - await arpcs_no_ainit.read_agent_config_container() - # Check that kernel-gid and kernel-uid are still at their default values (converted from -1) - assert ( - arpcs_no_ainit.local_config.container.kernel_gid.real == os.getgid() - ) # default value (os.getgid()) - assert ( - arpcs_no_ainit.local_config.container.kernel_uid.real == os.getuid() - ) # default value (os.getuid()) - - -@pytest.mark.asyncio -async def test_read_agent_config_container_1valid(arpcs_no_ainit, mocker): - inspect_mock = AsyncMock(return_value={kgid: 10}) - mocker.patch.object(arpcs_no_ainit.etcd, "get_prefix", new=inspect_mock) - await arpcs_no_ainit.read_agent_config_container() - - assert arpcs_no_ainit.local_config.container.kernel_gid.real == 10 - assert ( - arpcs_no_ainit.local_config.container.kernel_uid.real == os.getuid() - ) # default value (os.getuid()) - - -@pytest.mark.asyncio -async def test_read_agent_config_container_2valid(arpcs_no_ainit, mocker): - inspect_mock = AsyncMock(return_value={kgid: 10, kuid: 20}) - mocker.patch.object(arpcs_no_ainit.etcd, "get_prefix", new=inspect_mock) - await arpcs_no_ainit.read_agent_config_container() - - assert arpcs_no_ainit.local_config.container.kernel_gid.real == 10 - assert arpcs_no_ainit.local_config.container.kernel_uid.real == 20 + return ars + + +@pytest.fixture +def mock_agent_factory() -> Callable[[str, str], Mock]: + """Factory fixture to create mock agent instances with specified agent_id.""" + + def _create_agent(agent_id: str, scaling_group: str = "default") -> Mock: + mock_agent = Mock(spec=DummyAgent) + mock_agent.id = AgentId(agent_id) + mock_agent.local_config = AgentUnifiedConfig( + agent=AgentConfig(backend=AgentBackend.DUMMY, scaling_group=scaling_group, id=agent_id), + container=ContainerConfig(scratch_type=ScratchType.HOSTDIR), + resource=ResourceConfig(), + etcd=EtcdConfig(namespace="test", addr=HostPortPair(host="127.0.0.1", port=2379)), + ) + + # Use the real update_scaling_group method - capture agent in closure properly + def update_sg(sg: str) -> None: + AbstractAgent.update_scaling_group(mock_agent, sg) + + mock_agent.update_scaling_group = update_sg + return mock_agent + + return _create_agent + + +class TestAgentConfigReading: + """Tests for reading agent configuration from etcd.""" + + @pytest.mark.parametrize( + "etcd_response,expected_gid,expected_uid", + [ + # Invalid responses - should use defaults + ({"a": 1, "b": 2}, os.getgid(), os.getuid()), + ({}, os.getgid(), os.getuid()), + # Partial valid responses + ({"kernel-gid": 10}, 10, os.getuid()), + # Fully valid response + ({"kernel-gid": 10, "kernel-uid": 20}, 10, 20), + ], + ids=["invalid_keys", "empty", "only_gid", "both_valid"], + ) + @pytest.mark.asyncio + async def test_read_agent_config_container( + self, + agent_rpc_server: AgentRPCServer, + mocker, + etcd_response: dict, + expected_gid: int, + expected_uid: int, + ) -> None: + """Test reading container config from etcd with various responses.""" + inspect_mock = AsyncMock(return_value=etcd_response) + mocker.patch.object(agent_rpc_server.etcd, "get_prefix", new=inspect_mock) + + await agent_rpc_server.read_agent_config_container() + + assert agent_rpc_server.local_config.container.kernel_gid.real == expected_gid + assert agent_rpc_server.local_config.container.kernel_uid.real == expected_uid + + +class TestScalingGroupUpdates: + """Tests for updating scaling group configuration.""" + + def test_update_scaling_group_changes_config(self) -> None: + """Test that update_scaling_group modifies the in-memory config.""" + mock_agent = Mock(spec=DummyAgent) + mock_agent.local_config = AgentUnifiedConfig( + agent=AgentConfig(backend=AgentBackend.DUMMY, scaling_group="default"), + container=ContainerConfig(scratch_type=ScratchType.HOSTDIR), + resource=ResourceConfig(), + etcd=EtcdConfig(namespace="test", addr=HostPortPair(host="127.0.0.1", port=2379)), + ) + + AbstractAgent.update_scaling_group(mock_agent, "gpu") + + assert mock_agent.local_config.agent.scaling_group == "gpu" + + @pytest.mark.asyncio + async def test_update_scaling_group_persists_single_agent( + self, tmp_path: Path, mock_agent_factory: Callable[[str, str], Mock] + ) -> None: + """Test that scaling group updates persist to config file in single-agent mode.""" + config_file = tmp_path / "agent.toml" + config_file.write_text( + """[agent] +backend = "dummy" +scaling-group = "default" +id = "test-agent" + +[container] +scratch-type = "hostdir" + +[resource] + +[etcd] +namespace = "test" +addr = { host = "127.0.0.1", port = 2379 } +""" + ) + + # Create server with runtime + server = object.__new__(AgentRPCServer) + runtime = Mock() + runtime._default_agent_id = AgentId("test-agent") + + def get_agent_impl(agent_id=None): + if agent_id is None: + agent_id = runtime._default_agent_id + return runtime.agents[agent_id] + + runtime.get_agent = get_agent_impl + + mock_agent = mock_agent_factory("test-agent", "default") + runtime.agents = {AgentId("test-agent"): mock_agent} + server.runtime = runtime + + with patch("ai.backend.common.config.find_config_file", return_value=config_file): + await server.update_scaling_group.__wrapped__.__wrapped__(server, "gpu", None) # type: ignore[attr-defined] + + # Verify file was updated + with open(config_file) as f: + updated_config = tomlkit.load(f) + + assert updated_config["agent"]["scaling-group"] == "gpu" # type: ignore[index] + assert mock_agent.local_config.agent.scaling_group == "gpu" + + @pytest.mark.asyncio + async def test_update_scaling_group_persists_multi_agent( + self, tmp_path: Path, mock_agent_factory: Callable[[str, str], Mock] + ) -> None: + """Test that scaling group updates persist correctly in multi-agent mode.""" + config_file = tmp_path / "agent.toml" + config_file.write_text( + """[agent] +backend = "dummy" +scaling-group = "default" + +[container] +scratch-type = "hostdir" + +[resource] + +[etcd] +namespace = "test" +addr = { host = "127.0.0.1", port = 2379 } + +[[agents]] +[agents.agent] +id = "agent-1" +scaling-group = "default" + +[[agents]] +[agents.agent] +id = "agent-2" +scaling-group = "default" +""" + ) + + # Create server with runtime + server = object.__new__(AgentRPCServer) + runtime = Mock() + runtime._default_agent_id = AgentId("agent-1") + + def get_agent_impl(agent_id=None): + if agent_id is None: + agent_id = runtime._default_agent_id + return runtime.agents[agent_id] + + runtime.get_agent = get_agent_impl + + mock_agent1 = mock_agent_factory("agent-1", "default") + mock_agent2 = mock_agent_factory("agent-2", "default") + + runtime.agents = { + AgentId("agent-1"): mock_agent1, + AgentId("agent-2"): mock_agent2, + } + server.runtime = runtime + + with patch("ai.backend.common.config.find_config_file", return_value=config_file): + await server.update_scaling_group.__wrapped__.__wrapped__( # type: ignore[attr-defined] + server, "gpu", AgentId("agent-2") + ) # type: ignore[attr-defined] + + # Verify file was updated + with open(config_file) as f: + updated_config = tomlkit.load(f) + + # Only agent-2's scaling group should be updated + assert updated_config["agents"][1]["agent"]["scaling-group"] == "gpu" # type: ignore[index] + assert updated_config["agents"][0]["agent"]["scaling-group"] == "default" # type: ignore[index] + assert mock_agent2.local_config.agent.scaling_group == "gpu" + assert mock_agent1.local_config.agent.scaling_group == "default" diff --git a/tests/agent/test_agent_runtime.py b/tests/agent/test_agent_runtime.py new file mode 100644 index 00000000000..11ac64c2376 --- /dev/null +++ b/tests/agent/test_agent_runtime.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from ai.backend.agent.errors.runtime import AgentIdNotFoundError +from ai.backend.agent.runtime import AgentRuntime +from ai.backend.common.types import AgentId + +if TYPE_CHECKING: + from ai.backend.agent.config.unified import AgentUnifiedConfig + + +class TestAgentRuntimeSingleAgent: + """Test AgentRuntime with single agent configuration.""" + + @pytest.mark.asyncio + async def test_get_agent_returns_default_when_id_is_none( + self, + agent_runtime: AgentRuntime, + ) -> None: + """ + When agent_id is None, get_agent() should return the default agent. + """ + agent = agent_runtime.get_agent(None) + + assert agent is not None + assert agent.id is not None + # In single agent mode, the default agent should be the only agent + assert agent is agent_runtime.get_agent(agent.id) + + @pytest.mark.asyncio + async def test_get_agent_returns_agent_by_id( + self, + agent_runtime: AgentRuntime, + ) -> None: + """ + get_agent() should return the correct agent when given a specific ID. + """ + # Get the default agent's ID + default_agent = agent_runtime.get_agent(None) + agent_id = default_agent.id + + # Retrieve by ID + agent = agent_runtime.get_agent(agent_id) + + assert agent is default_agent + assert agent.id == agent_id + + @pytest.mark.asyncio + async def test_get_agent_raises_error_for_nonexistent_id( + self, + agent_runtime: AgentRuntime, + ) -> None: + """ + get_agent() should raise AgentIdNotFoundError for non-existent agent IDs. + """ + nonexistent_id = AgentId("nonexistent-agent-id") + + with pytest.raises(AgentIdNotFoundError) as exc_info: + agent_runtime.get_agent(nonexistent_id) + + # Verify error message is helpful + assert str(nonexistent_id) in str(exc_info.value) + assert "not found" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_get_agents_returns_all_agents( + self, + agent_runtime: AgentRuntime, + ) -> None: + """ + get_agents() should return a list of all agents. + """ + agents = agent_runtime.get_agents() + + assert isinstance(agents, list) + assert len(agents) == 1 # Single agent mode + + # Verify the agent is accessible + for agent in agents: + assert agent.id is not None + assert agent is agent_runtime.get_agent(agent.id) + + +class TestAgentRuntimeInitialization: + """Test AgentRuntime initialization and cleanup.""" + + @pytest.mark.asyncio + async def test_runtime_creates_agents_from_config( + self, + local_config: AgentUnifiedConfig, + etcd, + mocker, + ) -> None: + """ + AgentRuntime.create_agents() should initialize agents from config. + """ + from unittest.mock import Mock + + mock_stats_monitor = Mock() + mock_error_monitor = Mock() + + runtime = await AgentRuntime.create_runtime( + local_config, + etcd, + mock_stats_monitor, + mock_error_monitor, + None, + ) + + try: + # Verify agents were created + agents = runtime.get_agents() + assert len(agents) > 0 + + # Verify default agent is set + default_agent = runtime.get_agent(None) + assert default_agent is not None + + # Verify all agents have valid IDs + for agent in agents: + assert agent.id is not None + finally: + await runtime.__aexit__(None, None, None) + + @pytest.mark.asyncio + async def test_runtime_shutdown_cleans_up_agents( + self, + local_config: AgentUnifiedConfig, + etcd, + mocker, + ) -> None: + """ + AgentRuntime.shutdown() should properly clean up all agents. + """ + from unittest.mock import Mock + + mock_stats_monitor = Mock() + mock_error_monitor = Mock() + + runtime = await AgentRuntime.create_runtime( + local_config, + etcd, + mock_stats_monitor, + mock_error_monitor, + None, + ) + + # Verify agents exist before shutdown + agents = runtime.get_agents() + assert len(agents) > 0 + + # Shutdown + await runtime.__aexit__(None, None, None) + + # After shutdown, the runtime should be in a clean state + # (Specific behavior depends on implementation - adjust as needed) + # For now, we just verify shutdown doesn't raise errors diff --git a/tests/agent/test_config_validation.py b/tests/agent/test_config_validation.py index 84d15d71dda..089dfdabc9d 100644 --- a/tests/agent/test_config_validation.py +++ b/tests/agent/test_config_validation.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +from decimal import Decimal from pathlib import Path from typing import Any, Protocol from unittest.mock import patch @@ -19,11 +20,13 @@ ContainerSandboxType, CoreDumpConfig, DebugConfig, + ResourceAllocationMode, ResourceConfig, ScratchType, ) from ai.backend.agent.stats import StatModes from ai.backend.common.typed_validators import HostPortPair +from ai.backend.common.types import SlotName from ai.backend.logging.config import ( LoggingConfig, LogLevel, @@ -751,19 +754,34 @@ def test_agent_overrides_resource_config( ) -> None: raw_config = { **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, "agents": [ { "agent": {"id": "agent-1"}, - "resource": {"reserved-cpu": 2}, + "resource": { + "cpu": 2, + "mem": "8G", + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 1, + "mem": "8G", + }, }, - {"agent": {"id": "agent-2"}}, ], } config = AgentUnifiedConfig.model_validate(raw_config) agent_configs = config.get_agent_configs() - assert agent_configs[0].resource.reserved_cpu == 2 - assert agent_configs[1].resource.reserved_cpu == 1 + assert agent_configs[0].resource.allocations is not None + assert agent_configs[0].resource.allocations.cpu == 2 + assert agent_configs[1].resource.allocations is not None + assert agent_configs[1].resource.allocations.cpu == 1 def test_agent_partial_override_preserves_other_fields( self, @@ -793,33 +811,6 @@ def test_agent_partial_override_preserves_other_fields( assert agent_configs[0].agent.kernel_creation_concurrency == 8 assert agent_configs[0].agent.allow_compute_plugins == {"plugin1", "plugin2"} - def test_agent_with_different_plugins( - self, - default_raw_config: RawConfigT, - ) -> None: - raw_config = { - **default_raw_config, - "agents": [ - { - "agent": { - "id": "agent-1", - "allow-compute-plugins": {"plugin1"}, - } - }, - { - "agent": { - "id": "agent-2", - "allow-compute-plugins": {"plugin2"}, - } - }, - ], - } - config = AgentUnifiedConfig.model_validate(raw_config) - - agent_configs = config.get_agent_configs() - assert agent_configs[0].agent.allow_compute_plugins == {"plugin1"} - assert agent_configs[1].agent.allow_compute_plugins == {"plugin2"} - def test_multiple_agents_validate_backend_specific_config( self, default_raw_config: RawConfigT, @@ -852,6 +843,14 @@ def test_multiple_agents_with_mixed_overrides( ) -> None: raw_config = { **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + "allocations": { + "cpu": 1, + "mem": "8G", + }, + }, "agents": [ { "agent": { @@ -859,7 +858,10 @@ def test_multiple_agents_with_mixed_overrides( "kernel-creation-concurrency": 8, }, "container": {"port-range": [31000, 32000]}, - "resource": {"reserved-cpu": 2}, + "resource": { + "cpu": 2, + "mem": "8G", + }, }, { "agent": { @@ -877,9 +879,11 @@ def test_multiple_agents_with_mixed_overrides( agent_configs = config.get_agent_configs() assert len(agent_configs) == 3 assert agent_configs[0].agent.kernel_creation_concurrency == 8 - assert agent_configs[0].resource.reserved_cpu == 2 + assert agent_configs[0].resource.allocations is not None + assert agent_configs[0].resource.allocations.cpu == 2 assert agent_configs[1].agent.kernel_creation_concurrency == 4 - assert agent_configs[1].resource.reserved_cpu == 1 + assert agent_configs[1].resource.allocations is not None + assert agent_configs[1].resource.allocations.cpu == 1 assert agent_configs[2].agent.kernel_creation_concurrency == 4 def test_agent_with_only_id_inherits_all_fields( @@ -944,7 +948,6 @@ def test_agent_with_empty_resource_override_inherits_global( "agents": [ { "agent": {"id": "agent-1"}, - "resource": {}, }, {"agent": {"id": "agent-2"}}, ], @@ -955,6 +958,31 @@ def test_agent_with_empty_resource_override_inherits_global( assert agent_configs[0].resource.reserved_cpu == 2 assert agent_configs[1].resource.reserved_cpu == 2 + def test_agent_with_empty_resource_dict_is_rejected( + self, + default_raw_config: RawConfigT, + ) -> None: + """Empty resource dict should be rejected - omit the field instead.""" + raw_config = { + **default_raw_config, + "resource": { + "reserved-cpu": 2, + "reserved-mem": "2G", + "reserved-disk": "10G", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": {}, # Empty dict should be invalid + }, + {"agent": {"id": "agent-2"}}, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "Field required" in str(exc_info.value) + def test_overridable_agent_config_defaults_when_not_in_global( self, default_raw_config: RawConfigT, @@ -966,7 +994,7 @@ def test_overridable_agent_config_defaults_when_not_in_global( { "agent": { "id": "agent-2", - "allow-compute-plugins": {"plugin1"}, + "force-terminate-abusing-containers": True, }, }, ], @@ -977,10 +1005,8 @@ def test_overridable_agent_config_defaults_when_not_in_global( assert agent_configs[0].agent.agent_sock_port == 6007 assert agent_configs[0].agent.force_terminate_abusing_containers is False assert agent_configs[0].agent.use_experimental_redis_event_dispatcher is False - assert agent_configs[0].agent.allow_compute_plugins is None - assert agent_configs[0].agent.block_compute_plugins is None - assert agent_configs[1].agent.allow_compute_plugins == {"plugin1"} + assert agent_configs[1].agent.force_terminate_abusing_containers is True assert agent_configs[1].agent.agent_sock_port == 6007 def test_overridable_container_config_defaults_when_not_in_global( @@ -1009,3 +1035,592 @@ def test_overridable_container_config_defaults_when_not_in_global( assert agent_configs[1].container.port_range == (32000, 33000) assert agent_configs[1].container.scratch_type == ScratchType.HOSTDIR + + def test_agent_ids_must_be_unique(self, default_raw_config: RawConfigT) -> None: + raw_config = { + **default_raw_config, + "agents": [ + {"agent": {"id": "agent-1"}}, + {"agent": {"id": "agent-1"}}, + ], + } + + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "duplicate" in str(exc_info.value).lower() + + def test_different_scaling_groups_per_agent(self, default_raw_config: RawConfigT) -> None: + raw_config = { + **default_raw_config, + "agents": [ + { + "agent": { + "id": "agent-1", + "scaling-group": "default", + } + }, + { + "agent": { + "id": "agent-2", + "scaling-group": "gpu", + } + }, + ], + } + config = AgentUnifiedConfig.model_validate(raw_config) + + agent_configs = config.get_agent_configs() + assert agent_configs[0].agent.scaling_group == "default" + assert agent_configs[1].agent.scaling_group == "gpu" + + +class TestResourceAllocationModes: + """Test the new resource allocation modes: SHARED, AUTO_SPLIT, and MANUAL.""" + + @pytest.fixture + def default_raw_config(self) -> RawConfigT: + return { + "agent": { + "backend": AgentBackend.DOCKER, + "rpc-listen-addr": HostPortPair(host="127.0.0.1", port=6001), + }, + "container": { + "scratch-type": ScratchType.HOSTDIR, + "port-range": [30000, 31000], + }, + "resource": { + "reserved-cpu": 2, + "reserved-mem": "2G", + "reserved-disk": "10G", + }, + "etcd": { + "namespace": "test", + "addr": HostPortPair(host="127.0.0.1", port=2379), + }, + } + + def test_allocation_mode_defaults_to_shared( + self, + default_raw_config: RawConfigT, + ) -> None: + config = AgentUnifiedConfig.model_validate(default_raw_config) + assert config.resource.allocation_mode == ResourceAllocationMode.SHARED + + def test_shared_mode_single_agent( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "shared", + }, + } + config = AgentUnifiedConfig.model_validate(raw_config) + assert config.resource.allocation_mode == ResourceAllocationMode.SHARED + + def test_shared_mode_multiple_agents_no_allocations( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "shared", + }, + "agents": [ + {"agent": {"id": "agent-1"}}, + {"agent": {"id": "agent-2"}}, + ], + } + config = AgentUnifiedConfig.model_validate(raw_config) + assert config.resource.allocation_mode == ResourceAllocationMode.SHARED + assert len(config.get_agent_configs()) == 2 + + def test_shared_mode_rejects_allocated_cpu( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "shared", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "8G", + }, + }, + {"agent": {"id": "agent-2"}}, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "must not specify manual resource" in str(exc_info.value) + + def test_shared_mode_rejects_allocated_devices( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "shared", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "8G", + "devices": { + SlotName("cuda.mem"): 0.5, + }, + }, + }, + {"agent": {"id": "agent-2"}}, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "must not specify manual resource" in str(exc_info.value) + + def test_auto_split_mode_multiple_agents( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "auto-split", + }, + "agents": [ + {"agent": {"id": "agent-1"}}, + {"agent": {"id": "agent-2"}}, + ], + } + config = AgentUnifiedConfig.model_validate(raw_config) + assert config.resource.allocation_mode == ResourceAllocationMode.AUTO_SPLIT + + def test_auto_split_mode_rejects_allocated_cpu( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "auto-split", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "8G", + }, + }, + {"agent": {"id": "agent-2"}}, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "must not specify manual resource" in str(exc_info.value) + + def test_manual_mode_requires_allocated_cpu_mem_disk( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + # Missing mem - this should fail because mem is required + }, + }, + {"agent": {"id": "agent-2"}}, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "Field required" in str(exc_info.value) + + def test_manual_mode_valid_configuration( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 4, + "mem": "32G", + }, + }, + ], + } + config = AgentUnifiedConfig.model_validate(raw_config) + assert config.resource.allocation_mode == ResourceAllocationMode.MANUAL + agent_configs = config.get_agent_configs() + assert agent_configs[0].resource.allocations is not None + assert agent_configs[0].resource.allocations.cpu == 8 + assert agent_configs[1].resource.allocations is not None + assert agent_configs[1].resource.allocations.cpu == 4 + + def test_manual_mode_with_allocated_devices( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("cuda.mem"): 0.3, + }, + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("cuda.mem"): 0.7, + }, + }, + }, + ], + } + config = AgentUnifiedConfig.model_validate(raw_config) + agent_configs = config.get_agent_configs() + assert agent_configs[0].resource.allocations is not None + assert agent_configs[0].resource.allocations.devices[SlotName("cuda.mem")] == Decimal("0.3") + assert agent_configs[1].resource.allocations is not None + assert agent_configs[1].resource.allocations.devices[SlotName("cuda.mem")] == Decimal("0.7") + + def test_manual_mode_agents_with_same_slots_allowed( + self, + default_raw_config: RawConfigT, + ) -> None: + """Test that agents with the same slot names are allowed in MANUAL mode.""" + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("cuda.mem"): 0.3, + SlotName("cuda.shares"): 1.0, + }, + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 4, + "mem": "16G", + "devices": { + SlotName("cuda.mem"): 0.7, + SlotName("cuda.shares"): 2.0, + }, + }, + }, + ], + } + config = AgentUnifiedConfig.model_validate(raw_config) + agent_configs = config.get_agent_configs() + + # Check that both agents have the same slot names + assert agent_configs[0].resource.allocations is not None + assert set(agent_configs[0].resource.allocations.devices.keys()) == { + SlotName("cuda.mem"), + SlotName("cuda.shares"), + } + assert agent_configs[1].resource.allocations is not None + assert set(agent_configs[1].resource.allocations.devices.keys()) == { + SlotName("cuda.mem"), + SlotName("cuda.shares"), + } + + # Check that values can differ + assert agent_configs[0].resource.allocations.devices[SlotName("cuda.mem")] == Decimal("0.3") + assert agent_configs[1].resource.allocations.devices[SlotName("cuda.mem")] == Decimal("0.7") + + def test_manual_mode_agents_with_different_slots_rejected( + self, + default_raw_config: RawConfigT, + ) -> None: + """Test that agents with different slot names are rejected in MANUAL mode.""" + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("cuda.mem"): 0.7, + }, + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("cuda.shares"): 0.6, + }, + }, + }, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "All agents must have the same slots defined" in str(exc_info.value) + + def test_manual_mode_agents_with_subset_of_slots_rejected( + self, + default_raw_config: RawConfigT, + ) -> None: + """Test that agents where one has a subset of slots are rejected in MANUAL mode.""" + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("cuda.mem"): 0.5, + SlotName("cuda.shares"): 1.0, + }, + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("cuda.mem"): 0.5, + }, + }, + }, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "All agents must have the same slots defined" in str(exc_info.value) + + def test_manual_mode_agents_with_empty_devices_on_some_agents( + self, + default_raw_config: RawConfigT, + ) -> None: + """Test that agents with empty allocated_devices on some agents are rejected.""" + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("cuda.mem"): 0.5, + }, + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 8, + "mem": "32G", + # No devices specified + }, + }, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "All agents must have the same slots defined" in str(exc_info.value) + + def test_manual_mode_agents_all_with_empty_devices_allowed( + self, + default_raw_config: RawConfigT, + ) -> None: + """Test that all agents with empty allocated_devices are allowed.""" + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + # No devices specified + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 4, + "mem": "16G", + # No devices specified + }, + }, + ], + } + config = AgentUnifiedConfig.model_validate(raw_config) + agent_configs = config.get_agent_configs() + + # Both should have empty allocated_devices + assert agent_configs[0].resource.allocations is not None + assert agent_configs[0].resource.allocations.devices == {} + assert agent_configs[1].resource.allocations is not None + assert agent_configs[1].resource.allocations.devices == {} + + def test_allocated_devices_parses_decimal_strings( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("foo"): "0.25", # String value + }, + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("foo"): 0.75, # Numeric value + }, + }, + }, + ], + } + config = AgentUnifiedConfig.model_validate(raw_config) + agent_configs = config.get_agent_configs() + assert agent_configs[0].resource.allocations is not None + assert float(agent_configs[0].resource.allocations.devices[SlotName("foo")]) == 0.25 + assert agent_configs[1].resource.allocations is not None + assert float(agent_configs[1].resource.allocations.devices[SlotName("foo")]) == 0.75 + + def test_allocated_devices_rejects_negative_values( + self, + default_raw_config: RawConfigT, + ) -> None: + raw_config = { + **default_raw_config, + "resource": { + **default_raw_config["resource"], + "allocation-mode": "manual", + }, + "agents": [ + { + "agent": {"id": "agent-1"}, + "resource": { + "cpu": 8, + "mem": "32G", + "devices": { + SlotName("foo"): "-1", + }, + }, + }, + { + "agent": {"id": "agent-2"}, + "resource": { + "cpu": 8, + "mem": "32G", + }, + }, + ], + } + with pytest.raises(ValidationError) as exc_info: + AgentUnifiedConfig.model_validate(raw_config) + + assert "must not be a negative value" in str(exc_info.value) diff --git a/tests/manager/test_agent_client.py b/tests/manager/test_agent_client.py new file mode 100644 index 00000000000..9fd3595f34b --- /dev/null +++ b/tests/manager/test_agent_client.py @@ -0,0 +1,186 @@ +from unittest.mock import AsyncMock, MagicMock, Mock + +import pytest + +from ai.backend.common.types import AgentId, ClusterInfo, ImageConfig, KernelCreationConfig +from ai.backend.manager.clients.agent.client import AgentClient + + +@pytest.fixture +def mock_agent_cache() -> tuple[MagicMock, MagicMock]: + cache = MagicMock() + mock_rpc = MagicMock() + mock_rpc.call = MagicMock() + + cache.rpc_context.return_value.__aenter__ = AsyncMock(return_value=mock_rpc) + cache.rpc_context.return_value.__aexit__ = AsyncMock() + + return cache, mock_rpc + + +class TestAgentClientPassesAgentId: + @pytest.mark.asyncio + async def test_gather_hwinfo_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("test-agent")) + + mock_rpc.call.gather_hwinfo = AsyncMock(return_value={}) + + await client.gather_hwinfo() + + mock_rpc.call.gather_hwinfo.assert_called_once_with(agent_id=AgentId("test-agent")) + + @pytest.mark.asyncio + async def test_scan_gpu_alloc_map_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("gpu-agent")) + + mock_rpc.call.scan_gpu_alloc_map = AsyncMock(return_value={}) + + await client.scan_gpu_alloc_map() + + mock_rpc.call.scan_gpu_alloc_map.assert_called_once_with(agent_id=AgentId("gpu-agent")) + + @pytest.mark.asyncio + async def test_create_kernels_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("agent-2")) + + mock_rpc.call.create_kernels = AsyncMock(return_value={}) + kernel_configs = [Mock(spec=KernelCreationConfig)] + cluster_info = Mock(spec=ClusterInfo) + + await client.create_kernels("session-1", ["kernel-1"], kernel_configs, cluster_info, {}) # type: ignore[arg-type] + + args, kwargs = mock_rpc.call.create_kernels.call_args + assert kwargs["agent_id"] == AgentId("agent-2") + + @pytest.mark.asyncio + async def test_destroy_kernel_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("agent-1")) + + mock_rpc.call.destroy_kernel = AsyncMock() + + await client.destroy_kernel("kernel-id", "session-id", "test-reason") + + args, kwargs = mock_rpc.call.destroy_kernel.call_args + assert kwargs["agent_id"] == AgentId("agent-1") + + @pytest.mark.asyncio + async def test_execute_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("exec-agent")) + + mock_rpc.call.execute = AsyncMock(return_value={}) + + await client.execute( + "session-1", "kernel-1", 1, "run-1", "query", "print('hello')", {}, None + ) + + args, kwargs = mock_rpc.call.execute.call_args + assert kwargs["agent_id"] == AgentId("exec-agent") + + @pytest.mark.asyncio + async def test_check_and_pull_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("image-agent")) + + mock_rpc.call.check_and_pull = AsyncMock(return_value={}) + image_configs = {"python": Mock(spec=ImageConfig)} + + await client.check_and_pull(image_configs) + + mock_rpc.call.check_and_pull.assert_called_once_with( + image_configs, agent_id=AgentId("image-agent") + ) + + @pytest.mark.asyncio + async def test_create_local_network_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("network-agent")) + + mock_rpc.call.create_local_network = AsyncMock() + + await client.create_local_network("test-network") + + mock_rpc.call.create_local_network.assert_called_once_with( + "test-network", agent_id=AgentId("network-agent") + ) + + @pytest.mark.asyncio + async def test_assign_port_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("port-agent")) + + mock_rpc.call.assign_port = AsyncMock(return_value=30000) + + result = await client.assign_port() + + assert result == 30000 + mock_rpc.call.assign_port.assert_called_once_with(agent_id=AgentId("port-agent")) + + @pytest.mark.asyncio + async def test_upload_file_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("file-agent")) + + mock_rpc.call.upload_file = AsyncMock(return_value={}) + + await client.upload_file("kernel-1", "test.py", b"data") + + mock_rpc.call.upload_file.assert_called_once_with( + "kernel-1", "test.py", b"data", agent_id=AgentId("file-agent") + ) + + @pytest.mark.asyncio + async def test_start_service_passes_agent_id( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + client = AgentClient(cache, AgentId("service-agent")) + + mock_rpc.call.start_service = AsyncMock(return_value={}) + + await client.start_service("kernel-1", "jupyter", {}) + + mock_rpc.call.start_service.assert_called_once_with( + "kernel-1", "jupyter", {}, agent_id=AgentId("service-agent") + ) + + @pytest.mark.asyncio + async def test_different_agents_use_different_ids( + self, mock_agent_cache: tuple[MagicMock, MagicMock] + ) -> None: + cache, mock_rpc = mock_agent_cache + + client1 = AgentClient(cache, AgentId("agent-1")) + client2 = AgentClient(cache, AgentId("agent-2")) + + mock_rpc.call.gather_hwinfo = AsyncMock(return_value={}) + + await client1.gather_hwinfo() + await client2.gather_hwinfo() + + calls = mock_rpc.call.gather_hwinfo.call_args_list + assert len(calls) == 2 + assert calls[0].kwargs["agent_id"] == AgentId("agent-1") + assert calls[1].kwargs["agent_id"] == AgentId("agent-2")