1
1
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
2
2
3
- import json
4
3
from dataclasses import dataclass
5
- from typing import Any , Optional
4
+ from typing import Any
6
5
7
- from starlette .datastructures import MutableHeaders
8
6
from starlette .requests import Request
9
- from starlette .types import ASGIApp , Message , Receive , Scope , Send
7
+ from starlette .types import ASGIApp
10
8
11
9
from ..config import EndpointMethods
12
- from ..utils .requests import dict_to_bytes , find_match
10
+ from ..utils .middleware import JsonResponseMiddleware
11
+ from ..utils .requests import find_match
13
12
14
13
15
14
@dataclass (frozen = True )
16
- class OpenApiMiddleware :
15
+ class OpenApiMiddleware ( JsonResponseMiddleware ) :
17
16
"""Middleware to add the OpenAPI spec to the response."""
18
17
19
18
app : ASGIApp
@@ -24,61 +23,11 @@ class OpenApiMiddleware:
24
23
default_public : bool
25
24
oidc_auth_scheme_name : str = "oidcAuth"
26
25
27
- async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
28
- """Add the OpenAPI spec to the response."""
29
- if scope ["type" ] != "http" or Request (scope ).url .path != self .openapi_spec_path :
30
- return await self .app (scope , receive , send )
26
+ def should_transform_response (self , request : Request ) -> bool :
27
+ """Only transform responses for the OpenAPI spec path."""
28
+ return request .url .path == self .openapi_spec_path
31
29
32
- start_message : Optional [Message ] = None
33
- body = b""
34
-
35
- async def augment_oidc_spec (message : Message ):
36
- nonlocal start_message
37
- nonlocal body
38
- if message ["type" ] == "http.response.start" :
39
- # NOTE: Because we are modifying the response body, we will need to update
40
- # the content-length header. However, headers are sent before we see the
41
- # body. To handle this, we delay sending the http.response.start message
42
- # until after we alter the body.
43
- start_message = message
44
- return
45
- elif message ["type" ] != "http.response.body" :
46
- return await send (message )
47
-
48
- body += message ["body" ]
49
-
50
- # Skip body chunks until all chunks have been received
51
- if message .get ("more_body" ):
52
- return
53
-
54
- # Maybe decompress the body
55
- headers = MutableHeaders (scope = start_message )
56
-
57
- # Augment the spec
58
- body = dict_to_bytes (self .augment_spec (json .loads (body )))
59
-
60
- # Update the content-length header
61
- headers ["content-length" ] = str (len (body ))
62
- assert start_message , "Expected start_message to be set"
63
- start_message ["headers" ] = [
64
- (key .encode (), value .encode ()) for key , value in headers .items ()
65
- ]
66
-
67
- # Send http.response.start
68
- await send (start_message )
69
-
70
- # Send http.response.body
71
- await send (
72
- {
73
- "type" : "http.response.body" ,
74
- "body" : body ,
75
- "more_body" : False ,
76
- }
77
- )
78
-
79
- return await self .app (scope , receive , augment_oidc_spec )
80
-
81
- def augment_spec (self , openapi_spec ) -> dict [str , Any ]:
30
+ def transform_json (self , openapi_spec : dict [str , Any ]) -> dict [str , Any ]:
82
31
"""Augment the OpenAPI spec with auth information."""
83
32
components = openapi_spec .setdefault ("components" , {})
84
33
securitySchemes = components .setdefault ("securitySchemes" , {})
0 commit comments