diff --git a/stac_fastapi/html/middleware.py b/stac_fastapi/html/middleware.py
index 80720c9..bcc91a6 100644
--- a/stac_fastapi/html/middleware.py
+++ b/stac_fastapi/html/middleware.py
@@ -25,17 +25,17 @@
DEFAULT_TEMPLATES = Jinja2Templates(env=jinja2_env)
ENDPOINT_TEMPLATES = {
- # endpoint Name: template name
- "Landing Page": "landing",
- "Conformance Classes": "conformances",
- "Get Collections": "collections",
- "Get Collection": "collection",
- "Get ItemCollection": "items",
- "Get Item": "item",
- "Search": "search",
+ # endpoint Name (lower case): template name
+ "landing page": "landing",
+ "conformance classes": "conformances",
+ "get collections": "collections",
+ "get collection": "collection",
+ "get itemcollection": "items",
+ "get item": "item",
+ "search": "search",
# Extensions
- "Queryables": "queryables",
- "Collection Queryables": "queryables",
+ "queryables": "queryables",
+ "collection queryables": "queryables",
}
@@ -84,7 +84,7 @@ class HTMLRenderMiddleware:
app: ASGIApp
templates: Jinja2Templates = field(default_factory=lambda: DEFAULT_TEMPLATES)
- endpoints_names: dict[str, str] = field(default_factory=lambda: ENDPOINT_TEMPLATES)
+ endpoint_names: dict[str, str] = field(default_factory=lambda: ENDPOINT_TEMPLATES)
def create_html_response(
self,
@@ -153,7 +153,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): # noqa: C
start_message: Message
body = b""
- async def send_as_html(message: Message):
+ async def send_as_html(message: Message): # noqa: C901
nonlocal start_message
nonlocal body
@@ -182,19 +182,63 @@ async def send_as_html(message: Message):
) and not request.query_params.get("f", ""):
encode_to_html = True
- if start_message["status"] == 200 and encode_to_html:
+ response_headers = MutableHeaders(scope=start_message)
+
+ # stac-fastapi application overwrite the content-type for
+ # openapi response and use "application/vnd.oai.openapi+json;version=3.0"
+ if (
+ response_headers.get("Content-Type")
+ == "application/vnd.oai.openapi+json;version=3.0"
+ ):
+ openapi_doc = json.loads(body.decode())
+ for _, path in openapi_doc.get("paths").items():
+ if (
+ path.get("get", {}).get("summary", "").lower()
+ in self.endpoint_names
+ ):
+ if "parameters" not in path["get"]:
+ path["get"]["parameters"] = []
+
+ path["get"]["parameters"].append(
+ {
+ "name": "f",
+ "in": "query",
+ "required": False,
+ "schema": {
+ "anyOf": [
+ {
+ "enum": [
+ "html",
+ ],
+ "type": "string",
+ },
+ {"type": "null"},
+ ],
+ "description": "Response MediaType.",
+ "title": "F",
+ },
+ "description": "Response MediaType.",
+ }
+ )
+ path["get"]["responses"]["200"]["content"].update(
+ {"text/html": {}}
+ )
+
+ body = json.dumps(openapi_doc).encode("utf-8")
+ response_headers["Content-Length"] = str(len(body))
+
+ elif start_message["status"] == 200 and encode_to_html:
# NOTE: `scope["route"]` seems to be specific to FastAPI application
if route := scope.get("route"):
- if tpl := self.endpoints_names.get(route.name):
+ if tpl := self.endpoint_names.get(route.name.lower()):
body = self.create_html_response(
request,
json.loads(body.decode()),
template_name=tpl,
title=route.name,
)
- headers = MutableHeaders(scope=start_message)
- headers["Content-Type"] = "text/html"
- headers["Content-Length"] = str(len(body))
+ response_headers["Content-Type"] = "text/html"
+ response_headers["Content-Length"] = str(len(body))
# Send http.response.start
await send(start_message)
diff --git a/tests/test_middleware.py b/tests/test_middleware.py
index bde53f2..90db42c 100644
--- a/tests/test_middleware.py
+++ b/tests/test_middleware.py
@@ -7,7 +7,8 @@
import pytest
from fastapi import FastAPI
from starlette.requests import Request
-from starlette.responses import JSONResponse
+from starlette.responses import JSONResponse, Response
+from starlette.routing import Route, request_response
from starlette.testclient import TestClient
from stac_fastapi.html.middleware import HTMLRenderMiddleware, preferred_encoding
@@ -57,10 +58,45 @@ def test_get_compression_backend(header, expected):
@pytest.fixture
def client(): # noqa: C901
+ # Ref: https://github.com/stac-utils/stac-fastapi/blob/20ae9cfaf87ed892ef3235d979892e7e24c63fc0/stac_fastapi/api/stac_fastapi/api/openapi.py
+ def update_openapi(app: FastAPI) -> FastAPI:
+ """Update OpenAPI response content-type.
+
+ This function modifies the openapi route to comply with the STAC API spec's required
+ content-type response header.
+ """
+ # Find the route for the openapi_url in the app
+ openapi_route: Route = next(
+ route for route in app.router.routes if route.path == app.openapi_url
+ )
+ # Store the old endpoint function so we can call it from the patched function
+ old_endpoint = openapi_route.endpoint
+
+ # Create a patched endpoint function that modifies the content type of the response
+ async def patched_openapi_endpoint(req: Request) -> Response:
+ # Get the response from the old endpoint function
+ response: JSONResponse = await old_endpoint(req)
+ # Update the content type header in place
+ response.headers["content-type"] = (
+ "application/vnd.oai.openapi+json;version=3.0"
+ )
+ # Return the updated response
+ return response
+
+ # When a Route is accessed the `handle` function calls `self.app`. Which is
+ # the endpoint function wrapped with `request_response`. So we need to wrap
+ # our patched function and replace the existing app with it.
+ openapi_route.app = request_response(patched_openapi_endpoint)
+
+ # return the patched app
+ return app
+
app = FastAPI(
openapi_url="/api",
docs_url="/api.html",
)
+ update_openapi(app)
+
app.add_middleware(HTMLRenderMiddleware)
@app.get("/", name="Landing Page")
@@ -135,9 +171,11 @@ def test_html_middleware(client):
response = client.post("/search", headers={"Accept": "text/html"})
assert response.headers["Content-Type"] == "application/geo+json"
- # No influence on endpoint outside stac-fastapi scope
+ # No influence on endpoint outside scope
response = client.get("/api", headers={"Accept": "text/html"})
- assert response.headers["Content-Type"] == "application/json"
+ assert (
+ response.headers["Content-Type"] == "application/vnd.oai.openapi+json;version=3.0"
+ )
@pytest.mark.parametrize(
@@ -157,3 +195,27 @@ def test_html_middleware(client):
def test_all_routes(client, route, accept, result):
response = client.get(route, headers={"accept": accept})
assert response.headers["Content-Type"] == result
+
+
+def test_openapi_override(client):
+ """Test OpenAPI update."""
+ response = client.get("/api", headers={"Accept": "text/html"})
+ assert (
+ response.headers["Content-Type"] == "application/vnd.oai.openapi+json;version=3.0"
+ )
+ body = response.json()
+
+ for endpoint in [
+ "/",
+ "/conformance",
+ "/collections",
+ "/collections/{collectionId}",
+ "/collections/{collectionId}/items",
+ "/collections/{collectionId}/items/{itemId}",
+ "/search",
+ "/queryables",
+ "/collections/{collectionId}/queryables",
+ ]:
+ path = body["paths"][endpoint]
+ assert next(filter(lambda p: p["name"] == "f", path["get"]["parameters"]))
+ assert "text/html" in path["get"]["responses"]["200"]["content"]