Skip to content

Commit 9f8f057

Browse files
authored
refactor: JSON parsing middleware helper (#42)
1 parent 45b4e38 commit 9f8f057

File tree

2 files changed

+121
-60
lines changed

2 files changed

+121
-60
lines changed

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

+9-60
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22

3-
import json
43
from dataclasses import dataclass
5-
from typing import Any, Optional
4+
from typing import Any
65

7-
from starlette.datastructures import MutableHeaders
86
from starlette.requests import Request
9-
from starlette.types import ASGIApp, Message, Receive, Scope, Send
7+
from starlette.types import ASGIApp
108

119
from ..config import EndpointMethods
12-
from ..utils.requests import dict_to_bytes, find_match
10+
from ..utils.middleware import JsonResponseMiddleware
11+
from ..utils.requests import find_match
1312

1413

1514
@dataclass(frozen=True)
16-
class OpenApiMiddleware:
15+
class OpenApiMiddleware(JsonResponseMiddleware):
1716
"""Middleware to add the OpenAPI spec to the response."""
1817

1918
app: ASGIApp
@@ -24,61 +23,11 @@ class OpenApiMiddleware:
2423
default_public: bool
2524
oidc_auth_scheme_name: str = "oidcAuth"
2625

27-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
28-
"""Add the OpenAPI spec to the response."""
29-
if scope["type"] != "http" or Request(scope).url.path != self.openapi_spec_path:
30-
return await self.app(scope, receive, send)
26+
def should_transform_response(self, request: Request) -> bool:
27+
"""Only transform responses for the OpenAPI spec path."""
28+
return request.url.path == self.openapi_spec_path
3129

32-
start_message: Optional[Message] = None
33-
body = b""
34-
35-
async def augment_oidc_spec(message: Message):
36-
nonlocal start_message
37-
nonlocal body
38-
if message["type"] == "http.response.start":
39-
# NOTE: Because we are modifying the response body, we will need to update
40-
# the content-length header. However, headers are sent before we see the
41-
# body. To handle this, we delay sending the http.response.start message
42-
# until after we alter the body.
43-
start_message = message
44-
return
45-
elif message["type"] != "http.response.body":
46-
return await send(message)
47-
48-
body += message["body"]
49-
50-
# Skip body chunks until all chunks have been received
51-
if message.get("more_body"):
52-
return
53-
54-
# Maybe decompress the body
55-
headers = MutableHeaders(scope=start_message)
56-
57-
# Augment the spec
58-
body = dict_to_bytes(self.augment_spec(json.loads(body)))
59-
60-
# Update the content-length header
61-
headers["content-length"] = str(len(body))
62-
assert start_message, "Expected start_message to be set"
63-
start_message["headers"] = [
64-
(key.encode(), value.encode()) for key, value in headers.items()
65-
]
66-
67-
# Send http.response.start
68-
await send(start_message)
69-
70-
# Send http.response.body
71-
await send(
72-
{
73-
"type": "http.response.body",
74-
"body": body,
75-
"more_body": False,
76-
}
77-
)
78-
79-
return await self.app(scope, receive, augment_oidc_spec)
80-
81-
def augment_spec(self, openapi_spec) -> dict[str, Any]:
30+
def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
8231
"""Augment the OpenAPI spec with auth information."""
8332
components = openapi_spec.setdefault("components", {})
8433
securitySchemes = components.setdefault("securitySchemes", {})
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Utilities for middleware response handling."""
2+
3+
import json
4+
import re
5+
from abc import ABC, abstractmethod
6+
from typing import Any, Optional
7+
8+
from starlette.datastructures import Headers, MutableHeaders
9+
from starlette.requests import Request
10+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
11+
12+
13+
class JsonResponseMiddleware(ABC):
14+
"""Base class for middleware that transforms JSON response bodies."""
15+
16+
app: ASGIApp
17+
json_content_type_expr: str = (
18+
r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
19+
)
20+
21+
@abstractmethod
22+
def should_transform_response(self, request: Request) -> bool:
23+
"""
24+
Determine if this request's response should be transformed.
25+
26+
Args:
27+
request: The incoming request
28+
29+
Returns
30+
-------
31+
bool: True if the response should be transformed
32+
"""
33+
return bool(
34+
re.match(self.json_content_type_expr, request.headers.get("accept", ""))
35+
)
36+
37+
@abstractmethod
38+
def transform_json(self, data: Any) -> Any:
39+
"""
40+
Transform the JSON data.
41+
42+
Args:
43+
data: The parsed JSON data
44+
45+
Returns
46+
-------
47+
The transformed JSON data
48+
"""
49+
pass
50+
51+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
52+
"""Process the request/response."""
53+
if scope["type"] != "http":
54+
return await self.app(scope, receive, send)
55+
56+
request = Request(scope)
57+
if not self.should_transform_response(request):
58+
return await self.app(scope, receive, send)
59+
60+
start_message: Optional[Message] = None
61+
body = b""
62+
not_json = False
63+
64+
async def process_message(message: Message) -> None:
65+
nonlocal start_message
66+
nonlocal body
67+
nonlocal not_json
68+
if message["type"] == "http.response.start":
69+
# Delay sending start message until we've processed the body
70+
if not re.match(
71+
self.json_content_type_expr,
72+
Headers(scope=message).get("content-type", ""),
73+
):
74+
not_json = True
75+
return await send(message)
76+
start_message = message
77+
return
78+
elif message["type"] != "http.response.body" or not_json:
79+
return await send(message)
80+
81+
body += message["body"]
82+
83+
# Skip body chunks until all chunks have been received
84+
if message.get("more_body"):
85+
return
86+
87+
headers = MutableHeaders(scope=start_message)
88+
89+
# Transform the JSON body
90+
if body:
91+
data = json.loads(body)
92+
transformed = self.transform_json(data)
93+
body = json.dumps(transformed).encode()
94+
95+
# Update content-length header
96+
headers["content-length"] = str(len(body))
97+
assert start_message, "Expected start_message to be set"
98+
start_message["headers"] = [
99+
(key.encode(), value.encode()) for key, value in headers.items()
100+
]
101+
102+
# Send response
103+
await send(start_message)
104+
await send(
105+
{
106+
"type": "http.response.body",
107+
"body": body,
108+
"more_body": False,
109+
}
110+
)
111+
112+
return await self.app(scope, receive, process_message)

0 commit comments

Comments
 (0)