Skip to content

update openapi response to add html parameters #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 61 additions & 17 deletions stac_fastapi/html/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@
DEFAULT_TEMPLATES = Jinja2Templates(env=jinja2_env)

ENDPOINT_TEMPLATES = {
# endpoint Name: template name
"Landing Page": "landing",
"Conformance Classes": "conformances",
"Get Collections": "collections",
"Get Collection": "collection",
"Get ItemCollection": "items",
"Get Item": "item",
"Search": "search",
# endpoint Name (lower case): template name
"landing page": "landing",
"conformance classes": "conformances",
"get collections": "collections",
"get collection": "collection",
"get itemcollection": "items",
"get item": "item",
"search": "search",
# Extensions
"Queryables": "queryables",
"Collection Queryables": "queryables",
"queryables": "queryables",
"collection queryables": "queryables",
}


Expand Down Expand Up @@ -84,7 +84,7 @@ class HTMLRenderMiddleware:

app: ASGIApp
templates: Jinja2Templates = field(default_factory=lambda: DEFAULT_TEMPLATES)
endpoints_names: dict[str, str] = field(default_factory=lambda: ENDPOINT_TEMPLATES)
endpoint_names: dict[str, str] = field(default_factory=lambda: ENDPOINT_TEMPLATES)

def create_html_response(
self,
Expand Down Expand Up @@ -153,7 +153,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): # noqa: C
start_message: Message
body = b""

async def send_as_html(message: Message):
async def send_as_html(message: Message): # noqa: C901
nonlocal start_message
nonlocal body

Expand Down Expand Up @@ -182,19 +182,63 @@ async def send_as_html(message: Message):
) and not request.query_params.get("f", ""):
encode_to_html = True

if start_message["status"] == 200 and encode_to_html:
response_headers = MutableHeaders(scope=start_message)

# stac-fastapi application overwrite the content-type for
# openapi response and use "application/vnd.oai.openapi+json;version=3.0"
if (
response_headers.get("Content-Type")
== "application/vnd.oai.openapi+json;version=3.0"
):
openapi_doc = json.loads(body.decode())
for _, path in openapi_doc.get("paths").items():
if (
path.get("get", {}).get("summary", "").lower()
Copy link
Member Author

@vincentsarago vincentsarago Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

openapi will use camel case for the summary, so by using lower case we make sure the name can match the one provided in endpoint_names

in self.endpoint_names
):
if "parameters" not in path["get"]:
path["get"]["parameters"] = []

path["get"]["parameters"].append(
{
"name": "f",
"in": "query",
"required": False,
"schema": {
"anyOf": [
{
"enum": [
"html",
],
"type": "string",
},
{"type": "null"},
],
"description": "Response MediaType.",
"title": "F",
},
"description": "Response MediaType.",
}
)
path["get"]["responses"]["200"]["content"].update(
{"text/html": {}}
)

body = json.dumps(openapi_doc).encode("utf-8")
response_headers["Content-Length"] = str(len(body))

elif start_message["status"] == 200 and encode_to_html:
# NOTE: `scope["route"]` seems to be specific to FastAPI application
if route := scope.get("route"):
if tpl := self.endpoints_names.get(route.name):
if tpl := self.endpoint_names.get(route.name.lower()):
body = self.create_html_response(
request,
json.loads(body.decode()),
template_name=tpl,
title=route.name,
)
headers = MutableHeaders(scope=start_message)
headers["Content-Type"] = "text/html"
headers["Content-Length"] = str(len(body))
response_headers["Content-Type"] = "text/html"
response_headers["Content-Length"] = str(len(body))

# Send http.response.start
await send(start_message)
Expand Down
68 changes: 65 additions & 3 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest
from fastapi import FastAPI
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.responses import JSONResponse, Response
from starlette.routing import Route, request_response
from starlette.testclient import TestClient

from stac_fastapi.html.middleware import HTMLRenderMiddleware, preferred_encoding
Expand Down Expand Up @@ -57,10 +58,45 @@ def test_get_compression_backend(header, expected):

@pytest.fixture
def client(): # noqa: C901
# Ref: https://github.com/stac-utils/stac-fastapi/blob/20ae9cfaf87ed892ef3235d979892e7e24c63fc0/stac_fastapi/api/stac_fastapi/api/openapi.py
def update_openapi(app: FastAPI) -> FastAPI:
"""Update OpenAPI response content-type.

This function modifies the openapi route to comply with the STAC API spec's required
content-type response header.
"""
# Find the route for the openapi_url in the app
openapi_route: Route = next(
route for route in app.router.routes if route.path == app.openapi_url
)
# Store the old endpoint function so we can call it from the patched function
old_endpoint = openapi_route.endpoint

# Create a patched endpoint function that modifies the content type of the response
async def patched_openapi_endpoint(req: Request) -> Response:
# Get the response from the old endpoint function
response: JSONResponse = await old_endpoint(req)
# Update the content type header in place
response.headers["content-type"] = (
"application/vnd.oai.openapi+json;version=3.0"
)
# Return the updated response
return response

# When a Route is accessed the `handle` function calls `self.app`. Which is
# the endpoint function wrapped with `request_response`. So we need to wrap
# our patched function and replace the existing app with it.
openapi_route.app = request_response(patched_openapi_endpoint)

# return the patched app
return app

app = FastAPI(
openapi_url="/api",
docs_url="/api.html",
)
update_openapi(app)

app.add_middleware(HTMLRenderMiddleware)

@app.get("/", name="Landing Page")
Expand Down Expand Up @@ -135,9 +171,11 @@ def test_html_middleware(client):
response = client.post("/search", headers={"Accept": "text/html"})
assert response.headers["Content-Type"] == "application/geo+json"

# No influence on endpoint outside stac-fastapi scope
# No influence on endpoint outside scope
response = client.get("/api", headers={"Accept": "text/html"})
assert response.headers["Content-Type"] == "application/json"
assert (
response.headers["Content-Type"] == "application/vnd.oai.openapi+json;version=3.0"
)


@pytest.mark.parametrize(
Expand All @@ -157,3 +195,27 @@ def test_html_middleware(client):
def test_all_routes(client, route, accept, result):
response = client.get(route, headers={"accept": accept})
assert response.headers["Content-Type"] == result


def test_openapi_override(client):
"""Test OpenAPI update."""
response = client.get("/api", headers={"Accept": "text/html"})
assert (
response.headers["Content-Type"] == "application/vnd.oai.openapi+json;version=3.0"
)
body = response.json()

for endpoint in [
"/",
"/conformance",
"/collections",
"/collections/{collectionId}",
"/collections/{collectionId}/items",
"/collections/{collectionId}/items/{itemId}",
"/search",
"/queryables",
"/collections/{collectionId}/queryables",
]:
path = body["paths"][endpoint]
assert next(filter(lambda p: p["name"] == "f", path["get"]["parameters"]))
assert "text/html" in path["get"]["responses"]["200"]["content"]
Loading