|
1 | | -import logging |
2 | 1 | import asyncio |
3 | 2 | import importlib |
| 3 | +import logging |
4 | 4 | import os.path |
5 | | -import api.globals as cms_globals |
6 | | - |
7 | | -from typing import Dict, Any, Optional |
8 | 5 | from concurrent.futures import ThreadPoolExecutor |
9 | | -from anyio.lowlevel import RunVar |
| 6 | +from typing import Any, Dict, Optional |
| 7 | + |
10 | 8 | from anyio import CapacityLimiter |
| 9 | +from anyio.lowlevel import RunVar |
11 | 10 | from fastapi import FastAPI, Request |
| 11 | +from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html |
12 | 12 | from fastapi.openapi.utils import get_openapi |
13 | | -from fastapi.responses import RedirectResponse, HTMLResponse |
| 13 | +from fastapi.responses import HTMLResponse, RedirectResponse |
14 | 14 | from fastapi.staticfiles import StaticFiles |
15 | | -from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html |
16 | 15 | from prometheus_fastapi_instrumentator import Instrumentator |
17 | 16 |
|
| 17 | +from domain import Tags, TagsStreamable |
| 18 | +from utils import get_settings |
| 19 | + |
| 20 | +import api.globals as cms_globals |
18 | 21 | from api.auth.db import make_sure_db_and_tables |
19 | 22 | from api.auth.users import Props |
20 | 23 | from api.dependencies import ModelServiceDep |
21 | 24 | from api.utils import add_exception_handlers, add_rate_limiter |
22 | | -from domain import Tags, TagsStreamable |
23 | 25 | from management.tracker_client import TrackerClient |
24 | | -from utils import get_settings |
25 | | - |
26 | 26 |
|
27 | 27 | logging.getLogger("asyncio").setLevel(logging.ERROR) |
28 | 28 | logger = logging.getLogger("cms") |
@@ -87,25 +87,37 @@ def get_stream_server(msd_overwritten: Optional[ModelServiceDep] = None) -> Fast |
87 | 87 | return app |
88 | 88 |
|
89 | 89 |
|
90 | | -def _get_app(msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False) -> FastAPI: |
91 | | - tags_metadata = [{"name": tag.name, "description": tag.value} for tag in (Tags if not streamable else TagsStreamable)] |
| 90 | +def _get_app( |
| 91 | + msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False |
| 92 | +) -> FastAPI: |
| 93 | + tags_metadata = [ |
| 94 | + {"name": tag.name, "description": tag.value} |
| 95 | + for tag in (Tags if not streamable else TagsStreamable) |
| 96 | + ] |
92 | 97 | config = get_settings() |
93 | | - app = FastAPI(title="CogStack ModelServe", |
94 | | - summary="A model serving and governance system for CogStack NLP solutions", |
95 | | - docs_url=None, |
96 | | - redoc_url=None, |
97 | | - debug=(config.DEBUG == "true"), |
98 | | - openapi_tags=tags_metadata) |
| 98 | + app = FastAPI( |
| 99 | + title="CogStack ModelServe", |
| 100 | + summary="A model serving and governance system for CogStack NLP solutions", |
| 101 | + docs_url=None, |
| 102 | + redoc_url=None, |
| 103 | + debug=(config.DEBUG == "true"), |
| 104 | + openapi_tags=tags_metadata, |
| 105 | + ) |
99 | 106 | add_exception_handlers(app) |
100 | 107 | instrumentator = Instrumentator( |
101 | | - excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]).instrument(app) |
| 108 | + excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"] |
| 109 | + ).instrument(app) |
102 | 110 |
|
103 | 111 | if msd_overwritten is not None: |
104 | 112 | cms_globals.model_service_dep = msd_overwritten |
105 | 113 |
|
106 | 114 | cms_globals.props = Props(config.AUTH_USER_ENABLED == "true") |
107 | 115 |
|
108 | | - app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static") |
| 116 | + app.mount( |
| 117 | + "/static", |
| 118 | + StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), |
| 119 | + name="static", |
| 120 | + ) |
109 | 121 |
|
110 | 122 | @app.on_event("startup") |
111 | 123 | async def on_startup() -> None: |
@@ -160,8 +172,11 @@ def custom_openapi() -> Dict[str, Any]: |
160 | 172 | openapi_schema = get_openapi( |
161 | 173 | title=f"{cms_globals.model_service_dep().model_name} APIs", |
162 | 174 | version=cms_globals.model_service_dep().api_version, |
163 | | - description="by CogStack ModelServe, a model serving and governance system for CogStack NLP solutions.", |
164 | | - routes=app.routes |
| 175 | + description=( |
| 176 | + "by CogStack ModelServe, a model serving and governance system for CogStack NLP" |
| 177 | + " solutions." |
| 178 | + ), |
| 179 | + routes=app.routes, |
165 | 180 | ) |
166 | 181 | openapi_schema["info"]["x-logo"] = { |
167 | 182 | "url": "https://avatars.githubusercontent.com/u/28688163?s=200&v=4" |
@@ -189,69 +204,79 @@ def custom_openapi() -> Dict[str, Any]: |
189 | 204 |
|
190 | 205 | def _load_auth_router(app: FastAPI) -> FastAPI: |
191 | 206 | from api.routers import authentication |
| 207 | + |
192 | 208 | importlib.reload(authentication) |
193 | 209 | app.include_router(authentication.router) |
194 | 210 | return app |
195 | 211 |
|
196 | 212 |
|
197 | 213 | def _load_model_card(app: FastAPI) -> FastAPI: |
198 | 214 | from api.routers import model_card |
| 215 | + |
199 | 216 | importlib.reload(model_card) |
200 | 217 | app.include_router(model_card.router) |
201 | 218 | return app |
202 | 219 |
|
203 | 220 |
|
204 | 221 | def _load_invocation_router(app: FastAPI) -> FastAPI: |
205 | 222 | from api.routers import invocation |
| 223 | + |
206 | 224 | importlib.reload(invocation) |
207 | 225 | app.include_router(invocation.router) |
208 | 226 | return app |
209 | 227 |
|
210 | 228 |
|
211 | 229 | def _load_supervised_training_router(app: FastAPI) -> FastAPI: |
212 | 230 | from api.routers import supervised_training |
| 231 | + |
213 | 232 | importlib.reload(supervised_training) |
214 | 233 | app.include_router(supervised_training.router) |
215 | 234 | return app |
216 | 235 |
|
217 | 236 |
|
218 | 237 | def _load_evaluation_router(app: FastAPI) -> FastAPI: |
219 | 238 | from api.routers import evaluation |
| 239 | + |
220 | 240 | importlib.reload(evaluation) |
221 | 241 | app.include_router(evaluation.router) |
222 | 242 | return app |
223 | 243 |
|
224 | 244 |
|
225 | 245 | def _load_preview_router(app: FastAPI) -> FastAPI: |
226 | 246 | from api.routers import preview |
| 247 | + |
227 | 248 | importlib.reload(preview) |
228 | 249 | app.include_router(preview.router) |
229 | 250 | return app |
230 | 251 |
|
231 | 252 |
|
232 | 253 | def _load_unsupervised_training_router(app: FastAPI) -> FastAPI: |
233 | 254 | from api.routers import unsupervised_training |
| 255 | + |
234 | 256 | importlib.reload(unsupervised_training) |
235 | 257 | app.include_router(unsupervised_training.router) |
236 | 258 | return app |
237 | 259 |
|
238 | 260 |
|
239 | 261 | def _load_metacat_training_router(app: FastAPI) -> FastAPI: |
240 | 262 | from api.routers import metacat_training |
| 263 | + |
241 | 264 | importlib.reload(metacat_training) |
242 | 265 | app.include_router(metacat_training.router) |
243 | 266 | return app |
244 | 267 |
|
245 | 268 |
|
246 | 269 | def _load_health_check_router(app: FastAPI) -> FastAPI: |
247 | 270 | from api.routers import health_check |
| 271 | + |
248 | 272 | importlib.reload(health_check) |
249 | 273 | app.include_router(health_check.router) |
250 | 274 | return app |
251 | 275 |
|
252 | 276 |
|
253 | 277 | def _load_stream_router(app: FastAPI) -> FastAPI: |
254 | 278 | from api.routers import stream |
| 279 | + |
255 | 280 | importlib.reload(stream) |
256 | 281 | app.include_router(stream.router, prefix="/stream") |
257 | 282 | return app |
0 commit comments