11import logging
22from collections .abc import Sequence
33from http import HTTPStatus
4- from typing import Any , List , Mapping , MutableMapping , cast
4+ from typing import Any , Iterator , List , Mapping , MutableMapping , cast , override
55from uuid import UUID
66
77import attr
88from aiodocker .docker import Docker
99from aiohttp import web
1010from aiohttp .typedefs import Handler , Middleware
1111
12+ from ai .backend .agent .agent import AbstractAgent
1213from ai .backend .agent .config .unified import AgentUnifiedConfig
1314from ai .backend .agent .docker .kernel import prepare_kernel_metadata_uri_handling
1415from ai .backend .agent .kernel import AbstractKernel
1718from ai .backend .common .etcd import AsyncEtcd
1819from ai .backend .common .json import dump_json_str
1920from ai .backend .common .plugin import BasePluginContext
20- from ai .backend .common .types import KernelId , aobject
21+ from ai .backend .common .types import AgentId , KernelId , aobject
2122from ai .backend .logging import BraceStyleAdapter
2223
2324from .plugin import MetadataPlugin
@@ -87,17 +88,50 @@ async def list_versions(request: web.Request) -> web.Response:
8788 return web .Response (body = "latest/" )
8889
8990
91+ class AggregateKernelRegistry (Mapping [KernelId , AbstractKernel ]):
92+ _agents : dict [AgentId , AbstractAgent ]
93+
94+ def __init__ (self ) -> None :
95+ self ._agents = {}
96+
97+ def register_agent (self , agent : AbstractAgent ) -> None :
98+ self ._agents [agent .id ] = agent
99+
100+ @override
101+ def __getitem__ (self , kernel_id : KernelId ) -> AbstractKernel :
102+ for agent in self ._agents .values ():
103+ if kernel_id in agent .kernel_registry :
104+ return agent .kernel_registry [kernel_id ]
105+ raise KeyError (kernel_id )
106+
107+ @override
108+ def __iter__ (self ) -> Iterator [KernelId ]:
109+ for agent in self ._agents .values ():
110+ yield from agent .kernel_registry .keys ()
111+
112+ @override
113+ def __len__ (self ) -> int :
114+ return sum (len (agent .kernel_registry ) for agent in self ._agents .values ())
115+
116+ @override
117+ def __contains__ (self , x : object , / ) -> bool :
118+ if not isinstance (x , str ):
119+ return False
120+ return any (agent .__contains__ (x ) for agent in self ._agents )
121+
122+
90123class MetadataServer (aobject ):
91124 app : web .Application
92125 runner : web .AppRunner
93126 route_structure : MutableMapping [str , Any ]
94127 loaded_apps : List [str ]
128+ kernel_registry : AggregateKernelRegistry
95129
96130 def __init__ (
97131 self ,
98132 local_config : AgentUnifiedConfig ,
99133 etcd : AsyncEtcd ,
100- kernel_registry : Mapping [ KernelId , AbstractKernel ] ,
134+ kernel_registry : AggregateKernelRegistry ,
101135 ) -> None :
102136 app = web .Application (
103137 middlewares = [
@@ -112,6 +146,7 @@ def __init__(
112146 self .app = app
113147 self .loaded_apps = []
114148 self .route_structure = {"latest" : {"extension" : {}}}
149+ self .kernel_registry = kernel_registry
115150
116151 async def __ainit__ (self ):
117152 local_config = cast (AgentUnifiedConfig , self .app ["_root.context" ].local_config )
@@ -132,6 +167,9 @@ async def __ainit__(self):
132167 self .app .router .add_route ("GET" , "/" , list_versions )
133168 self .app .router .add_route ("GET" , "/{version}" , self .list_available_apps )
134169
170+ def register_agent (self , agent : AbstractAgent ) -> None :
171+ self .kernel_registry .register_agent (agent )
172+
135173 async def list_available_apps (self , request : web .Request ) -> web .Response :
136174 return web .Response (body = "\n " .join ([x + "/" for x in self .loaded_apps ]))
137175
0 commit comments