From 485744415a59586e1b3bba39399fe70ddd6e15c0 Mon Sep 17 00:00:00 2001 From: A5rocks Date: Mon, 30 Dec 2024 17:22:36 +0900 Subject: [PATCH 1/2] Add a sanity check for a missing `derive` for exception groups --- src/trio/_core/_run.py | 25 ++++++++++++ src/trio/_core/_tests/test_run.py | 65 ++++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 5dbaa18cab..4eff89cd61 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -639,6 +639,31 @@ def _close(self, exc: BaseException | None) -> BaseException | None: self.cancelled_caught = True exc = None elif isinstance(exc, BaseExceptionGroup): + # sanity check users + egs = [exc] + visited = set() + while egs: + next_eg = egs.pop() + if next_eg in visited: + continue + visited.add(next_eg) + if ( + "derive" not in type(next_eg).__dict__ + and type(next_eg) is not ExceptionGroup + ): + warnings.warn( + f"derive not implemented for {type(next_eg).__name__}, results may be unexpected", + stacklevel=1, + ) + + egs.extend( + [ + e + for e in next_eg.exceptions + if isinstance(e, BaseExceptionGroup) + ] + ) + matched, exc = exc.split(Cancelled) if matched: self.cancelled_caught = True diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index 75e5457d78..ed082c9f58 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -10,7 +10,7 @@ import weakref from contextlib import ExitStack, contextmanager, suppress from math import inf, nan -from typing import TYPE_CHECKING, NoReturn, TypeVar +from typing import TYPE_CHECKING, NoReturn, Self, TypeVar from unittest import mock import outcome @@ -44,6 +44,7 @@ Awaitable, Callable, Generator, + Sequence, ) if sys.version_info < (3, 11): @@ -2855,3 +2856,65 @@ def run(self, fn: Callable[[], object]) -> object: with mock.patch("trio._core._run.copy_context", return_value=Context()): assert _count_context_run_tb_frames() == 1 + + +def test_run_with_custom_exception_group() -> None: + class ExceptionGroupForTest(ExceptionGroup): + @staticmethod + def for_test(message: str, excs: list[Exception]) -> ExceptionGroupForTest: + raise NotImplementedError() + + async def check1(exception_group_type: type[ExceptionGroupForTest]) -> None: + raise exception_group_type.for_test("test message", [ValueError("uh oh")]) + + async def check2(exception_group_type: type[ExceptionGroupForTest]) -> None: + with _core.CancelScope(): + raise exception_group_type.for_test("test message", [ValueError("uh oh")]) + + async def check3(exception_group_type: type[ExceptionGroupForTest]) -> None: + async with _core.open_nursery(): + raise exception_group_type.for_test("test message", [ValueError("uh oh")]) + + class HasDerive(ExceptionGroupForTest): + def derive(self, excs: Sequence[BaseException]) -> HasDerive: + return HasDerive(self.message, excs) + + @staticmethod + def for_test(message: str, excs: list[Exception]) -> HasDerive: + return HasDerive(message, excs) + + class NormalNew(ExceptionGroupForTest): + @staticmethod + def for_test(message: str, excs: list[Exception]) -> NormalNew: + return NormalNew(message, excs) + + class AbnormalNew(ExceptionGroupForTest): + def __new__(cls, excs: Sequence[Exception]) -> Self: + return super().__new__(cls, f"has {len(excs)} exceptions", excs) + + @staticmethod + def for_test(message: str, excs: list[Exception]) -> AbnormalNew: + return AbnormalNew(excs) + + for check in (check1, check2, check3): + for error in (HasDerive, NormalNew, AbnormalNew): + if check is check3: + if error in (NormalNew, AbnormalNew): + with ( + pytest.warns(UserWarning, match="^derive not implemented"), + pytest.raises(ExceptionGroup) as e, + ): + _core.run(check, error) + + error = ExceptionGroup # we don't provide something better + else: + with pytest.raises(ExceptionGroup) as e: + _core.run(check, error) + + assert len(e.value.exceptions) == 1 + assert isinstance(e.value.exceptions[0], error) + else: + with pytest.raises(error): + _core.run(check, error) + + print(f"{check} + {error} PASSED") From d8e853f18611da963f3263b61ace1334305ab181 Mon Sep 17 00:00:00 2001 From: A5rocks Date: Mon, 30 Dec 2024 17:32:04 +0900 Subject: [PATCH 2/2] Fixes for CI --- newsfragments/3175.doc.rst | 1 + src/trio/_core/_run.py | 2 +- src/trio/_core/_tests/test_run.py | 10 ++++++---- 3 files changed, 8 insertions(+), 5 deletions(-) create mode 100644 newsfragments/3175.doc.rst diff --git a/newsfragments/3175.doc.rst b/newsfragments/3175.doc.rst new file mode 100644 index 0000000000..7c8475f4e7 --- /dev/null +++ b/newsfragments/3175.doc.rst @@ -0,0 +1 @@ +Warn if a user forgot to implement ``.derive`` for an ExceptionGroup subclass. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 4eff89cd61..0f243c4ff5 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -52,7 +52,7 @@ ) if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup + from exceptiongroup import BaseExceptionGroup, ExceptionGroup if TYPE_CHECKING: diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index ed082c9f58..073ac0f281 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -10,7 +10,7 @@ import weakref from contextlib import ExitStack, contextmanager, suppress from math import inf, nan -from typing import TYPE_CHECKING, NoReturn, Self, TypeVar +from typing import TYPE_CHECKING, NoReturn, TypeVar from unittest import mock import outcome @@ -47,6 +47,8 @@ Sequence, ) + from typing_extensions import Self + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -2859,7 +2861,7 @@ def run(self, fn: Callable[[], object]) -> object: def test_run_with_custom_exception_group() -> None: - class ExceptionGroupForTest(ExceptionGroup): + class ExceptionGroupForTest(ExceptionGroup[Exception]): @staticmethod def for_test(message: str, excs: list[Exception]) -> ExceptionGroupForTest: raise NotImplementedError() @@ -2876,7 +2878,7 @@ async def check3(exception_group_type: type[ExceptionGroupForTest]) -> None: raise exception_group_type.for_test("test message", [ValueError("uh oh")]) class HasDerive(ExceptionGroupForTest): - def derive(self, excs: Sequence[BaseException]) -> HasDerive: + def derive(self, excs: Sequence[Exception]) -> HasDerive: return HasDerive(self.message, excs) @staticmethod @@ -2897,7 +2899,7 @@ def for_test(message: str, excs: list[Exception]) -> AbnormalNew: return AbnormalNew(excs) for check in (check1, check2, check3): - for error in (HasDerive, NormalNew, AbnormalNew): + for error in [HasDerive, NormalNew, AbnormalNew]: if check is check3: if error in (NormalNew, AbnormalNew): with (