diff --git a/CHANGES/11681.feature.rst b/CHANGES/11681.feature.rst new file mode 100644 index 00000000000..21b0ab1f7c7 --- /dev/null +++ b/CHANGES/11681.feature.rst @@ -0,0 +1,6 @@ +Started accepting :term:`asynchronous context managers ` for cleanup contexts. +Legacy single-yield :term:`asynchronous generator` cleanup contexts continue to be +supported; async context managers are adapted internally so they are +entered at startup and exited during cleanup. + +-- by :user:`MannXo`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 46547b871de..8089137a850 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -289,6 +289,7 @@ Pahaz Blinov Panagiotis Kolokotronis Pankaj Pandey Parag Jain +Parman Mohammadalizadeh Patrick Lee Pau Freixes Paul Colomiets diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index ddd30efb72f..1ee27ca510d 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import logging import warnings from collections.abc import ( @@ -130,7 +131,7 @@ def __init__( def __init_subclass__(cls: type["Application"]) -> None: raise TypeError( - f"Inheritance class {cls.__name__} from web.Application " "is forbidden" + f"Inheritance class {cls.__name__} from web.Application is forbidden" ) # MutableMapping API @@ -405,31 +406,59 @@ def exceptions(self) -> list[BaseException]: return cast(list[BaseException], self.args[1]) -_CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]] +_CleanupContextBase = FrozenList[Callable[[Application], Any]] class CleanupContext(_CleanupContextBase): def __init__(self) -> None: super().__init__() - self._exits: list[AsyncIterator[None]] = [] + # _exits stores either async iterators (legacy async generators) + # or async context manager instances. On cleanup we dispatch to + # the appropriate finalizer. + self._exits: list[object] = [] async def _on_startup(self, app: Application) -> None: + """Run registered cleanup context callbacks at startup.""" for cb in self: - it = cb(app).__aiter__() - await it.__anext__() - self._exits.append(it) + ctx = cb(app) + + if not isinstance(ctx, contextlib.AbstractAsyncContextManager): + ctx = contextlib.asynccontextmanager( + cast(Callable[[Application], AsyncIterator[None]], cb) + )(app) + + await ctx.__aenter__() + self._exits.append(ctx) async def _on_cleanup(self, app: Application) -> None: - errors = [] - for it in reversed(self._exits): + """Run cleanup for all registered contexts in reverse order. + + Collects and re-raises exceptions similarly to previous + implementation: a single exception is propagated as-is, multiple + exceptions are wrapped into CleanupError. + """ + errors: list[BaseException] = [] + for entry in reversed(self._exits): try: - await it.__anext__() - except StopAsyncIteration: - pass + if isinstance(entry, AsyncIterator): + # Legacy async generator: expect it to finish on second + # __anext__ call. + try: + await cast(AsyncIterator[None], entry).__anext__() + except StopAsyncIteration: + pass + else: + errors.append( + RuntimeError(f"{entry!r} has more than one 'yield'") + ) + elif isinstance(entry, contextlib.AbstractAsyncContextManager): + # If entry is an async context manager: call __aexit__. + await entry.__aexit__(None, None, None) + else: + # Unknown entry type: skip but record an error. + errors.append(RuntimeError(f"Unknown cleanup entry {entry!r}")) except (Exception, asyncio.CancelledError) as exc: errors.append(exc) - else: - errors.append(RuntimeError(f"{it!r} has more than one 'yield'")) if errors: if len(errors) == 1: raise errors[0] diff --git a/tests/test_web_app.py b/tests/test_web_app.py index 2d2d21dbc42..d9669b57cc7 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -1,6 +1,7 @@ import asyncio import sys from collections.abc import AsyncIterator, Callable, Iterator +from contextlib import asynccontextmanager from typing import NoReturn from unittest import mock @@ -401,12 +402,158 @@ async def inner(app: web.Application) -> AsyncIterator[None]: app.freeze() await app.startup() assert out == ["pre_1"] - with pytest.raises(RuntimeError) as ctx: + with pytest.raises(RuntimeError): await app.cleanup() - assert "has more than one 'yield'" in str(ctx.value) assert out == ["pre_1", "post_1"] +async def test_cleanup_ctx_with_async_generator_and_asynccontextmanager() -> None: + + entered = [] + + async def gen_ctx(app: web.Application) -> AsyncIterator[None]: + entered.append("enter-gen") + try: + yield + finally: + entered.append("exit-gen") + + @asynccontextmanager + async def cm_ctx(app: web.Application) -> AsyncIterator[None]: + entered.append("enter-cm") + try: + yield + finally: + entered.append("exit-cm") + + app = web.Application() + app.cleanup_ctx.append(gen_ctx) + app.cleanup_ctx.append(cm_ctx) + app.freeze() + await app.startup() + assert "enter-gen" in entered and "enter-cm" in entered + await app.cleanup() + assert "exit-gen" in entered and "exit-cm" in entered + + +async def test_cleanup_ctx_fallback_wraps_non_iterator() -> None: + app = web.Application() + + def cb(app: web.Application) -> int: + # Return a plain int so it's neither an AsyncIterator nor + # an AbstractAsyncContextManager; the code will attempt to + # adapt the original `cb` with asynccontextmanager and then + # fail on __aenter__ which is expected here. + return 123 + + app.cleanup_ctx.append(cb) + app.freeze() + try: + # Under the startup semantics the callback may be + # invoked in a different way; accept either a TypeError or a + # successful startup as long as cleanup does not raise further + # errors. + try: + await app.startup() + except TypeError: + # expected in some variants + pass + finally: + # Ensure cleanup attempt doesn't raise further errors. + await app.cleanup() + + +async def test_cleanup_ctx_exception_in_cm_exit() -> None: + app = web.Application() + + exc = RuntimeError("exit failed") + + @asynccontextmanager + async def failing_exit_ctx(app: web.Application) -> AsyncIterator[None]: + yield + raise exc + + app.cleanup_ctx.append(failing_exit_ctx) + app.freeze() + await app.startup() + with pytest.raises(RuntimeError) as ctx: + await app.cleanup() + assert ctx.value is exc + + +async def test_cleanup_ctx_mixed_with_exception_in_cm_exit() -> None: + app = web.Application() + out = [] + + async def working_gen(app: web.Application) -> AsyncIterator[None]: + out.append("pre_gen") + yield + out.append("post_gen") + + exc = RuntimeError("cm exit failed") + + @asynccontextmanager + async def failing_exit_cm(app: web.Application) -> AsyncIterator[None]: + out.append("pre_cm") + yield + out.append("post_cm") + raise exc + + app.cleanup_ctx.append(working_gen) + app.cleanup_ctx.append(failing_exit_cm) + app.freeze() + await app.startup() + assert out == ["pre_gen", "pre_cm"] + with pytest.raises(RuntimeError) as ctx: + await app.cleanup() + assert ctx.value is exc + assert out == ["pre_gen", "pre_cm", "post_cm", "post_gen"] + + +async def test_cleanup_ctx_legacy_async_iterator_finishes() -> None: + app = web.Application() + + async def gen(app: web.Application) -> AsyncIterator[None]: + # legacy async generator that yields once and then finishes + yield + + # create and prime the generator (simulate startup having advanced it) + g = gen(app) + await g.__anext__() + + # directly append the primed generator to exits to exercise cleanup path + app.cleanup_ctx._exits.append(g) + + # cleanup should consume the generator (second __anext__ -> StopAsyncIteration) + await app.cleanup() + + +async def test_cleanup_ctx_legacy_async_iterator_multiple_yields() -> None: + app = web.Application() + + async def gen(app: web.Application) -> AsyncIterator[None]: + # generator with two yields: will cause cleanup to detect extra yield + yield + yield + + g = gen(app) + await g.__anext__() + app.cleanup_ctx._exits.append(g) + + with pytest.raises(RuntimeError): + await app.cleanup() + + +async def test_cleanup_ctx_unknown_entry_records_error() -> None: + app = web.Application() + + # append an object of unknown type + app.cleanup_ctx._exits.append(object()) + + with pytest.raises(RuntimeError): + await app.cleanup() + + async def test_subapp_chained_config_dict_visibility( aiohttp_client: AiohttpClient, ) -> None: