11from typing import TYPE_CHECKING , Union
2+ from unittest .mock import MagicMock , call
23from warnings import catch_warnings
34
45import pytest
56
6- from litestar import MediaType , asgi , get
7+ from litestar import MediaType , WebSocket , asgi , get , websocket
78from litestar .datastructures .headers import MutableScopeHeaders
9+ from litestar .enums import ScopeType
810from litestar .exceptions import LitestarWarning , ValidationException
911from litestar .middleware import AbstractMiddleware , ASGIMiddleware , DefineMiddleware
1012from litestar .response .base import ASGIResponse
@@ -211,19 +213,64 @@ def handler() -> dict:
211213 assert response .status_code == HTTP_400_BAD_REQUEST
212214
213215
216+ @pytest .mark .parametrize (
217+ "allowed_scopes,expected_calls" ,
218+ [
219+ ((ScopeType .HTTP ,), ["/http" ]),
220+ ((ScopeType .ASGI ,), ["/asgi" ]),
221+ ((ScopeType .WEBSOCKET ,), ["/ws" ]),
222+ ((ScopeType .HTTP , ScopeType .ASGI ), ["/http" , "/asgi" ]),
223+ ((ScopeType .HTTP , ScopeType .WEBSOCKET ), ["/http" , "/ws" ]),
224+ ((ScopeType .ASGI , ScopeType .WEBSOCKET ), ["/asgi" , "/ws" ]),
225+ ],
226+ )
227+ def test_asgi_middleware_exclude_by_scope_type (
228+ allowed_scopes : tuple [ScopeType , ...], expected_calls : list [str ]
229+ ) -> None :
230+ mock = MagicMock ()
231+
232+ class SubclassMiddleware (ASGIMiddleware ):
233+ scopes = allowed_scopes
234+
235+ async def handle (self , scope : "Scope" , receive : "Receive" , send : "Send" , next_app : "ASGIApp" ) -> None :
236+ mock (scope ["path" ])
237+ await next_app (scope , receive , send )
238+
239+ @get ("/http" )
240+ def http_handler () -> None :
241+ return None
242+
243+ @websocket ("/ws" )
244+ async def websocket_handler (socket : WebSocket ) -> None :
245+ await socket .accept ()
246+ await socket .close ()
247+
248+ @asgi ("/asgi" )
249+ async def asgi_handler (scope : "Scope" , receive : "Receive" , send : "Send" ) -> None :
250+ response = ASGIResponse (body = b"ok" , media_type = MediaType .TEXT )
251+ await response (scope , receive , send )
252+
253+ with create_test_client (
254+ [http_handler , asgi_handler , websocket_handler ], middleware = [SubclassMiddleware ()]
255+ ) as client :
256+ assert client .get ("/http" ).status_code == 200
257+ assert client .get ("/asgi" ).status_code == 200
258+ with client .websocket_connect ("/ws" ):
259+ pass
260+
261+ mock .assert_has_calls ([call (path ) for path in expected_calls ])
262+
263+
214264def test_asgi_middleware_exclude_by_pattern () -> None :
265+ mock = MagicMock ()
266+
215267 class SubclassMiddleware (ASGIMiddleware ):
216268 def __init__ (self ) -> None :
217269 self .exclude_path_pattern = r"^/123"
218270
219271 async def handle (self , scope : "Scope" , receive : "Receive" , send : "Send" , next_app : "ASGIApp" ) -> None :
220- async def _send (message : "Message" ) -> None :
221- if message ["type" ] == "http.response.start" :
222- headers = MutableScopeHeaders (message )
223- headers .add ("test" , str (123 ))
224- await send (message )
225-
226- await next_app (scope , receive , _send )
272+ mock (scope ["raw_path" ].decode ())
273+ await next_app (scope , receive , send )
227274
228275 @get ("/123" )
229276 def first_handler () -> dict :
@@ -239,28 +286,22 @@ async def handler(scope: "Scope", receive: "Receive", send: "Send") -> None:
239286 await response (scope , receive , send )
240287
241288 with create_test_client ([first_handler , second_handler , handler ], middleware = [SubclassMiddleware ()]) as client :
242- response = client .get ("/123" )
243- assert "test" not in response .headers
244-
245- response = client .get ("/456" )
246- assert "test" in response .headers
289+ assert client .get ("/123" ).status_code == 200
290+ assert client .get ("/456" ).status_code == 200
291+ assert client .get ("/mount/123" ).status_code == 200
247292
248- response = client .get ("/mount/123" )
249- assert "test" in response .headers
293+ mock .assert_has_calls ([call ("/456" ), call ("/mount/123" )])
250294
251295
252296def test_asgi_middleware_exclude_by_pattern_tuple () -> None :
297+ mock = MagicMock ()
298+
253299 class SubclassMiddleware (ASGIMiddleware ):
254300 exclude_path_pattern = ("123" , "456" )
255301
256302 async def handle (self , scope : "Scope" , receive : "Receive" , send : "Send" , next_app : "ASGIApp" ) -> None :
257- async def _send (message : "Message" ) -> None :
258- if message ["type" ] == "http.response.start" :
259- headers = MutableScopeHeaders (message )
260- headers .add ("test" , str (123 ))
261- await send (message )
262-
263- await next_app (scope , receive , _send )
303+ mock (scope ["path" ])
304+ await next_app (scope , receive , send )
264305
265306 @get ("/123" )
266307 def first_handler () -> dict :
@@ -277,12 +318,31 @@ def third_handler() -> dict:
277318 with create_test_client (
278319 [first_handler , second_handler , third_handler ], middleware = [SubclassMiddleware ()]
279320 ) as client :
280- response = client .get ("/123" )
281- assert "test" not in response .headers
282- response = client .get ("/456" )
283- assert "test" not in response .headers
284- response = client .get ("/789" )
285- assert "test" in response .headers
321+ assert client .get ("/123" ).status_code == 200
322+ assert client .get ("/456" ).status_code == 200
323+ assert client .get ("/789" ).status_code == 200
324+
325+ mock .assert_called_once_with ("/789" )
326+
327+
328+ def test_asgi_middleware_exclude_dynamic_handler_by_pattern () -> None :
329+ mock = MagicMock ()
330+
331+ class SubclassMiddleware (ASGIMiddleware ):
332+ def __init__ (self ) -> None :
333+ self .exclude_path_pattern = r"^/foo/{bar" # use a pattern that ensures we match the raw handler path
334+
335+ async def handle (self , scope : "Scope" , receive : "Receive" , send : "Send" , next_app : "ASGIApp" ) -> None :
336+ mock ()
337+ await next_app (scope , receive , send )
338+
339+ @get ("/foo/{bar:int}" )
340+ def handler (bar : int ) -> None :
341+ return None
342+
343+ with create_test_client ([handler ], middleware = [SubclassMiddleware ()]) as client :
344+ assert client .get ("/foo/1" ).status_code == 200
345+ mock .assert_not_called ()
286346
287347
288348@pytest .mark .parametrize ("excludes" , ["/" , ("/" , "/foo" ), "/*" , "/.*" ])
@@ -310,22 +370,19 @@ async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_ap
310370
311371
312372def test_asgi_middleware_exclude_by_opt_key () -> None :
373+ mock = MagicMock ()
374+
313375 class SubclassMiddleware (ASGIMiddleware ):
314376 exclude_opt_key = "exclude_route"
315377
316378 async def handle (self , scope : "Scope" , receive : "Receive" , send : "Send" , next_app : "ASGIApp" ) -> None :
317- async def _send (message : "Message" ) -> None :
318- if message ["type" ] == "http.response.start" :
319- headers = MutableScopeHeaders (message )
320- headers .add ("test" , str (123 ))
321- await send (message )
322-
323- await next_app (scope , receive , send )
379+ mock ()
380+ await next_app (scope , receive , send )
324381
325382 @get ("/" , exclude_route = True )
326383 def handler () -> dict :
327384 return {"hello" : "world" }
328385
329386 with create_test_client (handler , middleware = [SubclassMiddleware ()]) as client :
330- response = client .get ("/" )
331- assert "test" not in response . headers
387+ assert client .get ("/" ). status_code == 200
388+ mock . assert_not_called ()
0 commit comments