Skip to content

Commit a733629

Browse files
authored
feat(core)!: Add zero-cost middleware exclusion (#4372)
1 parent acef4e2 commit a733629

File tree

6 files changed

+203
-57
lines changed

6 files changed

+203
-57
lines changed

docs/release-notes/changelog.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,31 @@
198198
Use ``litestar[piccolo]`` extra installation target
199199
and ``litestar_piccolo`` plugin instead:
200200
https://github.com/litestar-org/litestar-piccolo
201+
202+
.. change:: Zero cost excluded middlewares
203+
:type: feature
204+
:breaking:
205+
206+
Middlewares inheriting from :class:`~litestar.middleware.base.ASGIMiddleware`
207+
will now have zero runtime cost when they are excluded e.g. via the ``scope`` or
208+
``exclude_opt_key`` options.
209+
210+
Previously, the base middleware was always being invoked for every request,
211+
evaluating the exclusion criteria, and then calling the user defined middleware
212+
functions. If a middleware had defined ``scopes = (ScopeType.HTTP,)``, it would
213+
still be called for *every* request, regardless of the scope type. Only for
214+
requests with the type ``HTTP``, it would then call the user's function.
215+
216+
.. note::
217+
This behaviour is still true for the legacy ``AbstractMiddleware``
218+
219+
With *zero cost exclusion*, the exclusion is being evaluated statically. At app
220+
creation time, when route handlers are registered and their middleware stacks
221+
are being built, a middleware that is to be excluded will simply not be included
222+
in the stack.
223+
224+
.. note::
225+
Even though this change is marked as breaking, no runtime behaviour
226+
difference is expected. Some test cases may break though if they relied on
227+
the fact that the middleware wrapper created by ``ASGIMiddleware`` was
228+
always being called

litestar/_asgi/routing_trie/mapping.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def build_route_middleware_stack(
183183
An ASGIApp that is composed of a "stack" of middlewares.
184184
"""
185185
from litestar.middleware.allowed_hosts import AllowedHostsMiddleware
186+
from litestar.middleware.base import ASGIMiddleware
186187
from litestar.middleware.compression import CompressionMiddleware
187188
from litestar.middleware.csrf import CSRFMiddleware
188189
from litestar.middleware.response_cache import ResponseCacheMiddleware
@@ -243,5 +244,7 @@ def build_route_middleware_stack(
243244
# +--------------------+
244245
# --> response
245246
for middleware in reversed(handler_middleware):
247+
if isinstance(middleware, ASGIMiddleware) and middleware.should_bypass_for_handler(route_handler):
248+
continue
246249
asgi_handler = middleware(asgi_handler)
247250
return asgi_handler

litestar/middleware/_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ def build_exclude_path_pattern(
4848
) from e
4949

5050

51+
def match_exclude_path(exclude_path_pattern: Pattern, scope: Scope) -> bool:
52+
return bool(
53+
exclude_path_pattern.findall(
54+
scope["raw_path"].decode()
55+
if getattr(scope.get("route_handler", None), "is_mount", False)
56+
else scope["path"]
57+
)
58+
)
59+
60+
5161
def should_bypass_middleware(
5262
*,
5363
exclude_http_methods: Sequence[Method] | None = None,
@@ -56,7 +66,7 @@ def should_bypass_middleware(
5666
scope: Scope,
5767
scopes: Scopes,
5868
) -> bool:
59-
"""Determine weather a middleware should be bypassed.
69+
"""Determine whether a middleware should be bypassed.
6070
6171
Args:
6272
exclude_http_methods: A sequence of http methods that do not require authentication.
@@ -77,9 +87,4 @@ def should_bypass_middleware(
7787
if exclude_http_methods and scope.get("method") in exclude_http_methods:
7888
return True
7989

80-
return bool(
81-
exclude_path_pattern
82-
and exclude_path_pattern.findall(
83-
scope["raw_path"].decode() if getattr(scope.get("route_handler", {}), "is_mount", False) else scope["path"]
84-
)
85-
)
90+
return exclude_path_pattern is not None and match_exclude_path(exclude_path_pattern, scope)

litestar/middleware/base.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
if TYPE_CHECKING:
2222
from litestar.middleware.constraints import MiddlewareConstraints
23-
from litestar.types import Scopes
23+
from litestar.types import RouteHandlerType, Scopes
2424
from litestar.types.asgi_types import ASGIApp, Receive, Scope, Send
2525

2626

@@ -215,27 +215,79 @@ async def handle(
215215
ScopeType.WEBSOCKET,
216216
ScopeType.ASGI,
217217
)
218+
"""Scope types this middleware should be applied to"""
218219
exclude_path_pattern: str | tuple[str, ...] | None = None
220+
r"""
221+
A regex pattern (or tuple of patterns) to exclude this middleware from route
222+
handlers whose path matches any of the provided patterns.
223+
224+
.. important::
225+
Pattern matching is performed against the **handler's path** (e.g.,
226+
``/user/{user_id:int}/``), NOT against the actual **request path** (e.g.,
227+
``/user/1234/``). This is a critical distinction for dynamic routes.
228+
229+
**Example 1: Static path**
230+
231+
Handler path::
232+
233+
/api/health
234+
235+
To exclude this handler, use a pattern like::
236+
237+
exclude_path_pattern = r"^/api/health$"
238+
239+
**Example 2: Dynamic path (path parameters)**
240+
241+
Handler path::
242+
243+
/user/{user_id:int}/profile
244+
└─────┬──────┘
245+
└─ This is what the pattern matches against
246+
247+
Actual request paths that match this handler::
248+
249+
/user/1234/profile
250+
/user/5678/profile
251+
/user/9999/profile
252+
253+
To exclude this handler, the pattern must match the **handler**, not the actual
254+
request path:
255+
256+
exclude_path_pattern = "/user/{user_id:int}/profile"
257+
exclude_path_pattern = "/user/\{.+?\}/"
258+
"""
219259
exclude_opt_key: str | None = None
220260
constraints: MiddlewareConstraints | None = None
221261

262+
def should_bypass_for_handler(self, handler: RouteHandlerType) -> bool:
263+
"""Return ``True`` if this middleware should be bypassed for ``handler``, according
264+
to ``scopes``, ``exclude_path_pattern`` or ``exclude_opt_key``, otherwise
265+
``False``.
266+
"""
267+
from litestar.handlers import ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler
268+
269+
if isinstance(handler, HTTPRouteHandler) and ScopeType.HTTP not in self.scopes:
270+
return True
271+
if isinstance(handler, WebsocketRouteHandler) and ScopeType.WEBSOCKET not in self.scopes:
272+
return True
273+
if isinstance(handler, ASGIRouteHandler) and ScopeType.ASGI not in self.scopes:
274+
return True
275+
276+
if self.exclude_opt_key and handler.opt.get(self.exclude_opt_key):
277+
return True
278+
279+
pattern = build_exclude_path_pattern(exclude=self.exclude_path_pattern, middleware_cls=type(self))
280+
if pattern and any(pattern.search(path) for path in handler.paths):
281+
return True
282+
283+
return False
284+
222285
def __call__(self, app: ASGIApp) -> ASGIApp:
223286
"""Create the actual middleware callable"""
224287
handle = self.handle
225-
exclude_pattern = build_exclude_path_pattern(exclude=self.exclude_path_pattern, middleware_cls=type(self))
226-
scopes = set(self.scopes)
227-
exclude_opt_key = self.exclude_opt_key
228288

229289
async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
230-
if should_bypass_middleware(
231-
scope=scope,
232-
scopes=scopes, # type: ignore[arg-type]
233-
exclude_opt_key=exclude_opt_key,
234-
exclude_path_pattern=exclude_pattern,
235-
):
236-
await app(scope, receive, send)
237-
else:
238-
await handle(scope=scope, receive=receive, send=send, next_app=app)
290+
await handle(scope=scope, receive=receive, send=send, next_app=app)
239291

240292
return middleware
241293

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ lint.ignore = [
429429
"A005", # Module shadows a built-in Python standard library module
430430
"PLC0415", # `import` should be at the top-level of a file
431431
"PLW1641", # Object does not implement `__hash__` method
432+
"SIM103", # 'Return the condition directly': This is a bit overzealous
432433
]
433434
lint.select = [
434435
"A", # flake8-builtins

tests/unit/test_middleware/test_base_middleware.py

Lines changed: 94 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import TYPE_CHECKING, Union
2+
from unittest.mock import MagicMock, call
23
from warnings import catch_warnings
34

45
import pytest
56

6-
from litestar import MediaType, asgi, get
7+
from litestar import MediaType, WebSocket, asgi, get, websocket
78
from litestar.datastructures.headers import MutableScopeHeaders
9+
from litestar.enums import ScopeType
810
from litestar.exceptions import LitestarWarning, ValidationException
911
from litestar.middleware import AbstractMiddleware, ASGIMiddleware, DefineMiddleware
1012
from litestar.response.base import ASGIResponse
@@ -211,19 +213,64 @@ def handler() -> dict:
211213
assert response.status_code == HTTP_400_BAD_REQUEST
212214

213215

216+
@pytest.mark.parametrize(
217+
"allowed_scopes,expected_calls",
218+
[
219+
((ScopeType.HTTP,), ["/http"]),
220+
((ScopeType.ASGI,), ["/asgi"]),
221+
((ScopeType.WEBSOCKET,), ["/ws"]),
222+
((ScopeType.HTTP, ScopeType.ASGI), ["/http", "/asgi"]),
223+
((ScopeType.HTTP, ScopeType.WEBSOCKET), ["/http", "/ws"]),
224+
((ScopeType.ASGI, ScopeType.WEBSOCKET), ["/asgi", "/ws"]),
225+
],
226+
)
227+
def test_asgi_middleware_exclude_by_scope_type(
228+
allowed_scopes: tuple[ScopeType, ...], expected_calls: list[str]
229+
) -> None:
230+
mock = MagicMock()
231+
232+
class SubclassMiddleware(ASGIMiddleware):
233+
scopes = allowed_scopes
234+
235+
async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None:
236+
mock(scope["path"])
237+
await next_app(scope, receive, send)
238+
239+
@get("/http")
240+
def http_handler() -> None:
241+
return None
242+
243+
@websocket("/ws")
244+
async def websocket_handler(socket: WebSocket) -> None:
245+
await socket.accept()
246+
await socket.close()
247+
248+
@asgi("/asgi")
249+
async def asgi_handler(scope: "Scope", receive: "Receive", send: "Send") -> None:
250+
response = ASGIResponse(body=b"ok", media_type=MediaType.TEXT)
251+
await response(scope, receive, send)
252+
253+
with create_test_client(
254+
[http_handler, asgi_handler, websocket_handler], middleware=[SubclassMiddleware()]
255+
) as client:
256+
assert client.get("/http").status_code == 200
257+
assert client.get("/asgi").status_code == 200
258+
with client.websocket_connect("/ws"):
259+
pass
260+
261+
mock.assert_has_calls([call(path) for path in expected_calls])
262+
263+
214264
def test_asgi_middleware_exclude_by_pattern() -> None:
265+
mock = MagicMock()
266+
215267
class SubclassMiddleware(ASGIMiddleware):
216268
def __init__(self) -> None:
217269
self.exclude_path_pattern = r"^/123"
218270

219271
async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None:
220-
async def _send(message: "Message") -> None:
221-
if message["type"] == "http.response.start":
222-
headers = MutableScopeHeaders(message)
223-
headers.add("test", str(123))
224-
await send(message)
225-
226-
await next_app(scope, receive, _send)
272+
mock(scope["raw_path"].decode())
273+
await next_app(scope, receive, send)
227274

228275
@get("/123")
229276
def first_handler() -> dict:
@@ -239,28 +286,22 @@ async def handler(scope: "Scope", receive: "Receive", send: "Send") -> None:
239286
await response(scope, receive, send)
240287

241288
with create_test_client([first_handler, second_handler, handler], middleware=[SubclassMiddleware()]) as client:
242-
response = client.get("/123")
243-
assert "test" not in response.headers
244-
245-
response = client.get("/456")
246-
assert "test" in response.headers
289+
assert client.get("/123").status_code == 200
290+
assert client.get("/456").status_code == 200
291+
assert client.get("/mount/123").status_code == 200
247292

248-
response = client.get("/mount/123")
249-
assert "test" in response.headers
293+
mock.assert_has_calls([call("/456"), call("/mount/123")])
250294

251295

252296
def test_asgi_middleware_exclude_by_pattern_tuple() -> None:
297+
mock = MagicMock()
298+
253299
class SubclassMiddleware(ASGIMiddleware):
254300
exclude_path_pattern = ("123", "456")
255301

256302
async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None:
257-
async def _send(message: "Message") -> None:
258-
if message["type"] == "http.response.start":
259-
headers = MutableScopeHeaders(message)
260-
headers.add("test", str(123))
261-
await send(message)
262-
263-
await next_app(scope, receive, _send)
303+
mock(scope["path"])
304+
await next_app(scope, receive, send)
264305

265306
@get("/123")
266307
def first_handler() -> dict:
@@ -277,12 +318,31 @@ def third_handler() -> dict:
277318
with create_test_client(
278319
[first_handler, second_handler, third_handler], middleware=[SubclassMiddleware()]
279320
) as client:
280-
response = client.get("/123")
281-
assert "test" not in response.headers
282-
response = client.get("/456")
283-
assert "test" not in response.headers
284-
response = client.get("/789")
285-
assert "test" in response.headers
321+
assert client.get("/123").status_code == 200
322+
assert client.get("/456").status_code == 200
323+
assert client.get("/789").status_code == 200
324+
325+
mock.assert_called_once_with("/789")
326+
327+
328+
def test_asgi_middleware_exclude_dynamic_handler_by_pattern() -> None:
329+
mock = MagicMock()
330+
331+
class SubclassMiddleware(ASGIMiddleware):
332+
def __init__(self) -> None:
333+
self.exclude_path_pattern = r"^/foo/{bar" # use a pattern that ensures we match the raw handler path
334+
335+
async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None:
336+
mock()
337+
await next_app(scope, receive, send)
338+
339+
@get("/foo/{bar:int}")
340+
def handler(bar: int) -> None:
341+
return None
342+
343+
with create_test_client([handler], middleware=[SubclassMiddleware()]) as client:
344+
assert client.get("/foo/1").status_code == 200
345+
mock.assert_not_called()
286346

287347

288348
@pytest.mark.parametrize("excludes", ["/", ("/", "/foo"), "/*", "/.*"])
@@ -310,22 +370,19 @@ async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_ap
310370

311371

312372
def test_asgi_middleware_exclude_by_opt_key() -> None:
373+
mock = MagicMock()
374+
313375
class SubclassMiddleware(ASGIMiddleware):
314376
exclude_opt_key = "exclude_route"
315377

316378
async def handle(self, scope: "Scope", receive: "Receive", send: "Send", next_app: "ASGIApp") -> None:
317-
async def _send(message: "Message") -> None:
318-
if message["type"] == "http.response.start":
319-
headers = MutableScopeHeaders(message)
320-
headers.add("test", str(123))
321-
await send(message)
322-
323-
await next_app(scope, receive, send)
379+
mock()
380+
await next_app(scope, receive, send)
324381

325382
@get("/", exclude_route=True)
326383
def handler() -> dict:
327384
return {"hello": "world"}
328385

329386
with create_test_client(handler, middleware=[SubclassMiddleware()]) as client:
330-
response = client.get("/")
331-
assert "test" not in response.headers
387+
assert client.get("/").status_code == 200
388+
mock.assert_not_called()

0 commit comments

Comments
 (0)