Skip to content

Commit 3d87471

Browse files
committed
refactor: simplify lifespan
1 parent c565ae7 commit 3d87471

File tree

5 files changed

+55
-112
lines changed

5 files changed

+55
-112
lines changed

src/stac_auth_proxy/app.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,22 @@
66
"""
77

88
import logging
9+
from contextlib import asynccontextmanager
910
from typing import Optional
1011

1112
from fastapi import FastAPI
1213
from starlette_cramjam.middleware import CompressionMiddleware
1314

1415
from .config import Settings
1516
from .handlers import HealthzHandler, ReverseProxyHandler
16-
from .lifespan import LifespanManager, ServerHealthCheck
1717
from .middleware import (
1818
AddProcessTimeHeaderMiddleware,
1919
ApplyCql2FilterMiddleware,
2020
BuildCql2FilterMiddleware,
2121
EnforceAuthMiddleware,
2222
OpenApiMiddleware,
2323
)
24+
from .utils.lifespan import check_server_health
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -32,14 +33,17 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
3233
#
3334
# Application
3435
#
35-
upstream_urls = (
36-
[settings.upstream_url, settings.oidc_discovery_internal_url]
37-
if settings.wait_for_upstream
38-
else []
39-
)
40-
lifespan = LifespanManager(
41-
on_startup=([ServerHealthCheck(url=url) for url in upstream_urls])
42-
)
36+
37+
@asynccontextmanager
38+
async def lifespan(app: FastAPI):
39+
assert settings
40+
41+
# Wait for upstream servers to become available
42+
if settings.wait_for_upstream:
43+
for url in [settings.upstream_url, settings.oidc_discovery_internal_url]:
44+
await check_server_health(url=url)
45+
46+
yield
4347

4448
app = FastAPI(
4549
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema

src/stac_auth_proxy/lifespan/LifespanManager.py

-37
This file was deleted.

src/stac_auth_proxy/lifespan/ServerHealthCheck.py

-57
This file was deleted.

src/stac_auth_proxy/lifespan/__init__.py

-9
This file was deleted.

src/stac_auth_proxy/utils/lifespan.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Health check implementations for lifespan events."""
2+
3+
import asyncio
4+
import logging
5+
6+
import httpx
7+
from pydantic import HttpUrl
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
async def check_server_health(
13+
url: str | HttpUrl,
14+
max_retries: int = 10,
15+
retry_delay: float = 1.0,
16+
retry_delay_max: float = 5.0,
17+
timeout: float = 5.0,
18+
) -> None:
19+
"""Wait for upstream API to become available."""
20+
# Convert url to string if it's a HttpUrl
21+
if isinstance(url, HttpUrl):
22+
url = str(url)
23+
24+
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
25+
for attempt in range(max_retries):
26+
try:
27+
response = await client.get(url)
28+
response.raise_for_status()
29+
logger.info(f"Upstream API {url!r} is healthy")
30+
return
31+
except Exception as e:
32+
logger.warning(f"Upstream health check for {url!r} failed: {e}")
33+
retry_in = min(retry_delay * (2**attempt), retry_delay_max)
34+
logger.warning(
35+
f"Upstream API {url!r} not healthy, retrying in {retry_in:.1f}s "
36+
f"(attempt {attempt + 1}/{max_retries})"
37+
)
38+
await asyncio.sleep(retry_in)
39+
40+
raise RuntimeError(
41+
f"Upstream API {url!r} failed to respond after {max_retries} attempts"
42+
)

0 commit comments

Comments
 (0)