diff --git a/litestar/middleware/_utils.py b/litestar/middleware/_utils.py index 9ffdbd9066..8bf1d4d337 100644 --- a/litestar/middleware/_utils.py +++ b/litestar/middleware/_utils.py @@ -44,6 +44,15 @@ def build_exclude_path_pattern( ) from e +def should_bypass_for_path_pattern(scope: Scope, pattern: Pattern | None = None) -> bool: + return bool( + pattern + and pattern.findall( + scope["raw_path"].decode() if getattr(scope.get("route_handler", {}), "is_mount", False) else scope["path"] + ) + ) + + def should_bypass_middleware( *, exclude_http_methods: Sequence[Method] | None = None, @@ -73,9 +82,4 @@ def should_bypass_middleware( if exclude_http_methods and scope.get("method") in exclude_http_methods: return True - return bool( - exclude_path_pattern - and exclude_path_pattern.findall( - scope["raw_path"].decode() if getattr(scope.get("route_handler", {}), "is_mount", False) else scope["path"] - ) - ) + return should_bypass_for_path_pattern(scope, exclude_path_pattern) diff --git a/litestar/middleware/base.py b/litestar/middleware/base.py index 1d85bcbb38..f11488d93c 100644 --- a/litestar/middleware/base.py +++ b/litestar/middleware/base.py @@ -1,12 +1,14 @@ from __future__ import annotations import abc +import warnings from abc import abstractmethod from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable from litestar.enums import ScopeType from litestar.middleware._utils import ( build_exclude_path_pattern, + should_bypass_for_path_pattern, should_bypass_middleware, ) from litestar.utils.deprecation import warn_deprecation @@ -217,22 +219,68 @@ async def handle( exclude_path_pattern: str | tuple[str, ...] | None = None exclude_opt_key: str | None = None + should_bypass_for_scope: Callable[[Scope], bool] | None = None + r""" + A callable that takes in the :class:`~litestar.types.Scope` of the current + connection and returns a boolean, indicating if the middleware should be skipped for + the current request. + + This can for example be used to exclude a middleware based on a dynamic path:: + + should_bypass_for_scope = lambda scope: scope["path"].endswith(".jpg") + + Applied to a route with a dynamic path like ``/static/{file_name:str}``, it would + be skipped *only* if ``file_name`` has a ``.jpg`` extension. + + .. versionadded:: 2.19 + """ + def __call__(self, app: ASGIApp) -> ASGIApp: """Create the actual middleware callable""" handle = self.handle exclude_pattern = build_exclude_path_pattern(exclude=self.exclude_path_pattern, middleware_cls=type(self)) scopes = set(self.scopes) exclude_opt_key = self.exclude_opt_key + should_bypass_for_scope = self.should_bypass_for_scope + + def exclude_pattern_matches_handler_path(scope: Scope) -> bool: + handler = scope["route_handler"] + return any(exclude_pattern.search(path) for path in handler.paths) async def middleware(scope: Scope, receive: Receive, send: Send) -> None: - if should_bypass_middleware( - scope=scope, - scopes=scopes, # type: ignore[arg-type] - exclude_opt_key=exclude_opt_key, - exclude_path_pattern=exclude_pattern, + path_excluded = False + if ( + should_bypass_middleware( + scope=scope, + scopes=scopes, # type: ignore[arg-type] + exclude_opt_key=exclude_opt_key, + ) + or (path_excluded := should_bypass_for_path_pattern(scope, exclude_pattern)) + or (should_bypass_for_scope and should_bypass_for_scope(scope)) ): + if path_excluded and exclude_pattern is not None and not exclude_pattern_matches_handler_path(scope): + warnings.warn( + f"{type(self).__name__}.exclude_path_pattern={exclude_pattern.pattern!r} " + "did match the request path but did not match the route " + "handler's path. When upgrading to Litestar 3, this middleware " + "would NOT be excluded. To keep the current behaviour, use " + "'should_bypass_for_scope' instead.", + category=DeprecationWarning, + stacklevel=2, + ) + await app(scope, receive, send) else: + if exclude_pattern is not None and exclude_pattern_matches_handler_path(scope): + warnings.warn( + f"{type(self).__name__}.exclude_path_pattern={exclude_pattern.pattern!r} " + "did not match the request path but did match the route " + "handler's path. When upgrading to Litestar 3, this middleware " + "would be excluded. To keep the current behaviour, use " + "'should_bypass_for_scope' instead.", + category=DeprecationWarning, + stacklevel=2, + ) await handle(scope=scope, receive=receive, send=send, next_app=app) return middleware diff --git a/tests/unit/test_middleware/test_base_middleware.py b/tests/unit/test_middleware/test_base_middleware.py index cc0dece19f..d08d043bb8 100644 --- a/tests/unit/test_middleware/test_base_middleware.py +++ b/tests/unit/test_middleware/test_base_middleware.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, List, Tuple, Union +from unittest.mock import MagicMock from warnings import catch_warnings import pytest @@ -285,6 +286,61 @@ def third_handler() -> dict: assert "test" in response.headers +def test_asgi_middleware_should_exclude_scope() -> None: + mock = MagicMock() + + class SubclassMiddleware(ASGIMiddleware): + @staticmethod + def should_bypass_for_scope(scope: "Scope") -> bool: + return scope["path"].endswith(".jpg") + + async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None: + mock(scope["path"]) + await next_app(scope, receive, send) + + @get("/{file_name:str}") + def handler(file_name: str) -> str: + return file_name + + with create_test_client([handler], middleware=[SubclassMiddleware()]) as client: + assert client.get("/test.txt").status_code == 200 + assert client.get("/test.jpg").status_code == 200 + + mock.assert_called_once_with("/test.txt") + + +def test_asgi_middleware_path_exclude_warns_future_use() -> None: + mock = MagicMock() + + class SubclassMiddleware(ASGIMiddleware): + def __init__(self, pattern: str) -> None: + self.exclude_path_pattern = pattern + + async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None: + mock(scope["path"]) + await next_app(scope, receive, send) + + @get("/{file_name:str}") + def handler(file_name: str) -> str: + return file_name + + # this configuration would NOT be excluded in the future + with create_test_client([handler], middleware=[SubclassMiddleware(".jpg")]) as client: + with pytest.warns(DeprecationWarning, match=".*exclude_path_pattern.* did match the request path"): + assert client.get("/test.jpg").status_code == 200 + + assert client.get("/test.txt").status_code == 200 + mock.assert_called_once_with("/test.txt") + + mock.reset_mock() + + # this configuration WOULD be excluded in the future + with create_test_client([handler], middleware=[SubclassMiddleware("str")]) as client: + with pytest.warns(DeprecationWarning, match=".*exclude_path_pattern.* did not match the request path"): + assert client.get("/test.jpg").status_code == 200 + mock.assert_called_once_with("/test.jpg") + + @pytest.mark.parametrize("excludes", ["/", ("/", "/foo"), "/*", "/.*"]) def test_asgi_middleware_exclude_by_pattern_warns_if_exclude_all(excludes: Union[str, Tuple[str, ...]]) -> None: class SubclassMiddleware(ASGIMiddleware):