4949from setproctitle import setproctitle
5050from zmq .auth .certs import load_certificate
5151
52+ from ai .backend .agent .etcd import EtcdClientRegistry
5253from ai .backend .agent .metrics .metric import RPCMetricObserver
5354from ai .backend .agent .resources import scan_gpu_alloc_map
5455from ai .backend .common import config , identity , msgpack , utils
9697 ServiceMetadata ,
9798)
9899from ai .backend .common .types import (
100+ AgentId ,
99101 ClusterInfo ,
100102 CommitStatus ,
101103 ContainerId ,
@@ -309,6 +311,7 @@ class AgentRPCServer(aobject):
309311
310312 loop : asyncio .AbstractEventLoop
311313 agent : AbstractAgent
314+ etcd_client_registry : EtcdClientRegistry
312315 rpc_server : Peer
313316 rpc_addr : str
314317 agent_addr : str
@@ -318,17 +321,19 @@ class AgentRPCServer(aobject):
318321
319322 def __init__ (
320323 self ,
321- etcd : AsyncEtcd ,
324+ etcd_client_registry : EtcdClientRegistry ,
322325 local_config : AgentUnifiedConfig ,
323326 * ,
324327 skip_detect_manager : bool = False ,
325328 ) -> None :
326329 self .loop = current_loop ()
327- self .etcd = etcd
330+ self .etcd_client_registry = etcd_client_registry
328331 self .local_config = local_config
329332 self .skip_detect_manager = skip_detect_manager
330333 self ._stop_signal = signal .SIGTERM
331334
335+ self .etcd_client_registry .prefill_clients (self .local_config .agent_configs )
336+
332337 async def __ainit__ (self ) -> None :
333338 # Start serving requests.
334339 await self .update_status ("starting" )
@@ -339,11 +344,12 @@ async def __ainit__(self) -> None:
339344 await self .read_agent_config ()
340345 await self .read_agent_config_container ()
341346
347+ global_etcd = self .etcd_client_registry .global_etcd
342348 self .stats_monitor = AgentStatsPluginContext (
343- self . etcd , self .local_config .model_dump (by_alias = True )
349+ global_etcd , self .local_config .model_dump (by_alias = True )
344350 )
345351 self .error_monitor = AgentErrorPluginContext (
346- self . etcd , self .local_config .model_dump (by_alias = True )
352+ global_etcd , self .local_config .model_dump (by_alias = True )
347353 )
348354 await self .stats_monitor .init ()
349355 await self .error_monitor .init ()
@@ -380,7 +386,7 @@ async def __ainit__(self) -> None:
380386 backend = self .local_config .agent_common .backend
381387 agent_mod = importlib .import_module (f"ai.backend.agent.{ backend .value } " )
382388 self .agent = await agent_mod .get_agent_cls ().new ( # type: ignore
383- self .etcd ,
389+ self .etcd_client_registry . get_client ( AgentId ( self . local_config . agent . id )) ,
384390 self .local_config ,
385391 stats_monitor = self .stats_monitor ,
386392 error_monitor = self .error_monitor ,
@@ -422,13 +428,13 @@ async def _debug_server_task():
422428
423429 self .debug_server_task = asyncio .create_task (_debug_server_task ())
424430
425- await self .etcd .put ("ip" , rpc_addr .host , scope = ConfigScopes .NODE )
431+ await self .agent . etcd .put ("ip" , rpc_addr .host , scope = ConfigScopes .NODE )
426432
427433 watcher_port = utils .nmget (
428434 self .local_config .model_dump (), "watcher.service-addr.port" , None
429435 )
430436 if watcher_port is not None :
431- await self .etcd .put ("watcher_port" , watcher_port , scope = ConfigScopes .NODE )
437+ await self .agent . etcd .put ("watcher_port" , watcher_port , scope = ConfigScopes .NODE )
432438
433439 await self .update_status ("running" )
434440
@@ -472,10 +478,11 @@ def _ensure_serializable(o) -> Any:
472478
473479 async def detect_manager (self ):
474480 log .info ("detecting the manager..." )
475- manager_instances = await self .etcd .get_prefix ("nodes/manager" )
481+ global_etcd = self .etcd_client_registry .global_etcd
482+ manager_instances = await global_etcd .get_prefix ("nodes/manager" )
476483 if not manager_instances :
477484 log .warning ("watching etcd to wait for the manager being available" )
478- async with aclosing (self . etcd . watch_prefix ("nodes/manager" )) as agen :
485+ async with aclosing (global_etcd . watch_prefix ("nodes/manager" )) as agen : # type: ignore
479486 async for ev in agen :
480487 match ev :
481488 case QueueSentinel .CLOSED | QueueSentinel .TIMEOUT :
@@ -486,9 +493,11 @@ async def detect_manager(self):
486493 log .info ("detected at least one manager running" )
487494
488495 async def read_agent_config (self ):
496+ global_etcd = self .etcd_client_registry .global_etcd
497+
489498 # Fill up Redis configs from etcd and store as separate attributes
490499 self ._redis_config = config .redis_config_iv .check (
491- await self . etcd .get_prefix ("config/redis" ),
500+ await global_etcd .get_prefix ("config/redis" ),
492501 )
493502 log .info ("configured redis: {0}" , self ._redis_config )
494503
@@ -506,7 +515,7 @@ async def read_agent_config(self):
506515 # Fill up vfolder configs from etcd and store as separate attributes
507516 # TODO: Integrate vfolder_config into local_config
508517 self ._vfolder_config = config .vfolder_config_iv .check (
509- await self . etcd .get_prefix ("volumes" ),
518+ await global_etcd .get_prefix ("volumes" ),
510519 )
511520 if self ._vfolder_config ["mount" ] is None :
512521 log .info (
@@ -517,7 +526,7 @@ async def read_agent_config(self):
517526 log .info ("configured vfolder fs prefix: {0}" , self ._vfolder_config ["fsprefix" ])
518527
519528 # Fill up shared agent configurations from etcd.
520- agent_etcd_config_raw = await self . etcd .get_prefix ("config/agent" )
529+ agent_etcd_config_raw = await global_etcd .get_prefix ("config/agent" )
521530 if agent_etcd_config_raw :
522531 try :
523532 # Parse specific etcd configs and update the unified config
@@ -541,7 +550,9 @@ async def read_agent_config(self):
541550 async def read_agent_config_container (self ):
542551 # Fill up global container configurations from etcd.
543552 try :
544- container_etcd_config_raw = await self .etcd .get_prefix ("config/container" )
553+ container_etcd_config_raw = await self .etcd_client_registry .global_etcd .get_prefix (
554+ "config/container"
555+ )
545556 if container_etcd_config_raw :
546557 # Update config by creating a new instance with modified values
547558 container_updates = {}
@@ -586,7 +597,7 @@ async def __aexit__(self, *exc_info) -> None:
586597
587598 @collect_error
588599 async def update_status (self , status ):
589- await self .etcd .put ("" , status , scope = ConfigScopes .NODE )
600+ await self .agent . etcd .put ("" , status , scope = ConfigScopes .NODE )
590601
591602 @rpc_function
592603 @collect_error
@@ -1221,29 +1232,12 @@ async def aiomonitor_ctx(
12211232
12221233
12231234@asynccontextmanager
1224- async def etcd_ctx (local_config : AgentUnifiedConfig ) -> AsyncGenerator [AsyncEtcd ]:
1225- etcd_credentials = None
1226- if local_config .etcd .user and local_config .etcd .password :
1227- etcd_credentials = {
1228- "user" : local_config .etcd .user ,
1229- "password" : local_config .etcd .password ,
1230- }
1231- scope_prefix_map = {
1232- ConfigScopes .GLOBAL : "" ,
1233- ConfigScopes .SGROUP : f"sgroup/{ local_config .agent .scaling_group } " ,
1234- ConfigScopes .NODE : f"nodes/agents/{ local_config .agent .id } " ,
1235- }
1236- etcd_config_data = local_config .etcd .to_dataclass ()
1237- etcd = AsyncEtcd (
1238- [addr .to_legacy () for addr in etcd_config_data .addrs ],
1239- local_config .etcd .namespace ,
1240- scope_prefix_map ,
1241- credentials = etcd_credentials ,
1242- )
1235+ async def etcd_ctx (local_config : AgentUnifiedConfig ) -> AsyncGenerator [EtcdClientRegistry ]:
1236+ etcd_client_registry = EtcdClientRegistry (local_config .etcd .to_dataclass ())
12431237 try :
1244- yield etcd
1238+ yield etcd_client_registry
12451239 finally :
1246- await etcd .close ()
1240+ await etcd_client_registry .close ()
12471241
12481242
12491243async def prepare_krunner_volumes (local_config : AgentUnifiedConfig ) -> AgentUnifiedConfig :
@@ -1253,7 +1247,7 @@ async def prepare_krunner_volumes(local_config: AgentUnifiedConfig) -> AgentUnif
12531247 )
12541248 krunner_volumes : Mapping [str , str ] = await kernel_mod .prepare_krunner_env (
12551249 local_config .model_dump (by_alias = True )
1256- ) # type: ignore
1250+ )
12571251 # TODO: merge k8s branch: nfs_mount_path = local_config['baistatic']['mounted-at']
12581252 log .info ("Kernel runner environments: {}" , [* krunner_volumes .keys ()])
12591253 return local_config .with_updates (container_update = {"krunner_volumes" : krunner_volumes })
@@ -1317,10 +1311,10 @@ async def auto_detect_agent_network(
13171311
13181312@asynccontextmanager
13191313async def agent_server_ctx (
1320- local_config : AgentUnifiedConfig , etcd : AsyncEtcd
1314+ local_config : AgentUnifiedConfig , etcd_client_registry : EtcdClientRegistry
13211315) -> AsyncGenerator [AgentRPCServer ]:
13221316 agent_server = await AgentRPCServer .new (
1323- etcd ,
1317+ etcd_client_registry ,
13241318 local_config ,
13251319 skip_detect_manager = local_config .agent_common .skip_manager_detection ,
13261320 )
@@ -1419,18 +1413,19 @@ async def server_main(
14191413 local_config = await auto_detect_agent_identity (local_config )
14201414
14211415 # etcd's scope-prefix map depends on the auto-detected identity info.
1422- etcd = await agent_init_stack .enter_async_context (etcd_ctx (local_config ))
1423- local_config = await auto_detect_agent_network (local_config , etcd )
1424- plugins = await etcd .get_prefix_dict ("config/plugins/accelerator" )
1416+ etcd_client_registry = await agent_init_stack .enter_async_context (etcd_ctx (local_config ))
1417+ global_etcd = etcd_client_registry .global_etcd
1418+ local_config = await auto_detect_agent_network (local_config , global_etcd )
1419+ plugins = await global_etcd .get_prefix_dict ("config/plugins/accelerator" )
14251420 local_config = local_config .with_changes (plugins = plugins )
14261421
14271422 # Start RPC server.
14281423 agent_server = await agent_init_stack .enter_async_context (
1429- agent_server_ctx (local_config , etcd )
1424+ agent_server_ctx (local_config , etcd_client_registry )
14301425 )
14311426 monitor .console_locals ["agent_server" ] = agent_server
14321427
1433- await agent_init_stack .enter_async_context (service_discovery_ctx (etcd , agent_server ))
1428+ await agent_init_stack .enter_async_context (service_discovery_ctx (global_etcd , agent_server ))
14341429 log .info ("Started the agent service." )
14351430 except Exception :
14361431 log .exception ("Server initialization failure; triggering shutdown..." )
0 commit comments