Skip to content

Commit d8cd9f3

Browse files
Merge pull request #7 from developmentseed/feature/update-openapi
update openapi response to add html parameters
2 parents 3f731eb + d9982e9 commit d8cd9f3

File tree

2 files changed

+126
-20
lines changed

2 files changed

+126
-20
lines changed

stac_fastapi/html/middleware.py

+61-17
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@
2525
DEFAULT_TEMPLATES = Jinja2Templates(env=jinja2_env)
2626

2727
ENDPOINT_TEMPLATES = {
28-
# endpoint Name: template name
29-
"Landing Page": "landing",
30-
"Conformance Classes": "conformances",
31-
"Get Collections": "collections",
32-
"Get Collection": "collection",
33-
"Get ItemCollection": "items",
34-
"Get Item": "item",
35-
"Search": "search",
28+
# endpoint Name (lower case): template name
29+
"landing page": "landing",
30+
"conformance classes": "conformances",
31+
"get collections": "collections",
32+
"get collection": "collection",
33+
"get itemcollection": "items",
34+
"get item": "item",
35+
"search": "search",
3636
# Extensions
37-
"Queryables": "queryables",
38-
"Collection Queryables": "queryables",
37+
"queryables": "queryables",
38+
"collection queryables": "queryables",
3939
}
4040

4141

@@ -84,7 +84,7 @@ class HTMLRenderMiddleware:
8484

8585
app: ASGIApp
8686
templates: Jinja2Templates = field(default_factory=lambda: DEFAULT_TEMPLATES)
87-
endpoints_names: dict[str, str] = field(default_factory=lambda: ENDPOINT_TEMPLATES)
87+
endpoint_names: dict[str, str] = field(default_factory=lambda: ENDPOINT_TEMPLATES)
8888

8989
def create_html_response(
9090
self,
@@ -153,7 +153,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): # noqa: C
153153
start_message: Message
154154
body = b""
155155

156-
async def send_as_html(message: Message):
156+
async def send_as_html(message: Message): # noqa: C901
157157
nonlocal start_message
158158
nonlocal body
159159

@@ -182,19 +182,63 @@ async def send_as_html(message: Message):
182182
) and not request.query_params.get("f", ""):
183183
encode_to_html = True
184184

185-
if start_message["status"] == 200 and encode_to_html:
185+
response_headers = MutableHeaders(scope=start_message)
186+
187+
# stac-fastapi application overwrite the content-type for
188+
# openapi response and use "application/vnd.oai.openapi+json;version=3.0"
189+
if (
190+
response_headers.get("Content-Type")
191+
== "application/vnd.oai.openapi+json;version=3.0"
192+
):
193+
openapi_doc = json.loads(body.decode())
194+
for _, path in openapi_doc.get("paths").items():
195+
if (
196+
path.get("get", {}).get("summary", "").lower()
197+
in self.endpoint_names
198+
):
199+
if "parameters" not in path["get"]:
200+
path["get"]["parameters"] = []
201+
202+
path["get"]["parameters"].append(
203+
{
204+
"name": "f",
205+
"in": "query",
206+
"required": False,
207+
"schema": {
208+
"anyOf": [
209+
{
210+
"enum": [
211+
"html",
212+
],
213+
"type": "string",
214+
},
215+
{"type": "null"},
216+
],
217+
"description": "Response MediaType.",
218+
"title": "F",
219+
},
220+
"description": "Response MediaType.",
221+
}
222+
)
223+
path["get"]["responses"]["200"]["content"].update(
224+
{"text/html": {}}
225+
)
226+
227+
body = json.dumps(openapi_doc).encode("utf-8")
228+
response_headers["Content-Length"] = str(len(body))
229+
230+
elif start_message["status"] == 200 and encode_to_html:
186231
# NOTE: `scope["route"]` seems to be specific to FastAPI application
187232
if route := scope.get("route"):
188-
if tpl := self.endpoints_names.get(route.name):
233+
if tpl := self.endpoint_names.get(route.name.lower()):
189234
body = self.create_html_response(
190235
request,
191236
json.loads(body.decode()),
192237
template_name=tpl,
193238
title=route.name,
194239
)
195-
headers = MutableHeaders(scope=start_message)
196-
headers["Content-Type"] = "text/html"
197-
headers["Content-Length"] = str(len(body))
240+
response_headers["Content-Type"] = "text/html"
241+
response_headers["Content-Length"] = str(len(body))
198242

199243
# Send http.response.start
200244
await send(start_message)

tests/test_middleware.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import pytest
88
from fastapi import FastAPI
99
from starlette.requests import Request
10-
from starlette.responses import JSONResponse
10+
from starlette.responses import JSONResponse, Response
11+
from starlette.routing import Route, request_response
1112
from starlette.testclient import TestClient
1213

1314
from stac_fastapi.html.middleware import HTMLRenderMiddleware, preferred_encoding
@@ -57,10 +58,45 @@ def test_get_compression_backend(header, expected):
5758

5859
@pytest.fixture
5960
def client(): # noqa: C901
61+
# Ref: https://github.com/stac-utils/stac-fastapi/blob/20ae9cfaf87ed892ef3235d979892e7e24c63fc0/stac_fastapi/api/stac_fastapi/api/openapi.py
62+
def update_openapi(app: FastAPI) -> FastAPI:
63+
"""Update OpenAPI response content-type.
64+
65+
This function modifies the openapi route to comply with the STAC API spec's required
66+
content-type response header.
67+
"""
68+
# Find the route for the openapi_url in the app
69+
openapi_route: Route = next(
70+
route for route in app.router.routes if route.path == app.openapi_url
71+
)
72+
# Store the old endpoint function so we can call it from the patched function
73+
old_endpoint = openapi_route.endpoint
74+
75+
# Create a patched endpoint function that modifies the content type of the response
76+
async def patched_openapi_endpoint(req: Request) -> Response:
77+
# Get the response from the old endpoint function
78+
response: JSONResponse = await old_endpoint(req)
79+
# Update the content type header in place
80+
response.headers["content-type"] = (
81+
"application/vnd.oai.openapi+json;version=3.0"
82+
)
83+
# Return the updated response
84+
return response
85+
86+
# When a Route is accessed the `handle` function calls `self.app`. Which is
87+
# the endpoint function wrapped with `request_response`. So we need to wrap
88+
# our patched function and replace the existing app with it.
89+
openapi_route.app = request_response(patched_openapi_endpoint)
90+
91+
# return the patched app
92+
return app
93+
6094
app = FastAPI(
6195
openapi_url="/api",
6296
docs_url="/api.html",
6397
)
98+
update_openapi(app)
99+
64100
app.add_middleware(HTMLRenderMiddleware)
65101

66102
@app.get("/", name="Landing Page")
@@ -135,9 +171,11 @@ def test_html_middleware(client):
135171
response = client.post("/search", headers={"Accept": "text/html"})
136172
assert response.headers["Content-Type"] == "application/geo+json"
137173

138-
# No influence on endpoint outside stac-fastapi scope
174+
# No influence on endpoint outside scope
139175
response = client.get("/api", headers={"Accept": "text/html"})
140-
assert response.headers["Content-Type"] == "application/json"
176+
assert (
177+
response.headers["Content-Type"] == "application/vnd.oai.openapi+json;version=3.0"
178+
)
141179

142180

143181
@pytest.mark.parametrize(
@@ -157,3 +195,27 @@ def test_html_middleware(client):
157195
def test_all_routes(client, route, accept, result):
158196
response = client.get(route, headers={"accept": accept})
159197
assert response.headers["Content-Type"] == result
198+
199+
200+
def test_openapi_override(client):
201+
"""Test OpenAPI update."""
202+
response = client.get("/api", headers={"Accept": "text/html"})
203+
assert (
204+
response.headers["Content-Type"] == "application/vnd.oai.openapi+json;version=3.0"
205+
)
206+
body = response.json()
207+
208+
for endpoint in [
209+
"/",
210+
"/conformance",
211+
"/collections",
212+
"/collections/{collectionId}",
213+
"/collections/{collectionId}/items",
214+
"/collections/{collectionId}/items/{itemId}",
215+
"/search",
216+
"/queryables",
217+
"/collections/{collectionId}/queryables",
218+
]:
219+
path = body["paths"][endpoint]
220+
assert next(filter(lambda p: p["name"] == "f", path["get"]["parameters"]))
221+
assert "text/html" in path["get"]["responses"]["200"]["content"]

0 commit comments

Comments
 (0)