diff --git a/CHANGES.md b/CHANGES.md index d2d3be5b..25a8eaea 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,7 @@ ### Added - add `enable_direct_response` settings to by-pass Pydantic validation and FastAPI serialization for responses +- add `/_mgmt/health` endpoint (`readiness`) and `health_check: Callable[[], [Dict]` optional attribute in `StacApi` class ## [5.1.1] - 2025-03-17 diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index c6874ce0..bd03b4fe 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -1,7 +1,7 @@ """Fastapi app creation.""" -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union import attr from brotli_asgi import BrotliMiddleware @@ -67,6 +67,10 @@ class StacApi: specified routes. This is useful for applying custom auth requirements to routes defined elsewhere in the application. + health_check: + A Callable which return application's `health` information. + Defaults to `def health: return {"status": "UP"}` + """ settings: ApiSettings = attr.ib() @@ -128,6 +132,9 @@ class StacApi: ) ) route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[]) + health_check: Union[Callable[[], Dict], Callable[[], Awaitable[Dict]]] = attr.ib( + default=lambda: {"status": "UP"} + ) def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]: """Get an extension. @@ -363,14 +370,44 @@ def register_core(self) -> None: def add_health_check(self) -> None: """Add a health check.""" - mgmt_router = APIRouter(prefix=self.app.state.router_prefix) - @mgmt_router.get("/_mgmt/ping") async def ping(): - """Liveliness/readiness probe.""" + """Liveliness probe.""" return {"message": "PONG"} - self.app.include_router(mgmt_router, tags=["Liveliness/Readiness"]) + self.app.router.add_api_route( + name="Ping", + path="/_mgmt/ping", + response_model=Dict, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + }, + }, + response_class=self.response_class, + methods=["GET"], + endpoint=ping, + tags=["Liveliness/Readiness"], + ) + + self.app.router.add_api_route( + name="Health", + path="/_mgmt/health", + response_model=Dict, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + }, + }, + response_class=self.response_class, + methods=["GET"], + endpoint=self.health_check, + tags=["Liveliness/Readiness"], + ) def add_route_dependencies( self, scopes: List[Scope], dependencies: List[Depends] diff --git a/stac_fastapi/api/tests/test_app.py b/stac_fastapi/api/tests/test_app.py index 44894c89..5500303e 100644 --- a/stac_fastapi/api/tests/test_app.py +++ b/stac_fastapi/api/tests/test_app.py @@ -6,6 +6,7 @@ from fastapi.testclient import TestClient from pydantic import ValidationError from stac_pydantic import api +from starlette.requests import Request from typing_extensions import Annotated from stac_fastapi.api import app @@ -529,3 +530,51 @@ def item_collection( assert post_search.json() == "2020-01-01T00:00:00.00001Z" assert post_search_zero.status_code == 200, post_search_zero.text assert post_search_zero.json() == "2020-01-01T00:00:00.0000Z" + + +def test_mgmt_endpoints(AsyncTestCoreClient): + """Test ping/health endpoints.""" + + test_app = app.StacApi( + settings=ApiSettings(), + client=AsyncTestCoreClient(), + ) + + with TestClient(test_app.app) as client: + resp = client.get("/_mgmt/ping") + assert resp.status_code == 200 + assert resp.json() == {"message": "PONG"} + + resp = client.get("/_mgmt/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "UP"} + + def health_check(request: Request): + return { + "status": "UP", + "database": { + "status": "UP", + "version": "0.1.0", + }, + } + + test_app = app.StacApi( + settings=ApiSettings(), + client=AsyncTestCoreClient(), + health_check=health_check, + ) + + with TestClient(test_app.app) as client: + resp = client.get("/_mgmt/ping") + assert resp.status_code == 200 + assert resp.json() == {"message": "PONG"} + + resp = client.get("/_mgmt/health") + assert resp.status_code == 200 + assert resp.json() == { + "status": "UP", + "database": { + "status": "UP", + "version": "0.1.0", + }, + }