-
Notifications
You must be signed in to change notification settings - Fork 63
feat: handle CancelledError - cancel if no other waiters + refactoring #708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
865db04
9244bc7
8c2b878
b283118
f8ceb96
27100f3
19b1b3b
e85e099
2c4d654
f9cfbf9
6441a1d
94c303b
1dec1b4
85f2f51
094a3cc
0b06d9b
b439eca
83f921c
f2152ee
4bbab74
157c7d5
23daa05
3972c0f
4878867
3943555
d76b537
9ed3a42
4675c7a
33e5dca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,9 +9,9 @@ | |
| Coroutine, | ||
| Generic, | ||
| Hashable, | ||
| List, | ||
| Optional, | ||
| OrderedDict, | ||
| Set, | ||
| Type, | ||
| TypedDict, | ||
| TypeVar, | ||
|
|
@@ -54,8 +54,9 @@ class _CacheParameters(TypedDict): | |
| @final | ||
| @dataclasses.dataclass | ||
| class _CacheItem(Generic[_R]): | ||
| fut: "asyncio.Future[_R]" | ||
| task: "asyncio.Task[_R]" | ||
| later_call: Optional[asyncio.Handle] | ||
| waiters: int | ||
|
|
||
| def cancel(self) -> None: | ||
| if self.later_call is not None: | ||
|
|
@@ -108,7 +109,17 @@ def __init__( | |
| self.__closed = False | ||
| self.__hits = 0 | ||
| self.__misses = 0 | ||
| self.__tasks: Set["asyncio.Task[_R]"] = set() | ||
|
|
||
| @property | ||
| def __tasks(self) -> List["asyncio.Task[_R]"]: | ||
| # NOTE: I don't think we need to form a set first here but not too sure we want it for guarantees | ||
| return list( | ||
| { | ||
| cache_item.task | ||
| for cache_item in self.__cache.values() | ||
| if not cache_item.task.done() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the condition is needed or useful for the use cases I saw. |
||
| } | ||
| ) | ||
|
|
||
| def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool: | ||
| key = _make_key(args, kwargs, self.__typed) | ||
|
|
@@ -128,12 +139,11 @@ def cache_clear(self) -> None: | |
| if c.later_call: | ||
| c.later_call.cancel() | ||
| self.__cache.clear() | ||
| self.__tasks.clear() | ||
|
|
||
| async def cache_close(self, *, wait: bool = False) -> None: | ||
| self.__closed = True | ||
|
|
||
| tasks = list(self.__tasks) | ||
| tasks = self.__tasks | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to materialize tasks into a list or set or some other container here
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, that's already handled by the original code. |
||
| if not tasks: | ||
| return | ||
|
|
||
|
|
@@ -167,19 +177,8 @@ def _cache_hit(self, key: Hashable) -> None: | |
| def _cache_miss(self, key: Hashable) -> None: | ||
| self.__misses += 1 | ||
|
|
||
| def _task_done_callback( | ||
| self, fut: "asyncio.Future[_R]", key: Hashable, task: "asyncio.Task[_R]" | ||
| ) -> None: | ||
| self.__tasks.discard(task) | ||
|
|
||
| if task.cancelled(): | ||
| fut.cancel() | ||
| self.__cache.pop(key, None) | ||
| return | ||
|
|
||
| exc = task.exception() | ||
| if exc is not None: | ||
| fut.set_exception(exc) | ||
| def _task_done_callback(self, key: Hashable, task: "asyncio.Task[_R]") -> None: | ||
| if task.cancelled() or task.exception() is not None: | ||
| self.__cache.pop(key, None) | ||
| return | ||
|
|
||
|
|
@@ -190,7 +189,16 @@ def _task_done_callback( | |
| self.__ttl, self.__cache.pop, key, None | ||
| ) | ||
|
|
||
| fut.set_result(task.result()) | ||
| def _handle_cancelled_error( | ||
| self, key: Hashable, cache_item: "_CacheItem[Any]" | ||
| ) -> None: | ||
| # Called when a waiter is cancelled. | ||
| # If this is the last waiter and the underlying task is not done, | ||
| # cancel the underlying task and remove the cache entry. | ||
| if cache_item.waiters == 1 and not cache_item.task.done(): | ||
| cache_item.cancel() # Cancel TTL expiration | ||
| cache_item.task.cancel() # Cancel the running coroutine | ||
| self.__cache.pop(key, None) # Remove from cache | ||
|
|
||
| async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: | ||
| if self.__closed: | ||
|
|
@@ -204,25 +212,43 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: | |
|
|
||
| if cache_item is not None: | ||
| self._cache_hit(key) | ||
| if not cache_item.fut.done(): | ||
| return await asyncio.shield(cache_item.fut) | ||
|
|
||
| return cache_item.fut.result() | ||
| if not cache_item.task.done(): | ||
| # Each logical waiter increments waiters on entry. | ||
| cache_item.waiters += 1 | ||
|
|
||
| try: | ||
| # All waiters await the same shielded task. | ||
| return await asyncio.shield(cache_item.task) | ||
| except asyncio.CancelledError: | ||
| # If a waiter is cancelled, handle possible last-waiter cleanup. | ||
| self._handle_cancelled_error(key, cache_item) | ||
| raise | ||
| finally: | ||
| # Each logical waiter decrements waiters on exit (normal or cancelled). | ||
| cache_item.waiters -= 1 | ||
| # If the task is already done, just return the result. | ||
| return cache_item.task.result() | ||
|
|
||
| fut = loop.create_future() | ||
| coro = self.__wrapped__(*fn_args, **fn_kwargs) | ||
| task: asyncio.Task[_R] = loop.create_task(coro) | ||
| self.__tasks.add(task) | ||
| task.add_done_callback(partial(self._task_done_callback, fut, key)) | ||
| task.add_done_callback(partial(self._task_done_callback, key)) | ||
|
|
||
| self.__cache[key] = _CacheItem(fut, None) | ||
| cache_item = _CacheItem(task, None, 1) | ||
| self.__cache[key] = cache_item | ||
|
|
||
| if self.__maxsize is not None and len(self.__cache) > self.__maxsize: | ||
| dropped_key, cache_item = self.__cache.popitem(last=False) | ||
| cache_item.cancel() | ||
| dropped_key, dropped_cache_item = self.__cache.popitem(last=False) | ||
| dropped_cache_item.cancel() | ||
|
|
||
| self._cache_miss(key) | ||
| return await asyncio.shield(fut) | ||
|
|
||
| try: | ||
| return await asyncio.shield(task) | ||
| except asyncio.CancelledError: | ||
| self._handle_cancelled_error(key, cache_item) | ||
| raise | ||
| finally: | ||
| cache_item.waiters -= 1 | ||
|
|
||
| def __get__( | ||
| self, instance: _T, owner: Optional[Type[_T]] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import asyncio | ||
|
|
||
| import pytest | ||
|
|
||
| from async_lru import alru_cache | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_to_cancel", [0, 1, 2, 3]) | ||
| async def test_cancel(num_to_cancel: int) -> None: | ||
| cache_item_task_finished = False | ||
|
|
||
| @alru_cache | ||
| async def coro(val: int) -> int: | ||
| # I am a long running coro function | ||
| nonlocal cache_item_task_finished | ||
| await asyncio.sleep(2) | ||
| cache_item_task_finished = True | ||
| return val | ||
|
|
||
| # create 3 tasks for the cached function using the same key | ||
| tasks = [asyncio.create_task(coro(1)) for _ in range(3)] | ||
|
|
||
| # force the event loop to run once so the tasks can begin | ||
| await asyncio.sleep(0) | ||
|
|
||
| # maybe cancel some tasks | ||
| for i in range(num_to_cancel): | ||
| tasks[i].cancel() | ||
|
|
||
| # allow enough time for the non-cancelled tasks to complete | ||
| await asyncio.sleep(3) | ||
|
|
||
| # check state | ||
| assert cache_item_task_finished == (num_to_cancel < 3) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_cancel_single_waiter_triggers_handle_cancelled_error() -> None: | ||
| # This test ensures the _handle_cancelled_error path (waiters == 1) is exercised. | ||
| cache_item_task_finished = False | ||
|
|
||
| @alru_cache | ||
| async def coro(val: int) -> int: | ||
| nonlocal cache_item_task_finished | ||
| await asyncio.sleep(2) | ||
| cache_item_task_finished = True | ||
| return val | ||
|
|
||
| task = asyncio.create_task(coro(42)) | ||
| await asyncio.sleep(0) | ||
| task.cancel() | ||
| try: | ||
| await task | ||
| except asyncio.CancelledError: | ||
| pass | ||
|
|
||
| # The underlying coroutine should be cancelled, so the flag should remain False | ||
| assert cache_item_task_finished is False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a list is probably overkill, we could just return an iterator for all the uses I saw.