Skip to content

Commit 425777e

Browse files
committed
Finalize simplified authentication extension
1 parent eb1435a commit 425777e

File tree

3 files changed

+23
-25
lines changed

3 files changed

+23
-25
lines changed

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
6363
if self.extension_url not in extensions:
6464
extensions.append(self.extension_url)
6565

66-
# TODO: Should we add this to items even if the assets don't match the asset expression?
6766
# auth:schemes
6867
# ---
6968
# A property that contains all of the scheme definitions used by Assets and
@@ -72,18 +71,28 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
7271
# - Collections
7372
# - Item Properties
7473

75-
if self.state_key not in request.state:
74+
oidc_metadata = getattr(request.state, self.state_key, {})
75+
if not oidc_metadata:
7676
logger.error(
77-
"OIDC metadata not found in scope. "
78-
"Skipping authentication extension."
77+
"OIDC metadata not found in scope. Skipping authentication extension."
7978
)
8079
return data
8180

8281
scheme_loc = data["properties"] if "properties" in data else data
8382
schemes = scheme_loc.setdefault("auth:schemes", {})
84-
schemes[self.auth_scheme_name] = self.parse_oidc_config(
85-
request.state.get(self.state_key, {})
86-
)
83+
schemes[self.auth_scheme_name] = {
84+
"type": "oauth2",
85+
"description": "requires an authentication token",
86+
"flows": {
87+
"authorizationCode": {
88+
"authorizationUrl": oidc_metadata["authorization_endpoint"],
89+
"tokenUrl": oidc_metadata.get("token_endpoint"),
90+
"scopes": {
91+
k: k for k in sorted(oidc_metadata.get("scopes_supported", []))
92+
},
93+
},
94+
},
95+
}
8796

8897
# auth:refs
8998
# ---
@@ -114,19 +123,3 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
114123
link.setdefault("auth:refs", []).append(self.auth_scheme_name)
115124

116125
return data
117-
118-
def parse_oidc_config(self, oidc_config: dict[str, Any]) -> dict[str, Any]:
119-
"""Parse the OIDC configuration."""
120-
return {
121-
"type": "oauth2",
122-
"description": "requires an authentication token",
123-
"flows": {
124-
"authorizationCode": {
125-
"authorizationUrl": oidc_config["authorization_endpoint"],
126-
"tokenUrl": oidc_config.get("token_endpoint"),
127-
"scopes": {
128-
k: k for k in sorted(oidc_config.get("scopes_supported", []))
129-
},
130-
},
131-
},
132-
}

tests/conftest.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ def public_key(test_key: jwk.JWK) -> dict[str, Any]:
3232
@pytest.fixture(autouse=True)
3333
def mock_jwks(public_key: dict[str, Any]):
3434
"""Mock JWKS endpoint."""
35-
mock_oidc_config = {"jwks_uri": "https://example.com/jwks"}
35+
mock_oidc_config = {
36+
"jwks_uri": "https://example.com/jwks",
37+
"authorization_endpoint": "https://example.com/auth",
38+
"token_endpoint": "https://example.com/token",
39+
"scopes_supported": ["openid", "profile", "email", "collection:create"],
40+
}
3641

3742
mock_jwks = {"keys": [public_key]}
3843

tests/test_filters_jinja2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def test_item_get(
310310
if is_authenticated:
311311
assert response.status_code == 200
312312
assert response.json()["id"] == "bar"
313-
assert response.json()["properties"] == {"private": True}
313+
assert response.json()["properties"].get("private") is True
314314
else:
315315
assert response.status_code == 404
316316
assert response.json() == {"message": "Not found"}

0 commit comments

Comments
 (0)