diff --git a/starlette/_utils.py b/starlette/_utils.py index 0c389dcb2..143c660f3 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -4,7 +4,9 @@ import functools import sys import typing -from contextlib import contextmanager +from contextlib import asynccontextmanager + +import anyio.abc from starlette.types import Scope @@ -13,12 +15,14 @@ else: # pragma: no cover from typing_extensions import TypeGuard -has_exceptiongroups = True if sys.version_info < (3, 11): # pragma: no cover try: - from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found] + from exceptiongroup import BaseExceptionGroup except ImportError: - has_exceptiongroups = False + + class BaseExceptionGroup(BaseException): # type: ignore[no-redef] + pass + T = typing.TypeVar("T") AwaitableCallable = typing.Callable[..., typing.Awaitable[T]] @@ -70,16 +74,28 @@ async def __aexit__(self, *args: typing.Any) -> None | bool: return None -@contextmanager -def collapse_excgroups() -> typing.Generator[None, None, None]: +@asynccontextmanager +async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]: try: - yield - except BaseException as exc: - if has_exceptiongroups: # pragma: no cover - while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: - exc = exc.exceptions[0] - - raise exc + async with anyio.create_task_group() as tg: + yield tg + except BaseExceptionGroup as excs: + if len(excs.exceptions) != 1: + raise + + exc = excs.exceptions[0] + context = exc.__context__ + tb = exc.__traceback__ + cause = exc.__cause__ + sc = exc.__suppress_context__ + try: + raise exc + finally: + exc.__traceback__ = tb + exc.__context__ = context + exc.__cause__ = cause + exc.__suppress_context__ = sc + del exc, cause, tb, context def get_route_path(scope: Scope) -> str: diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index f146984b3..56bbc9d12 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,7 +4,7 @@ import anyio -from starlette._utils import collapse_excgroups +from starlette._utils import create_collapsing_task_group from starlette.requests import ClientDisconnect, Request from starlette.responses import AsyncContentStream, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -174,8 +174,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream() send_stream, recv_stream = streams - with recv_stream, send_stream, collapse_excgroups(): - async with anyio.create_task_group() as task_group: + with recv_stream, send_stream: + async with create_collapsing_task_group() as task_group: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) response_sent.set() diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 6e0a3fae6..3e9ad0296 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -9,6 +9,7 @@ import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream +from starlette._utils import create_collapsing_task_group from starlette.types import Receive, Scope, Send warnings.warn( @@ -102,7 +103,7 @@ async def __call__(self, receive: Receive, send: Send) -> None: more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - async with anyio.create_task_group() as task_group: + async with create_collapsing_task_group() as task_group: task_group.start_soon(self.sender, send) async with self.stream_send: await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) diff --git a/starlette/responses.py b/starlette/responses.py index 31874f655..1f8b87bea 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -18,6 +18,7 @@ import anyio import anyio.to_thread +from starlette._utils import create_collapsing_task_group from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders @@ -258,7 +259,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: except OSError: raise ClientDisconnect() else: - async with anyio.create_task_group() as task_group: + async with create_collapsing_task_group() as task_group: async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: await func() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 7232cfd18..483dd869e 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextvars +import sys from collections.abc import AsyncGenerator, AsyncIterator, Generator from contextlib import AsyncExitStack from typing import Any @@ -21,6 +22,9 @@ from starlette.websockets import WebSocket from tests.types import TestClientFactory +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import ExceptionGroup + class CustomMiddleware(BaseHTTPMiddleware): async def dispatch( @@ -41,6 +45,10 @@ def exc(request: Request) -> None: raise Exception("Exc") +def eg(request: Request) -> None: + raise ExceptionGroup("my exception group", [ValueError("TEST")]) + + def exc_stream(request: Request) -> StreamingResponse: return StreamingResponse(_generate_faulty_stream()) @@ -76,6 +84,7 @@ async def websocket_endpoint(session: WebSocket) -> None: routes=[ Route("/", endpoint=homepage), Route("/exc", endpoint=exc), + Route("/eg", endpoint=eg), Route("/exc-stream", endpoint=exc_stream), Route("/no-response", endpoint=NoResponse), WebSocketRoute("/ws", endpoint=websocket_endpoint), @@ -89,13 +98,16 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None: response = client.get("/") assert response.headers["Custom-Header"] == "Example" - with pytest.raises(Exception) as ctx: + with pytest.raises(Exception) as ctx1: response = client.get("/exc") - assert str(ctx.value) == "Exc" + assert str(ctx1.value) == "Exc" - with pytest.raises(Exception) as ctx: + with pytest.raises(Exception) as ctx2: response = client.get("/exc-stream") - assert str(ctx.value) == "Faulty Stream" + assert str(ctx2.value) == "Faulty Stream" + + with pytest.raises(ExceptionGroup, match=r"my exception group \(1 sub-exception\)"): + client.get("/eg") with pytest.raises(RuntimeError): response = client.get("/no-response") diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 3511c89c9..418f0946f 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -4,7 +4,6 @@ import pytest -from starlette._utils import collapse_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ from tests.types import TestClientFactory @@ -86,7 +85,7 @@ def test_wsgi_exception(test_client_factory: TestClientFactory) -> None: # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) - with pytest.raises(RuntimeError), collapse_excgroups(): + with pytest.raises(RuntimeError): client.get("/") diff --git a/tests/test__utils.py b/tests/test__utils.py index 916f460d4..60af81392 100644 --- a/tests/test__utils.py +++ b/tests/test__utils.py @@ -1,11 +1,15 @@ import functools +import sys from typing import Any import pytest -from starlette._utils import get_route_path, is_async_callable +from starlette._utils import create_collapsing_task_group, get_route_path, is_async_callable from starlette.types import Scope +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import ExceptionGroup + def test_async_func() -> None: async def async_func() -> None: ... # pragma: no cover @@ -94,3 +98,31 @@ async def async_func( ) def test_get_route_path(scope: Scope, expected_result: str) -> None: assert get_route_path(scope) == expected_result + + +@pytest.mark.anyio +async def test_collapsing_task_group_one_exc() -> None: + class MyException(Exception): + pass + + with pytest.raises(MyException): + async with create_collapsing_task_group(): + raise MyException + + +@pytest.mark.anyio +async def test_collapsing_task_group_two_exc() -> None: + class MyException(Exception): + pass + + async def raise_exc() -> None: + raise MyException + + with pytest.raises(ExceptionGroup) as exc: + async with create_collapsing_task_group() as task_group: + task_group.start_soon(raise_exc) + raise MyException + + exc1, exc2 = exc.value.exceptions + assert isinstance(exc1, MyException) + assert isinstance(exc2, MyException)