diff --git a/async_lru/__init__.py b/async_lru/__init__.py index 447e9cdb..aadb96e8 100644 --- a/async_lru/__init__.py +++ b/async_lru/__init__.py @@ -56,6 +56,8 @@ class _CacheParameters(TypedDict): class _CacheItem(Generic[_R]): fut: "asyncio.Future[_R]" later_call: Optional[asyncio.Handle] + waiters: int + task: "asyncio.Task[_R]" def cancel(self) -> None: if self.later_call is not None: @@ -192,6 +194,15 @@ def _task_done_callback( fut.set_result(task.result()) + def _handle_cancelled_error(self, key: Hashable, cache_item: "_CacheItem") -> None: + # Called when a waiter is cancelled. + # If this is the last waiter and the underlying task is not done, + # cancel the underlying task and remove the cache entry. + if cache_item.waiters == 1 and not cache_item.task.done(): + cache_item.cancel() # Cancel TTL expiration + cache_item.task.cancel() # Cancel the running coroutine + self.__cache.pop(key, None) # Remove from cache + async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: if self.__closed: raise RuntimeError(f"alru_cache is closed for {self}") @@ -205,8 +216,20 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: if cache_item is not None: self._cache_hit(key) if not cache_item.fut.done(): - return await asyncio.shield(cache_item.fut) - + # Each logical waiter increments waiters on entry. + cache_item.waiters += 1 + + try: + # All waiters await the same future. + return await asyncio.shield(cache_item.fut) + except asyncio.CancelledError: + # If a waiter is cancelled, handle possible last-waiter cleanup. + self._handle_cancelled_error(key, cache_item) + raise + finally: + # Each logical waiter decrements waiters on exit (normal or cancelled). + cache_item.waiters -= 1 + # If the future is already done, just return the result. return cache_item.fut.result() fut = loop.create_future() @@ -215,14 +238,22 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: self.__tasks.add(task) task.add_done_callback(partial(self._task_done_callback, fut, key)) - self.__cache[key] = _CacheItem(fut, None) + cache_item = _CacheItem(fut, None, 1, task) + self.__cache[key] = cache_item if self.__maxsize is not None and len(self.__cache) > self.__maxsize: - dropped_key, cache_item = self.__cache.popitem(last=False) - cache_item.cancel() + dropped_key, dropped_cache_item = self.__cache.popitem(last=False) + dropped_cache_item.cancel() self._cache_miss(key) - return await asyncio.shield(fut) + + try: + return await asyncio.shield(fut) + except asyncio.CancelledError: + self._handle_cancelled_error(key, cache_item) + raise + finally: + cache_item.waiters -= 1 def __get__( self, instance: _T, owner: Optional[Type[_T]] diff --git a/tests/test_cancel.py b/tests/test_cancel.py new file mode 100644 index 00000000..8dc1e962 --- /dev/null +++ b/tests/test_cancel.py @@ -0,0 +1,58 @@ +import asyncio + +import pytest + +from async_lru import alru_cache + + +@pytest.mark.parametrize("num_to_cancel", [0, 1, 2, 3]) +async def test_cancel(num_to_cancel: int) -> None: + cache_item_task_finished = False + + @alru_cache + async def coro(val: int) -> int: + # I am a long running coro function + nonlocal cache_item_task_finished + await asyncio.sleep(2) + cache_item_task_finished = True + return val + + # create 3 tasks for the cached function using the same key + tasks = [asyncio.create_task(coro(1)) for _ in range(3)] + + # force the event loop to run once so the tasks can begin + await asyncio.sleep(0) + + # maybe cancel some tasks + for i in range(num_to_cancel): + tasks[i].cancel() + + # allow enough time for the non-cancelled tasks to complete + await asyncio.sleep(3) + + # check state + assert cache_item_task_finished == (num_to_cancel < 3) + + +@pytest.mark.asyncio +async def test_cancel_single_waiter_triggers_handle_cancelled_error(): + # This test ensures the _handle_cancelled_error path (waiters == 1) is exercised. + cache_item_task_finished = False + + @alru_cache + async def coro(val: int) -> int: + nonlocal cache_item_task_finished + await asyncio.sleep(2) + cache_item_task_finished = True + return val + + task = asyncio.create_task(coro(42)) + await asyncio.sleep(0) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # The underlying coroutine should be cancelled, so the flag should remain False + assert cache_item_task_finished is False