Skip to content

Commit a1db2ba

Browse files
committed
Legibility refactor (more higher order functions instead of dataclasses)
1 parent a34c370 commit a1db2ba

File tree

5 files changed

+89
-85
lines changed

5 files changed

+89
-85
lines changed

Diff for: src/stac_auth_proxy/app.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .auth import OpenIdConnectAuth
1414
from .config import Settings
15-
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
15+
from .handlers import ReverseProxyHandler, build_openapi_spec_handler
1616
from .middleware import AddProcessTimeHeaderMiddleware
1717

1818
# from .utils import apply_filter
@@ -55,7 +55,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
5555
collections_filter=collections_filter,
5656
items_filter=items_filter,
5757
)
58-
openapi_handler = OpenApiSpecHandler(
58+
openapi_handler = build_openapi_spec_handler(
5959
proxy=proxy_handler,
6060
oidc_config_url=str(settings.oidc_discovery_url),
6161
)
@@ -67,7 +67,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
6767
(
6868
proxy_handler.stream
6969
if path != settings.openapi_spec_endpoint
70-
else openapi_handler.dispatch
70+
else openapi_handler
7171
),
7272
methods=methods,
7373
dependencies=[Security(auth_scheme.validated_user)],
@@ -80,7 +80,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
8080
(
8181
proxy_handler.stream
8282
if path != settings.openapi_spec_endpoint
83-
else openapi_handler.dispatch
83+
else openapi_handler
8484
),
8585
methods=methods,
8686
dependencies=[Security(auth_scheme.maybe_validated_user)],

Diff for: src/stac_auth_proxy/auth.py

+70-62
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import urllib.request
66
from dataclasses import dataclass, field
7-
from typing import Annotated, Any, Callable, Optional, Sequence
7+
from typing import Annotated, Optional, Sequence
88

99
import jwt
1010
from fastapi import HTTPException, Security, security, status
@@ -25,8 +25,6 @@ class OpenIdConnectAuth:
2525
# Generated attributes
2626
auth_scheme: SecurityBase = field(init=False)
2727
jwks_client: jwt.PyJWKClient = field(init=False)
28-
validated_user: Callable[..., Any] = field(init=False)
29-
maybe_validated_user: Callable[..., Any] = field(init=False)
3028

3129
def __post_init__(self):
3230
"""Initialize the OIDC authentication class."""
@@ -50,70 +48,80 @@ def __post_init__(self):
5048
openIdConnectUrl=str(self.openid_configuration_url),
5149
auto_error=False,
5250
)
53-
self.validated_user = self._build(auto_error=True)
54-
self.maybe_validated_user = self._build(auto_error=False)
55-
56-
def _build(self, auto_error: bool = True):
57-
"""Build a dependency for validating an OIDC token."""
58-
59-
def valid_token_dependency(
60-
auth_header: Annotated[str, Security(self.auth_scheme)],
61-
required_scopes: security.SecurityScopes,
62-
):
63-
"""Dependency to validate an OIDC token."""
64-
if not auth_header:
51+
52+
# Update annotations to support FastAPI's dependency injection
53+
for endpoint in [self.validated_user, self.maybe_validated_user]:
54+
endpoint.__annotations__["auth_header"] = Annotated[
55+
str,
56+
Security(self.auth_scheme),
57+
]
58+
59+
def maybe_validated_user(
60+
self,
61+
auth_header: Annotated[str, Security(...)],
62+
required_scopes: security.SecurityScopes,
63+
):
64+
"""Dependency to validate an OIDC token."""
65+
return self.validated_user(auth_header, required_scopes, auto_error=False)
66+
67+
def validated_user(
68+
self,
69+
auth_header: Annotated[str, Security(...)],
70+
required_scopes: security.SecurityScopes,
71+
auto_error: bool = True,
72+
):
73+
"""Dependency to validate an OIDC token."""
74+
if not auth_header:
75+
if auto_error:
76+
raise HTTPException(
77+
status_code=status.HTTP_403_FORBIDDEN,
78+
detail="Not authenticated",
79+
)
80+
return None
81+
82+
# Extract token from header
83+
token_parts = auth_header.split(" ")
84+
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
85+
logger.error(f"Invalid token: {auth_header}")
86+
raise HTTPException(
87+
status_code=status.HTTP_401_UNAUTHORIZED,
88+
detail="Could not validate credentials",
89+
headers={"WWW-Authenticate": "Bearer"},
90+
)
91+
[_, token] = token_parts
92+
93+
# Parse & validate token
94+
try:
95+
key = self.jwks_client.get_signing_key_from_jwt(token).key
96+
payload = jwt.decode(
97+
token,
98+
key,
99+
algorithms=["RS256"],
100+
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
101+
audience=self.allowed_jwt_audiences,
102+
)
103+
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
104+
logger.exception(f"InvalidTokenError: {e=}")
105+
raise HTTPException(
106+
status_code=status.HTTP_401_UNAUTHORIZED,
107+
detail="Could not validate credentials",
108+
headers={"WWW-Authenticate": "Bearer"},
109+
) from e
110+
111+
# Validate scopes (if required)
112+
for scope in required_scopes.scopes:
113+
if scope not in payload["scope"]:
65114
if auto_error:
66115
raise HTTPException(
67-
status_code=status.HTTP_403_FORBIDDEN,
68-
detail="Not authenticated",
116+
status_code=status.HTTP_401_UNAUTHORIZED,
117+
detail="Not enough permissions",
118+
headers={
119+
"WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"'
120+
},
69121
)
70122
return None
71123

72-
# Extract token from header
73-
token_parts = auth_header.split(" ")
74-
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
75-
logger.error(f"Invalid token: {auth_header}")
76-
raise HTTPException(
77-
status_code=status.HTTP_401_UNAUTHORIZED,
78-
detail="Could not validate credentials",
79-
headers={"WWW-Authenticate": "Bearer"},
80-
)
81-
[_, token] = token_parts
82-
83-
# Parse & validate token
84-
try:
85-
key = self.jwks_client.get_signing_key_from_jwt(token).key
86-
payload = jwt.decode(
87-
token,
88-
key,
89-
algorithms=["RS256"],
90-
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
91-
audience=self.allowed_jwt_audiences,
92-
)
93-
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
94-
logger.exception(f"InvalidTokenError: {e=}")
95-
raise HTTPException(
96-
status_code=status.HTTP_401_UNAUTHORIZED,
97-
detail="Could not validate credentials",
98-
headers={"WWW-Authenticate": "Bearer"},
99-
) from e
100-
101-
# Validate scopes (if required)
102-
for scope in required_scopes.scopes:
103-
if scope not in payload["scope"]:
104-
if auto_error:
105-
raise HTTPException(
106-
status_code=status.HTTP_401_UNAUTHORIZED,
107-
detail="Not enough permissions",
108-
headers={
109-
"WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"'
110-
},
111-
)
112-
return None
113-
114-
return payload
115-
116-
return valid_token_dependency
124+
return payload
117125

118126

119127
class OidcFetchError(Exception):

Diff for: src/stac_auth_proxy/filters/template.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Generate CQL2 filter expressions via Jinja2 templating."""
22

3-
from typing import Any, Annotated, Callable
3+
from typing import Annotated, Any, Callable
44

55
from cql2 import Expr
66
from fastapi import Request, Security

Diff for: src/stac_auth_proxy/handlers/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Handlers to process requests."""
22

3-
from .open_api_spec import OpenApiSpecHandler
3+
from .open_api_spec import build_openapi_spec_handler
44
from .reverse_proxy import ReverseProxyHandler
55

6-
__all__ = ["OpenApiSpecHandler", "ReverseProxyHandler"]
6+
__all__ = ["build_openapi_spec_handler", "ReverseProxyHandler"]

Diff for: src/stac_auth_proxy/handlers/open_api_spec.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Custom request handlers."""
22

33
import logging
4-
from dataclasses import dataclass
54

65
from fastapi import Request, Response
76
from fastapi.routing import APIRoute
@@ -12,17 +11,14 @@
1211
logger = logging.getLogger(__name__)
1312

1413

15-
@dataclass
16-
class OpenApiSpecHandler:
17-
"""Handler for OpenAPI spec requests."""
18-
19-
proxy: ReverseProxyHandler
20-
oidc_config_url: str
21-
auth_scheme_name: str = "oidcAuth"
22-
23-
async def dispatch(self, req: Request, res: Response):
14+
def build_openapi_spec_handler(
15+
proxy: ReverseProxyHandler,
16+
oidc_config_url: str,
17+
auth_scheme_name: str = "oidcAuth",
18+
):
19+
async def dispatch(req: Request, res: Response):
2420
"""Proxy the OpenAPI spec from the upstream STAC API, updating it with OIDC security requirements."""
25-
oidc_spec_response = await self.proxy.proxy_request(req)
21+
oidc_spec_response = await proxy.proxy_request(req)
2622
openapi_spec = oidc_spec_response.json()
2723

2824
# Pass along the response headers
@@ -45,10 +41,10 @@ async def dispatch(self, req: Request, res: Response):
4541

4642
# Add the OIDC security scheme to the components
4743
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
48-
self.auth_scheme_name
44+
auth_scheme_name
4945
] = {
5046
"type": "openIdConnect",
51-
"openIdConnectUrl": self.oidc_config_url,
47+
"openIdConnectUrl": oidc_config_url,
5248
}
5349

5450
# Update the paths with the specified security requirements
@@ -61,9 +57,9 @@ async def dispatch(self, req: Request, res: Response):
6157
if match.name != "FULL":
6258
continue
6359
# Add the OIDC security requirement
64-
config.setdefault("security", []).append(
65-
{self.auth_scheme_name: []}
66-
)
60+
config.setdefault("security", []).append({auth_scheme_name: []})
6761
break
6862

6963
return openapi_spec
64+
65+
return dispatch

0 commit comments

Comments
 (0)