Skip to content

Feature/conformance check #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
EnforceAuthMiddleware,
OpenApiMiddleware,
)
from .utils.lifespan import check_server_health
from .utils.lifespan import check_conformance, check_server_health

logger = logging.getLogger(__name__)

Expand All @@ -40,9 +40,22 @@ async def lifespan(app: FastAPI):

# Wait for upstream servers to become available
if settings.wait_for_upstream:
logger.info("Running upstream server health checks...")
for url in [settings.upstream_url, settings.oidc_discovery_internal_url]:
await check_server_health(url=url)

# Log all middleware connected to the app
logger.debug(
"Connected middleware:\n%s",
"\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
)

if settings.check_conformance:
await check_conformance(
app.user_middleware,
str(settings.upstream_url),
)

yield

app = FastAPI(
Expand Down Expand Up @@ -88,19 +101,19 @@ async def lifespan(app: FastAPI):
)

app.add_middleware(
EnforceAuthMiddleware,
public_endpoints=settings.public_endpoints,
private_endpoints=settings.private_endpoints,
default_public=settings.default_public,
oidc_config_url=settings.oidc_discovery_internal_url,
CompressionMiddleware,
)

app.add_middleware(
CompressionMiddleware,
AddProcessTimeHeaderMiddleware,
)

app.add_middleware(
AddProcessTimeHeaderMiddleware,
EnforceAuthMiddleware,
public_endpoints=settings.public_endpoints,
private_endpoints=settings.private_endpoints,
default_public=settings.default_public,
oidc_config_url=settings.oidc_discovery_internal_url,
)

return app
1 change: 1 addition & 0 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Settings(BaseSettings):
oidc_discovery_internal_url: HttpUrl

wait_for_upstream: bool = True
check_conformance: bool = True

# Endpoints
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
Expand Down
8 changes: 8 additions & 0 deletions src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from ..utils import filters
from ..utils.middleware import required_conformance

logger = getLogger(__name__)


@required_conformance(
r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter",
)
@dataclass(frozen=True)
class ApplyCql2FilterMiddleware:
"""Middleware to apply the Cql2Filter to the request."""
Expand Down
57 changes: 54 additions & 3 deletions src/stac_auth_proxy/utils/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import asyncio
import logging
import re

import httpx
from pydantic import HttpUrl
from starlette.middleware import Middleware

logger = logging.getLogger(__name__)

Expand All @@ -21,14 +23,16 @@ async def check_server_health(
if isinstance(url, HttpUrl):
url = str(url)

async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
async with httpx.AsyncClient(
base_url=url, timeout=timeout, follow_redirects=True
) as client:
for attempt in range(max_retries):
try:
response = await client.get(url)
response = await client.get("/")
response.raise_for_status()
logger.info(f"Upstream API {url!r} is healthy")
return
except Exception as e:
except httpx.ConnectError as e:
logger.warning(f"Upstream health check for {url!r} failed: {e}")
retry_in = min(retry_delay * (2**attempt), retry_delay_max)
logger.warning(
Expand All @@ -40,3 +44,50 @@ async def check_server_health(
raise RuntimeError(
f"Upstream API {url!r} failed to respond after {max_retries} attempts"
)


async def check_conformance(
middleware_classes: list[Middleware],
api_url: str,
attr_name: str = "__required_conformances__",
endpoint: str = "/conformance",
):
"""Check if the upstream API supports a given conformance class."""
required_conformances: dict[str, list[str]] = {}
for middleware in middleware_classes:

for conformance in getattr(middleware.cls, attr_name, []):
required_conformances.setdefault(conformance, []).append(
middleware.cls.__name__
)

async with httpx.AsyncClient(base_url=api_url) as client:
response = await client.get(endpoint)
response.raise_for_status()
api_conforms_to = response.json().get("conformsTo", [])

missing = [
req_conformance
for req_conformance in required_conformances.keys()
if not any(
re.match(req_conformance, conformance) for conformance in api_conforms_to
)
]

def conformance_str(conformance: str) -> str:
return f" - {conformance} [{','.join(required_conformances[conformance])}]"

if missing:
missing_str = [conformance_str(c) for c in missing]
raise RuntimeError(
"\n".join(
[
"Upstream catalog is missing the following conformance classes:",
*missing_str,
]
)
)
logger.debug(
"Upstream catalog conforms to the following required conformance classes: \n%s",
"\n".join([conformance_str(c) for c in required_conformances]),
)
13 changes: 13 additions & 0 deletions src/stac_auth_proxy/utils/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,16 @@ async def transform_response(message: Message) -> None:
)

return await self.app(scope, receive, transform_response)


def required_conformance(
*conformances: str,
attr_name: str = "__required_conformances__",
):
"""Register required conformance classes with a middleware class."""

def decorator(middleware):
setattr(middleware, attr_name, list(conformances))
return middleware

return decorator
84 changes: 84 additions & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests for lifespan module."""

from dataclasses import dataclass
from unittest.mock import patch

import pytest
from starlette.middleware import Middleware
from starlette.types import ASGIApp

from stac_auth_proxy.utils.lifespan import check_conformance, check_server_health
from stac_auth_proxy.utils.middleware import required_conformance


@required_conformance("http://example.com/conformance")
@dataclass
class TestMiddleware:
"""Test middleware with required conformance."""

app: ASGIApp


async def test_check_server_health_success(source_api_server):
"""Test successful health check."""
await check_server_health(source_api_server)


async def test_check_server_health_failure():
"""Test health check failure."""
with patch("asyncio.sleep") as mock_sleep:
with pytest.raises(RuntimeError) as exc_info:
await check_server_health("http://localhost:9999")
assert "failed to respond after" in str(exc_info.value)
# Verify sleep was called with exponential backoff
assert mock_sleep.call_count > 0
# First call should be with base delay
# NOTE: When testing individually, the mock_sleep strangely has a first call of
# 0 seconds (possibly by httpx), however when running all tests, this does not
# occur. So, we have to check for 1.0 in the first two calls.
assert 1.0 in [mock_sleep.call_args_list[i][0][0] for i in range(2)]
# Last call should be with max delay
assert mock_sleep.call_args_list[-1][0][0] == 5.0


async def test_check_conformance_success(source_api_server, source_api_responses):
"""Test successful conformance check."""
middleware = [Middleware(TestMiddleware)]
await check_conformance(middleware, source_api_server)


async def test_check_conformance_failure(source_api_server, source_api_responses):
"""Test conformance check failure."""
# Override the conformance response to not include required conformance
source_api_responses["/conformance"]["GET"] = {"conformsTo": []}

middleware = [Middleware(TestMiddleware)]
with pytest.raises(RuntimeError) as exc_info:
await check_conformance(middleware, source_api_server)
assert "missing the following conformance classes" in str(exc_info.value)


async def test_check_conformance_multiple_middleware(source_api_server):
"""Test conformance check with multiple middleware."""

@required_conformance("http://example.com/conformance")
class TestMiddleware2:
def __init__(self, app):
self.app = app

middleware = [
Middleware(TestMiddleware),
Middleware(TestMiddleware2),
]
await check_conformance(middleware, source_api_server)


async def test_check_conformance_no_required(source_api_server):
"""Test conformance check with middleware that has no required conformances."""

class NoConformanceMiddleware:
def __init__(self, app):
self.app = app

middleware = [Middleware(NoConformanceMiddleware)]
await check_conformance(middleware, source_api_server)