Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions litestar/middleware/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def build_exclude_path_pattern(
) from e


def should_bypass_for_path_pattern(scope: Scope, pattern: Pattern | None = None) -> bool:
return bool(
pattern
and pattern.findall(
scope["raw_path"].decode() if getattr(scope.get("route_handler", {}), "is_mount", False) else scope["path"]
)
)


def should_bypass_middleware(
*,
exclude_http_methods: Sequence[Method] | None = None,
Expand Down Expand Up @@ -73,9 +82,4 @@ def should_bypass_middleware(
if exclude_http_methods and scope.get("method") in exclude_http_methods:
return True

return bool(
exclude_path_pattern
and exclude_path_pattern.findall(
scope["raw_path"].decode() if getattr(scope.get("route_handler", {}), "is_mount", False) else scope["path"]
)
)
return should_bypass_for_path_pattern(scope, exclude_path_pattern)
58 changes: 53 additions & 5 deletions litestar/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import abc
import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable

from litestar.enums import ScopeType
from litestar.middleware._utils import (
build_exclude_path_pattern,
should_bypass_for_path_pattern,
should_bypass_middleware,
)
from litestar.utils.deprecation import warn_deprecation
Expand Down Expand Up @@ -217,22 +219,68 @@ async def handle(
exclude_path_pattern: str | tuple[str, ...] | None = None
exclude_opt_key: str | None = None

should_bypass_for_scope: Callable[[Scope], bool] | None = None
r"""
A callable that takes in the :class:`~litestar.types.Scope` of the current
connection and returns a boolean, indicating if the middleware should be skipped for
the current request.

This can for example be used to exclude a middleware based on a dynamic path::

should_bypass_for_scope = lambda scope: scope["path"].endswith(".jpg")

Applied to a route with a dynamic path like ``/static/{file_name:str}``, it would
be skipped *only* if ``file_name`` has a ``.jpg`` extension.

.. versionadded:: 2.19
"""

def __call__(self, app: ASGIApp) -> ASGIApp:
"""Create the actual middleware callable"""
handle = self.handle
exclude_pattern = build_exclude_path_pattern(exclude=self.exclude_path_pattern, middleware_cls=type(self))
scopes = set(self.scopes)
exclude_opt_key = self.exclude_opt_key
should_bypass_for_scope = self.should_bypass_for_scope

def exclude_pattern_matches_handler_path(scope: Scope) -> bool:
handler = scope["route_handler"]
return any(exclude_pattern.search(path) for path in handler.paths)

async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
if should_bypass_middleware(
scope=scope,
scopes=scopes, # type: ignore[arg-type]
exclude_opt_key=exclude_opt_key,
exclude_path_pattern=exclude_pattern,
path_excluded = False
if (
should_bypass_middleware(
scope=scope,
scopes=scopes, # type: ignore[arg-type]
exclude_opt_key=exclude_opt_key,
)
or (path_excluded := should_bypass_for_path_pattern(scope, exclude_pattern))
or (should_bypass_for_scope and should_bypass_for_scope(scope))
):
if path_excluded and exclude_pattern is not None and not exclude_pattern_matches_handler_path(scope):
warnings.warn(
f"{type(self).__name__}.exclude_path_pattern={exclude_pattern.pattern!r} "
"did match the request path but did not match the route "
"handler's path. When upgrading to Litestar 3, this middleware "
"would NOT be excluded. To keep the current behaviour, use "
"'should_bypass_for_scope' instead.",
category=DeprecationWarning,
stacklevel=2,
)

await app(scope, receive, send)
else:
if exclude_pattern is not None and exclude_pattern_matches_handler_path(scope):
warnings.warn(
f"{type(self).__name__}.exclude_path_pattern={exclude_pattern.pattern!r} "
"did not match the request path but did match the route "
"handler's path. When upgrading to Litestar 3, this middleware "
"would be excluded. To keep the current behaviour, use "
"'should_bypass_for_scope' instead.",
category=DeprecationWarning,
stacklevel=2,
)
await handle(scope=scope, receive=receive, send=send, next_app=app)

return middleware
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_middleware/test_base_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, List, Tuple, Union
from unittest.mock import MagicMock
from warnings import catch_warnings

import pytest
Expand Down Expand Up @@ -285,6 +286,61 @@ def third_handler() -> dict:
assert "test" in response.headers


def test_asgi_middleware_should_exclude_scope() -> None:
mock = MagicMock()

class SubclassMiddleware(ASGIMiddleware):
@staticmethod
def should_bypass_for_scope(scope: "Scope") -> bool:
return scope["path"].endswith(".jpg")

async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None:
mock(scope["path"])
await next_app(scope, receive, send)

@get("/{file_name:str}")
def handler(file_name: str) -> str:
return file_name

with create_test_client([handler], middleware=[SubclassMiddleware()]) as client:
assert client.get("/test.txt").status_code == 200
assert client.get("/test.jpg").status_code == 200

mock.assert_called_once_with("/test.txt")


def test_asgi_middleware_path_exclude_warns_future_use() -> None:
mock = MagicMock()

class SubclassMiddleware(ASGIMiddleware):
def __init__(self, pattern: str) -> None:
self.exclude_path_pattern = pattern

async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None:
mock(scope["path"])
await next_app(scope, receive, send)

@get("/{file_name:str}")
def handler(file_name: str) -> str:
return file_name

# this configuration would NOT be excluded in the future
with create_test_client([handler], middleware=[SubclassMiddleware(".jpg")]) as client:
with pytest.warns(DeprecationWarning, match=".*exclude_path_pattern.* did match the request path"):
assert client.get("/test.jpg").status_code == 200

assert client.get("/test.txt").status_code == 200
mock.assert_called_once_with("/test.txt")

mock.reset_mock()

# this configuration WOULD be excluded in the future
with create_test_client([handler], middleware=[SubclassMiddleware("str")]) as client:
with pytest.warns(DeprecationWarning, match=".*exclude_path_pattern.* did not match the request path"):
assert client.get("/test.jpg").status_code == 200
mock.assert_called_once_with("/test.jpg")


@pytest.mark.parametrize("excludes", ["/", ("/", "/foo"), "/*", "/.*"])
def test_asgi_middleware_exclude_by_pattern_warns_if_exclude_all(excludes: Union[str, Tuple[str, ...]]) -> None:
class SubclassMiddleware(ASGIMiddleware):
Expand Down
Loading