Skip to content

Commit daee31e

Browse files
committed
feat(BA-3026): Extract agent common resources to AgentRuntime
This change introduces AgentRuntime, which contains common resources shared by all agents. Currently the class is minimal but future multi agent changes will make more use of the class.
1 parent 6ebc74f commit daee31e

File tree

8 files changed

+236
-145
lines changed

8 files changed

+236
-145
lines changed

changes/6728.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Extract agent common resources to AgentRuntime
Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +0,0 @@
1-
from typing import Type
2-
3-
from ..agent import AbstractAgent
4-
from .agent import DockerAgent
5-
6-
7-
def get_agent_cls() -> Type[AbstractAgent]:
8-
return DockerAgent

src/ai/backend/agent/docker/agent.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,14 @@
119119
known_slot_types,
120120
)
121121
from ..scratch import create_loop_filesystem, destroy_loop_filesystem
122-
from ..server import get_extra_volumes
123122
from ..types import (
124123
AgentEventData,
125124
Container,
126125
KernelOwnershipData,
127126
LifecycleEvent,
128127
MountInfo,
129128
Port,
129+
VolumeInfo,
130130
)
131131
from ..utils import (
132132
closing_async,
@@ -161,6 +161,49 @@
161161
2.39: "ubuntu24.04",
162162
}
163163

164+
deeplearning_image_keys = {
165+
"tensorflow",
166+
"caffe",
167+
"keras",
168+
"torch",
169+
"mxnet",
170+
"theano",
171+
}
172+
173+
deeplearning_sample_volume = VolumeInfo(
174+
"deeplearning-samples",
175+
"/home/work/samples",
176+
"ro",
177+
)
178+
179+
180+
async def get_extra_volumes(docker, lang):
181+
avail_volumes = (await docker.volumes.list())["Volumes"]
182+
if not avail_volumes:
183+
return []
184+
avail_volume_names = set(v["Name"] for v in avail_volumes)
185+
186+
# deeplearning specialization
187+
# TODO: extract as config
188+
volume_list = []
189+
for k in deeplearning_image_keys:
190+
if k in lang:
191+
volume_list.append(deeplearning_sample_volume)
192+
break
193+
194+
# Mount only actually existing volumes
195+
mount_list = []
196+
for vol in volume_list:
197+
if vol.name in avail_volume_names:
198+
mount_list.append(vol)
199+
else:
200+
log.info(
201+
"skipped attaching extra volume {0} to a kernel based on image {1}",
202+
vol.name,
203+
lang,
204+
)
205+
return mount_list
206+
164207

165208
def container_from_docker_container(src: DockerContainer) -> Container:
166209
ports = []
Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +0,0 @@
1-
from typing import Type
2-
3-
from ..agent import AbstractAgent
4-
from .agent import DummyAgent
5-
6-
7-
def get_agent_cls() -> Type[AbstractAgent]:
8-
return DummyAgent
Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +0,0 @@
1-
from typing import Type
2-
3-
from ..agent import AbstractAgent
4-
from .agent import KubernetesAgent
5-
6-
7-
def get_agent_cls() -> Type[AbstractAgent]:
8-
return KubernetesAgent

src/ai/backend/agent/runtime.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import signal
2+
from typing import Optional, Type
3+
4+
from ai.backend.agent.agent import AbstractAgent
5+
from ai.backend.agent.config.unified import AgentUnifiedConfig
6+
from ai.backend.agent.monitor import AgentErrorPluginContext, AgentStatsPluginContext
7+
from ai.backend.agent.types import AgentBackend
8+
from ai.backend.common.auth import PublicKey
9+
from ai.backend.common.etcd import AsyncEtcd
10+
from ai.backend.common.types import aobject
11+
12+
13+
class AgentRuntime(aobject):
14+
local_config: AgentUnifiedConfig
15+
agent: AbstractAgent
16+
17+
_stop_signal: signal.Signals
18+
19+
def __init__(
20+
self,
21+
local_config: AgentUnifiedConfig,
22+
etcd: AsyncEtcd,
23+
stats_monitor: AgentStatsPluginContext,
24+
error_monitor: AgentErrorPluginContext,
25+
agent_public_key: Optional[PublicKey],
26+
) -> None:
27+
self.local_config = local_config
28+
29+
self._stop_signal = signal.SIGTERM
30+
31+
self.etcd = etcd
32+
self.stats_monitor = stats_monitor
33+
self.error_monitor = error_monitor
34+
self.agent_public_key = agent_public_key
35+
36+
async def __ainit__(self) -> None:
37+
agent_cls = self._get_agent_cls()
38+
self.agent = await agent_cls.new(
39+
self.etcd,
40+
self.local_config,
41+
stats_monitor=self.stats_monitor,
42+
error_monitor=self.error_monitor,
43+
agent_public_key=self.agent_public_key,
44+
)
45+
46+
async def __aexit__(self, *exc_info) -> None:
47+
await self.agent.shutdown(self._stop_signal)
48+
49+
def get_agent(self) -> AbstractAgent:
50+
return self.agent
51+
52+
def get_etcd(self) -> AsyncEtcd:
53+
return self.etcd
54+
55+
def mark_stop_signal(self, stop_signal: signal.Signals) -> None:
56+
self._stop_signal = stop_signal
57+
58+
def _get_agent_cls(self) -> Type[AbstractAgent]:
59+
match self.local_config.agent_common.backend:
60+
case AgentBackend.DOCKER:
61+
from ai.backend.agent.docker.agent import DockerAgent
62+
63+
return DockerAgent
64+
case AgentBackend.KUBERNETES:
65+
from ai.backend.agent.kubernetes.agent import KubernetesAgent
66+
67+
return KubernetesAgent
68+
case AgentBackend.DUMMY:
69+
from ai.backend.agent.dummy.agent import DummyAgent
70+
71+
return DummyAgent

0 commit comments

Comments
 (0)