Skip to content

Commit d1ad8b3

Browse files
committed
feature: conformance check (#46)
Tooling for automated conformance checks. Closes #26
1 parent fe46940 commit d1ad8b3

File tree

6 files changed

+181
-11
lines changed

6 files changed

+181
-11
lines changed

src/stac_auth_proxy/app.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
EnforceAuthMiddleware,
2222
OpenApiMiddleware,
2323
)
24-
from .utils.lifespan import check_server_health
24+
from .utils.lifespan import check_conformance, check_server_health
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -40,9 +40,22 @@ async def lifespan(app: FastAPI):
4040

4141
# Wait for upstream servers to become available
4242
if settings.wait_for_upstream:
43+
logger.info("Running upstream server health checks...")
4344
for url in [settings.upstream_url, settings.oidc_discovery_internal_url]:
4445
await check_server_health(url=url)
4546

47+
# Log all middleware connected to the app
48+
logger.debug(
49+
"Connected middleware:\n%s",
50+
"\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
51+
)
52+
53+
if settings.check_conformance:
54+
await check_conformance(
55+
app.user_middleware,
56+
str(settings.upstream_url),
57+
)
58+
4659
yield
4760

4861
app = FastAPI(
@@ -88,19 +101,19 @@ async def lifespan(app: FastAPI):
88101
)
89102

90103
app.add_middleware(
91-
EnforceAuthMiddleware,
92-
public_endpoints=settings.public_endpoints,
93-
private_endpoints=settings.private_endpoints,
94-
default_public=settings.default_public,
95-
oidc_config_url=settings.oidc_discovery_internal_url,
104+
CompressionMiddleware,
96105
)
97106

98107
app.add_middleware(
99-
CompressionMiddleware,
108+
AddProcessTimeHeaderMiddleware,
100109
)
101110

102111
app.add_middleware(
103-
AddProcessTimeHeaderMiddleware,
112+
EnforceAuthMiddleware,
113+
public_endpoints=settings.public_endpoints,
114+
private_endpoints=settings.private_endpoints,
115+
default_public=settings.default_public,
116+
oidc_config_url=settings.oidc_discovery_internal_url,
104117
)
105118

106119
return app

src/stac_auth_proxy/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class Settings(BaseSettings):
3939
oidc_discovery_internal_url: HttpUrl
4040

4141
wait_for_upstream: bool = True
42+
check_conformance: bool = True
4243

4344
# Endpoints
4445
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

+8
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1414

1515
from ..utils import filters
16+
from ..utils.middleware import required_conformance
1617

1718
logger = getLogger(__name__)
1819

1920

21+
@required_conformance(
22+
r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
23+
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
24+
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
25+
r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
26+
r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter",
27+
)
2028
@dataclass(frozen=True)
2129
class ApplyCql2FilterMiddleware:
2230
"""Middleware to apply the Cql2Filter to the request."""

src/stac_auth_proxy/utils/lifespan.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import asyncio
44
import logging
5+
import re
56

67
import httpx
78
from pydantic import HttpUrl
9+
from starlette.middleware import Middleware
810

911
logger = logging.getLogger(__name__)
1012

@@ -21,14 +23,16 @@ async def check_server_health(
2123
if isinstance(url, HttpUrl):
2224
url = str(url)
2325

24-
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
26+
async with httpx.AsyncClient(
27+
base_url=url, timeout=timeout, follow_redirects=True
28+
) as client:
2529
for attempt in range(max_retries):
2630
try:
27-
response = await client.get(url)
31+
response = await client.get("/")
2832
response.raise_for_status()
2933
logger.info(f"Upstream API {url!r} is healthy")
3034
return
31-
except Exception as e:
35+
except httpx.ConnectError as e:
3236
logger.warning(f"Upstream health check for {url!r} failed: {e}")
3337
retry_in = min(retry_delay * (2**attempt), retry_delay_max)
3438
logger.warning(
@@ -40,3 +44,50 @@ async def check_server_health(
4044
raise RuntimeError(
4145
f"Upstream API {url!r} failed to respond after {max_retries} attempts"
4246
)
47+
48+
49+
async def check_conformance(
50+
middleware_classes: list[Middleware],
51+
api_url: str,
52+
attr_name: str = "__required_conformances__",
53+
endpoint: str = "/conformance",
54+
):
55+
"""Check if the upstream API supports a given conformance class."""
56+
required_conformances: dict[str, list[str]] = {}
57+
for middleware in middleware_classes:
58+
59+
for conformance in getattr(middleware.cls, attr_name, []):
60+
required_conformances.setdefault(conformance, []).append(
61+
middleware.cls.__name__
62+
)
63+
64+
async with httpx.AsyncClient(base_url=api_url) as client:
65+
response = await client.get(endpoint)
66+
response.raise_for_status()
67+
api_conforms_to = response.json().get("conformsTo", [])
68+
69+
missing = [
70+
req_conformance
71+
for req_conformance in required_conformances.keys()
72+
if not any(
73+
re.match(req_conformance, conformance) for conformance in api_conforms_to
74+
)
75+
]
76+
77+
def conformance_str(conformance: str) -> str:
78+
return f" - {conformance} [{','.join(required_conformances[conformance])}]"
79+
80+
if missing:
81+
missing_str = [conformance_str(c) for c in missing]
82+
raise RuntimeError(
83+
"\n".join(
84+
[
85+
"Upstream catalog is missing the following conformance classes:",
86+
*missing_str,
87+
]
88+
)
89+
)
90+
logger.debug(
91+
"Upstream catalog conforms to the following required conformance classes: \n%s",
92+
"\n".join([conformance_str(c) for c in required_conformances]),
93+
)

src/stac_auth_proxy/utils/middleware.py

+13
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,16 @@ async def transform_response(message: Message) -> None:
9999
)
100100

101101
return await self.app(scope, receive, transform_response)
102+
103+
104+
def required_conformance(
105+
*conformances: str,
106+
attr_name: str = "__required_conformances__",
107+
):
108+
"""Register required conformance classes with a middleware class."""
109+
110+
def decorator(middleware):
111+
setattr(middleware, attr_name, list(conformances))
112+
return middleware
113+
114+
return decorator

tests/test_lifespan.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Tests for lifespan module."""
2+
3+
from dataclasses import dataclass
4+
from unittest.mock import patch
5+
6+
import pytest
7+
from starlette.middleware import Middleware
8+
from starlette.types import ASGIApp
9+
10+
from stac_auth_proxy.utils.lifespan import check_conformance, check_server_health
11+
from stac_auth_proxy.utils.middleware import required_conformance
12+
13+
14+
@required_conformance("http://example.com/conformance")
15+
@dataclass
16+
class TestMiddleware:
17+
"""Test middleware with required conformance."""
18+
19+
app: ASGIApp
20+
21+
22+
async def test_check_server_health_success(source_api_server):
23+
"""Test successful health check."""
24+
await check_server_health(source_api_server)
25+
26+
27+
async def test_check_server_health_failure():
28+
"""Test health check failure."""
29+
with patch("asyncio.sleep") as mock_sleep:
30+
with pytest.raises(RuntimeError) as exc_info:
31+
await check_server_health("http://localhost:9999")
32+
assert "failed to respond after" in str(exc_info.value)
33+
# Verify sleep was called with exponential backoff
34+
assert mock_sleep.call_count > 0
35+
# First call should be with base delay
36+
# NOTE: When testing individually, the mock_sleep strangely has a first call of
37+
# 0 seconds (possibly by httpx), however when running all tests, this does not
38+
# occur. So, we have to check for 1.0 in the first two calls.
39+
assert 1.0 in [mock_sleep.call_args_list[i][0][0] for i in range(2)]
40+
# Last call should be with max delay
41+
assert mock_sleep.call_args_list[-1][0][0] == 5.0
42+
43+
44+
async def test_check_conformance_success(source_api_server, source_api_responses):
45+
"""Test successful conformance check."""
46+
middleware = [Middleware(TestMiddleware)]
47+
await check_conformance(middleware, source_api_server)
48+
49+
50+
async def test_check_conformance_failure(source_api_server, source_api_responses):
51+
"""Test conformance check failure."""
52+
# Override the conformance response to not include required conformance
53+
source_api_responses["/conformance"]["GET"] = {"conformsTo": []}
54+
55+
middleware = [Middleware(TestMiddleware)]
56+
with pytest.raises(RuntimeError) as exc_info:
57+
await check_conformance(middleware, source_api_server)
58+
assert "missing the following conformance classes" in str(exc_info.value)
59+
60+
61+
async def test_check_conformance_multiple_middleware(source_api_server):
62+
"""Test conformance check with multiple middleware."""
63+
64+
@required_conformance("http://example.com/conformance")
65+
class TestMiddleware2:
66+
def __init__(self, app):
67+
self.app = app
68+
69+
middleware = [
70+
Middleware(TestMiddleware),
71+
Middleware(TestMiddleware2),
72+
]
73+
await check_conformance(middleware, source_api_server)
74+
75+
76+
async def test_check_conformance_no_required(source_api_server):
77+
"""Test conformance check with middleware that has no required conformances."""
78+
79+
class NoConformanceMiddleware:
80+
def __init__(self, app):
81+
self.app = app
82+
83+
middleware = [Middleware(NoConformanceMiddleware)]
84+
await check_conformance(middleware, source_api_server)

0 commit comments

Comments
 (0)