Skip to content

Integrate with Authentication Extension #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e377d23
chore(docs): cleanup
alukach Mar 15, 2025
e462d3c
Merge branch 'main' into feature/s3-url-signing
alukach Mar 17, 2025
4ea48ae
In progress
alukach Mar 17, 2025
7a2bcfe
Merge branch 'main' into feature/s3-url-signing
alukach Mar 19, 2025
b27f28a
Merge branch 'main' into feature/s3-url-signing
alukach Mar 20, 2025
5ffcf7f
Rework middleware to abstract common response body parsing logic
alukach Mar 20, 2025
f024b2b
set max expiration, fix arg
alukach Mar 20, 2025
94fb5fb
add boto3 dependency
alukach Mar 20, 2025
02f94df
bump pgstac version
alukach Mar 20, 2025
e0dfcef
chore(cicd): parallelize tests
alukach Mar 20, 2025
e07e92e
Return signed url as JSON object
alukach Mar 20, 2025
c4bc3f2
support get item endpoint
alukach Mar 20, 2025
7d4177b
buildout working example
alukach Mar 20, 2025
882dd24
More progress
alukach Mar 21, 2025
d40af69
Working
alukach Mar 22, 2025
4469e3f
Merge branch 'main' into feature/authentication-extension
alukach Mar 23, 2025
1aba35c
fix: handle item view
alukach Mar 23, 2025
8d8342e
Merge branch 'main' into feature/authentication-extension
alukach Mar 24, 2025
0c1c9c8
Merge branch 'main' into feature/authentication-extension
alukach Mar 24, 2025
08dfde1
bugfix: handle feature collection view
alukach Mar 24, 2025
e356616
feat: support specifying cls & args & kwargs separately
alukach Mar 24, 2025
858c894
Merge branch 'main' into feature/authentication-extension
alukach Mar 25, 2025
b12df14
Rm redundant decompression
alukach Mar 25, 2025
71ac590
Rm unused authlib
alukach Mar 25, 2025
e33980b
Merge branch 'main' into feature/authentication-extension
alukach Mar 26, 2025
eb24e8f
fix: missing import
alukach Mar 26, 2025
a627c87
middleware util: use regex pattern
alukach Mar 26, 2025
3bfb408
Merge branch 'main' into feature/authentication-extension
alukach Mar 26, 2025
9ce9cad
rm asset signing logic
alukach Mar 26, 2025
86df1f3
Update pattern matching
alukach Mar 26, 2025
b8fbb63
Merge branch 'main' into feature/authentication-extension
alukach Mar 26, 2025
3f0b6af
Merge branch 'main' into feature/authentication-extension
alukach Mar 31, 2025
27b3552
Retrieve oidc info from scope
alukach Apr 1, 2025
42b032e
Rectify test
alukach Apr 1, 2025
d583363
Rm asset signer endpoint
alukach Apr 3, 2025
eb1435a
pass in request rather than scope, use state_key
alukach Apr 3, 2025
425777e
Finalize simplified authentication extension
alukach Apr 3, 2025
7f2b1fc
Add some tests
alukach Apr 3, 2025
b8424f9
add missing requirement
alukach Apr 3, 2025
1deb016
fix content-type matching
alukach Apr 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ classifiers = [
"License :: OSI Approved :: MIT License",
]
dependencies = [
"boto3>=1.37.16",
"brotli>=1.1.0",
"cql2>=0.3.6",
"cryptography>=44.0.1",
"fastapi>=0.115.5",
"httpx[http2]>=0.28.0",
"jinja2>=3.1.4",
Expand Down
8 changes: 8 additions & 0 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .middleware import (
AddProcessTimeHeaderMiddleware,
ApplyCql2FilterMiddleware,
AuthenticationExtensionMiddleware,
BuildCql2FilterMiddleware,
EnforceAuthMiddleware,
OpenApiMiddleware,
Expand Down Expand Up @@ -86,6 +87,13 @@ async def lifespan(app: FastAPI):
#
# Middleware (order is important, last added = first to run)
#
app.add_middleware(
AuthenticationExtensionMiddleware,
default_public=settings.default_public,
public_endpoints=settings.public_endpoints,
private_endpoints=settings.private_endpoints,
)

if settings.openapi_spec_endpoint:
app.add_middleware(
OpenApiMiddleware,
Expand Down
125 changes: 125 additions & 0 deletions src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Middleware to add auth information to item response served by upstream API."""

import logging
import re
from dataclasses import dataclass, field
from itertools import chain
from typing import Any
from urllib.parse import urlparse

from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.types import ASGIApp

from ..config import EndpointMethods
from ..utils.middleware import JsonResponseMiddleware
from ..utils.requests import find_match

logger = logging.getLogger(__name__)


@dataclass
class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
"""Middleware to add the authentication extension to the response."""

app: ASGIApp

default_public: bool
private_endpoints: EndpointMethods
public_endpoints: EndpointMethods

auth_scheme_name: str = "oauth"
auth_scheme: dict[str, Any] = field(default_factory=dict)
extension_url: str = (
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
)

json_content_type_expr: str = r"application/(geo\+)?json"

state_key: str = "oidc_metadata"

def should_transform_response(
self, request: Request, response_headers: Headers
) -> bool:
"""Determine if the response should be transformed."""
# Match STAC catalog, collection, or item URLs with a single regex
return all(
[
re.match(
# catalog, collections, collection, items, item, search
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
request.url.path,
),
re.match(
self.json_content_type_expr,
response_headers.get("content-type", ""),
),
]
)

def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
"""Augment the STAC Item with auth information."""
extensions = data.setdefault("stac_extensions", [])
if self.extension_url not in extensions:
extensions.append(self.extension_url)

# auth:schemes
# ---
# A property that contains all of the scheme definitions used by Assets and
# Links in the STAC Item or Collection.
# - Catalogs
# - Collections
# - Item Properties

oidc_metadata = getattr(request.state, self.state_key, {})
if not oidc_metadata:
logger.error(
"OIDC metadata not found in scope. Skipping authentication extension."
)
return data

scheme_loc = data["properties"] if "properties" in data else data
schemes = scheme_loc.setdefault("auth:schemes", {})
schemes[self.auth_scheme_name] = {
"type": "oauth2",
"description": "requires an authentication bearertoken",
"flows": {
"authorizationCode": {
"authorizationUrl": oidc_metadata["authorization_endpoint"],
"tokenUrl": oidc_metadata.get("token_endpoint"),
"scopes": {
k: k for k in sorted(oidc_metadata.get("scopes_supported", []))
},
},
},
}

# auth:refs
# ---
# Annotate links with "auth:refs": [auth_scheme]
links = chain(
# Item/Collection
data.get("links", []),
# Collections/Items/Search
(
link
for prop in ["features", "collections"]
for object_with_links in data.get(prop, [])
for link in object_with_links.get("links", [])
),
)
for link in links:
if "href" not in link:
logger.warning("Link %s has no href", link)
continue
match = find_match(
path=urlparse(link["href"]).path,
method="GET",
private_endpoints=self.private_endpoints,
public_endpoints=self.public_endpoints,
default_public=self.default_public,
)
if match.is_private:
link.setdefault("auth:refs", []).append(self.auth_scheme_name)

return data
106 changes: 58 additions & 48 deletions src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Middleware to enforce authentication."""

import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Annotated, Any, Optional, Sequence
from urllib.parse import urlparse, urlunparse

Expand All @@ -18,6 +18,53 @@
logger = logging.getLogger(__name__)


@dataclass
class OidcService:
"""OIDC configuration and JWKS client."""

oidc_config_url: HttpUrl
jwks_client: jwt.PyJWKClient = field(init=False)
metadata: dict[str, Any] = field(init=False)

def __post_init__(self) -> None:
"""Initialize OIDC config and JWKS client."""
logger.debug("Requesting OIDC config")
origin_url = str(self.oidc_config_url)

try:
response = httpx.get(origin_url)
response.raise_for_status()
self.metadata = response.json()
assert self.metadata, "OIDC metadata is empty"

# NOTE: We manually replace the origin of the jwks_uri in the event that
# the jwks_uri is not available from within the proxy.
oidc_url = urlparse(origin_url)
jwks_uri = urlunparse(
urlparse(self.metadata["jwks_uri"])._replace(
netloc=oidc_url.netloc, scheme=oidc_url.scheme
)
)
if jwks_uri != self.metadata["jwks_uri"]:
logger.warning(
"JWKS URI has been rewritten from %s to %s",
self.metadata["jwks_uri"],
jwks_uri,
)
self.jwks_client = jwt.PyJWKClient(jwks_uri)
except httpx.HTTPStatusError as e:
logger.error(
"Received a non-200 response when fetching OIDC config: %s",
e.response.text,
)
raise OidcFetchError(
f"Request for OIDC config failed with status {e.response.status_code}"
) from e
except httpx.RequestError as e:
logger.error("Error fetching OIDC config from %s: %s", origin_url, str(e))
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e


@dataclass
class EnforceAuthMiddleware:
"""Middleware to enforce authentication."""
Expand All @@ -26,56 +73,11 @@ class EnforceAuthMiddleware:
private_endpoints: EndpointMethods
public_endpoints: EndpointMethods
default_public: bool

oidc_config_url: HttpUrl
allowed_jwt_audiences: Optional[Sequence[str]] = None

state_key: str = "payload"

# Generated attributes
_jwks_client: Optional[jwt.PyJWKClient] = None

@property
def jwks_client(self) -> jwt.PyJWKClient:
"""Get the OIDC configuration URL."""
if not self._jwks_client:
logger.debug("Requesting OIDC config")
origin_url = str(self.oidc_config_url)

try:
response = httpx.get(origin_url)
response.raise_for_status()
oidc_config = response.json()

# NOTE: We manually replace the origin of the jwks_uri in the event that
# the jwks_uri is not available from within the proxy.
oidc_url = urlparse(origin_url)
jwks_uri = urlunparse(
urlparse(oidc_config["jwks_uri"])._replace(
netloc=oidc_url.netloc, scheme=oidc_url.scheme
)
)
if jwks_uri != oidc_config["jwks_uri"]:
logger.warning(
"JWKS URI has been rewritten from %s to %s",
oidc_config["jwks_uri"],
jwks_uri,
)
self._jwks_client = jwt.PyJWKClient(jwks_uri)
except httpx.HTTPStatusError as e:
logger.error(
"Received a non-200 response when fetching OIDC config: %s",
e.response.text,
)
raise OidcFetchError(
f"Request for OIDC config failed with status {e.response.status_code}"
) from e
except httpx.RequestError as e:
logger.error(
"Error fetching OIDC config from %s: %s", origin_url, str(e)
)
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
return self._jwks_client
_oidc_config: Optional[OidcService] = None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Enforce authentication."""
Expand Down Expand Up @@ -107,6 +109,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.state_key,
payload,
)
setattr(request.state, "oidc_metadata", self.oidc_config.metadata)
return await self.app(scope, receive, send)

def validate_token(
Expand Down Expand Up @@ -137,7 +140,7 @@ def validate_token(

# Parse & validate token
try:
key = self.jwks_client.get_signing_key_from_jwt(token).key
key = self.oidc_config.jwks_client.get_signing_key_from_jwt(token).key
payload = jwt.decode(
token,
key,
Expand All @@ -163,6 +166,13 @@ def validate_token(
)
return payload

@property
def oidc_config(self) -> OidcService:
"""Get the OIDC configuration."""
if not self._oidc_config:
self._oidc_config = OidcService(oidc_config_url=self.oidc_config_url)
return self._oidc_config


class OidcFetchError(Exception):
"""Error fetching OIDC configuration."""
Expand Down
8 changes: 4 additions & 4 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def should_transform_response(
]
)

def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
"""Augment the OpenAPI spec with auth information."""
components = openapi_spec.setdefault("components", {})
components = data.setdefault("components", {})
securitySchemes = components.setdefault("securitySchemes", {})
securitySchemes[self.oidc_auth_scheme_name] = {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_config_url,
}
for path, method_config in openapi_spec["paths"].items():
for path, method_config in data["paths"].items():
for method, config in method_config.items():
match = find_match(
path,
Expand All @@ -62,4 +62,4 @@ def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
config.setdefault("security", []).append(
{self.oidc_auth_scheme_name: match.required_scopes}
)
return openapi_spec
return data
15 changes: 7 additions & 8 deletions src/stac_auth_proxy/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
from .ApplyCql2FilterMiddleware import ApplyCql2FilterMiddleware
from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware
from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware
from .EnforceAuthMiddleware import EnforceAuthMiddleware
from .UpdateOpenApiMiddleware import OpenApiMiddleware

__all__ = [
x.__name__
for x in [
OpenApiMiddleware,
AddProcessTimeHeaderMiddleware,
EnforceAuthMiddleware,
BuildCql2FilterMiddleware,
ApplyCql2FilterMiddleware,
]
"AddProcessTimeHeaderMiddleware",
"ApplyCql2FilterMiddleware",
"AuthenticationExtensionMiddleware",
"BuildCql2FilterMiddleware",
"EnforceAuthMiddleware",
"OpenApiMiddleware",
]
7 changes: 4 additions & 3 deletions src/stac_auth_proxy/utils/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def should_transform_response(
...

@abstractmethod
def transform_json(self, data: Any) -> Any:
def transform_json(self, data: Any, request: Request) -> Any:
"""
Transform the JSON data.

Expand All @@ -56,9 +56,10 @@ async def transform_response(message: Message) -> None:

start_message = start_message or message
headers = MutableHeaders(scope=start_message)
request = Request(scope)

if not self.should_transform_response(
request=Request(scope),
request=request,
response_headers=headers,
):
# For non-JSON responses, send the start message immediately
Expand All @@ -78,7 +79,7 @@ async def transform_response(message: Message) -> None:
# Transform the JSON body
if body:
data = json.loads(body)
transformed = self.transform_json(data)
transformed = self.transform_json(data, request=request)
body = json.dumps(transformed).encode()

# Update content-length header
Expand Down
Loading