Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
865db04
feat: handle CancelledError - cancel if no other waiters
BobTheBuidler Nov 2, 2025
9244bc7
Update __init__.py
BobTheBuidler Nov 2, 2025
8c2b878
Create test_cancel.py
BobTheBuidler Nov 2, 2025
b283118
Update test_cancel.py
BobTheBuidler Nov 2, 2025
f8ceb96
Update test_cancel.py
Dreamsorcerer Nov 2, 2025
27100f3
fix: name error CancelledError
BobTheBuidler Nov 2, 2025
19b1b3b
feat(test): more cancel tests
BobTheBuidler Nov 2, 2025
e85e099
finish up impl
BobTheBuidler Nov 2, 2025
2c4d654
Update test_cancel.py
BobTheBuidler Nov 2, 2025
f9cfbf9
lint
BobTheBuidler Nov 2, 2025
6441a1d
Update __init__.py
BobTheBuidler Nov 2, 2025
94c303b
Update __init__.py
BobTheBuidler Nov 3, 2025
1dec1b4
chore: refactor out __task and fut
BobTheBuidler Nov 3, 2025
85f2f51
Update __init__.py
BobTheBuidler Nov 3, 2025
094a3cc
Update __init__.py
BobTheBuidler Nov 3, 2025
0b06d9b
fix missing import
BobTheBuidler Nov 3, 2025
b439eca
Update __init__.py
BobTheBuidler Nov 3, 2025
83f921c
Merge branch 'master' into refactored
BobTheBuidler Nov 7, 2025
f2152ee
Update __init__.py
BobTheBuidler Nov 7, 2025
4bbab74
Update test_cancel.py
BobTheBuidler Nov 7, 2025
157c7d5
Update benchmark.py
BobTheBuidler Nov 7, 2025
23daa05
Update test_internals.py
BobTheBuidler Nov 7, 2025
3972c0f
Update test_basic.py
BobTheBuidler Nov 7, 2025
4878867
Update test_internals.py
BobTheBuidler Nov 7, 2025
3943555
Update __init__.py
BobTheBuidler Nov 7, 2025
d76b537
Update __init__.py
BobTheBuidler Nov 7, 2025
9ed3a42
Update __init__.py
BobTheBuidler Nov 7, 2025
4675c7a
Update test_internals.py
BobTheBuidler Nov 7, 2025
33e5dca
Update test_internals.py
BobTheBuidler Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 56 additions & 30 deletions async_lru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
Coroutine,
Generic,
Hashable,
List,
Optional,
OrderedDict,
Set,
Type,
TypedDict,
TypeVar,
Expand Down Expand Up @@ -54,8 +54,9 @@ class _CacheParameters(TypedDict):
@final
@dataclasses.dataclass
class _CacheItem(Generic[_R]):
fut: "asyncio.Future[_R]"
task: "asyncio.Task[_R]"
later_call: Optional[asyncio.Handle]
waiters: int

def cancel(self) -> None:
if self.later_call is not None:
Expand Down Expand Up @@ -108,7 +109,17 @@ def __init__(
self.__closed = False
self.__hits = 0
self.__misses = 0
self.__tasks: Set["asyncio.Task[_R]"] = set()

@property
def __tasks(self) -> List["asyncio.Task[_R]"]:
# NOTE: I don't think we need to form a set first here but not too sure we want it for guarantees
return list(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a list is probably overkill, we could just return an iterator for all the uses I saw.

{
cache_item.task
for cache_item in self.__cache.values()
if not cache_item.task.done()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the condition is needed or useful for the use cases I saw.

}
)

def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
key = _make_key(args, kwargs, self.__typed)
Expand All @@ -128,12 +139,11 @@ def cache_clear(self) -> None:
if c.later_call:
c.later_call.cancel()
self.__cache.clear()
self.__tasks.clear()

async def cache_close(self, *, wait: bool = False) -> None:
self.__closed = True

tasks = list(self.__tasks)
tasks = self.__tasks
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to materialize tasks into a list or set or some other container here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, that's already handled by the original code.

if not tasks:
return

Expand Down Expand Up @@ -167,19 +177,8 @@ def _cache_hit(self, key: Hashable) -> None:
def _cache_miss(self, key: Hashable) -> None:
self.__misses += 1

def _task_done_callback(
self, fut: "asyncio.Future[_R]", key: Hashable, task: "asyncio.Task[_R]"
) -> None:
self.__tasks.discard(task)

if task.cancelled():
fut.cancel()
self.__cache.pop(key, None)
return

exc = task.exception()
if exc is not None:
fut.set_exception(exc)
def _task_done_callback(self, key: Hashable, task: "asyncio.Task[_R]") -> None:
if task.cancelled() or task.exception() is not None:
self.__cache.pop(key, None)
return

Expand All @@ -190,7 +189,16 @@ def _task_done_callback(
self.__ttl, self.__cache.pop, key, None
)

fut.set_result(task.result())
def _handle_cancelled_error(
self, key: Hashable, cache_item: "_CacheItem[Any]"
) -> None:
# Called when a waiter is cancelled.
# If this is the last waiter and the underlying task is not done,
# cancel the underlying task and remove the cache entry.
if cache_item.waiters == 1 and not cache_item.task.done():
cache_item.cancel() # Cancel TTL expiration
cache_item.task.cancel() # Cancel the running coroutine
self.__cache.pop(key, None) # Remove from cache

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
if self.__closed:
Expand All @@ -204,25 +212,43 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:

if cache_item is not None:
self._cache_hit(key)
if not cache_item.fut.done():
return await asyncio.shield(cache_item.fut)

return cache_item.fut.result()
if not cache_item.task.done():
# Each logical waiter increments waiters on entry.
cache_item.waiters += 1

try:
# All waiters await the same shielded task.
return await asyncio.shield(cache_item.task)
except asyncio.CancelledError:
# If a waiter is cancelled, handle possible last-waiter cleanup.
self._handle_cancelled_error(key, cache_item)
raise
finally:
# Each logical waiter decrements waiters on exit (normal or cancelled).
cache_item.waiters -= 1
# If the task is already done, just return the result.
return cache_item.task.result()

fut = loop.create_future()
coro = self.__wrapped__(*fn_args, **fn_kwargs)
task: asyncio.Task[_R] = loop.create_task(coro)
self.__tasks.add(task)
task.add_done_callback(partial(self._task_done_callback, fut, key))
task.add_done_callback(partial(self._task_done_callback, key))

self.__cache[key] = _CacheItem(fut, None)
cache_item = _CacheItem(task, None, 1)
self.__cache[key] = cache_item

if self.__maxsize is not None and len(self.__cache) > self.__maxsize:
dropped_key, cache_item = self.__cache.popitem(last=False)
cache_item.cancel()
dropped_key, dropped_cache_item = self.__cache.popitem(last=False)
dropped_cache_item.cancel()

self._cache_miss(key)
return await asyncio.shield(fut)

try:
return await asyncio.shield(task)
except asyncio.CancelledError:
self._handle_cancelled_error(key, cache_item)
raise
finally:
cache_item.waiters -= 1

def __get__(
self, instance: _T, owner: Optional[Type[_T]]
Expand Down
3 changes: 1 addition & 2 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,10 @@ async def dummy_coro():
pass

iterations = range(1000)
create_future = loop.create_future
callback_fn = func._task_done_callback

@benchmark
def run() -> None:
for i in iterations:
callback = partial(callback_fn, create_future(), i)
callback = partial(callback_fn, i)
callback(task)
4 changes: 2 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ async def coro(val: int) -> int:
assert ret1 == ret2

assert (
coro1._LRUCacheWrapper__cache[1].fut.result() # type: ignore[attr-defined]
== coro2._LRUCacheWrapper__cache[1].fut.result() # type: ignore[attr-defined]
coro1._LRUCacheWrapper__cache[1].task.result() # type: ignore[attr-defined]
== coro2._LRUCacheWrapper__cache[1].task.result() # type: ignore[attr-defined]
)
assert coro1._LRUCacheWrapper__cache != coro2._LRUCacheWrapper__cache # type: ignore[attr-defined]
assert coro1._LRUCacheWrapper__cache.keys() == coro2._LRUCacheWrapper__cache.keys() # type: ignore[attr-defined]
Expand Down
58 changes: 58 additions & 0 deletions tests/test_cancel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import asyncio

import pytest

from async_lru import alru_cache


@pytest.mark.parametrize("num_to_cancel", [0, 1, 2, 3])
async def test_cancel(num_to_cancel: int) -> None:
cache_item_task_finished = False

@alru_cache
async def coro(val: int) -> int:
# I am a long running coro function
nonlocal cache_item_task_finished
await asyncio.sleep(2)
cache_item_task_finished = True
return val

# create 3 tasks for the cached function using the same key
tasks = [asyncio.create_task(coro(1)) for _ in range(3)]

# force the event loop to run once so the tasks can begin
await asyncio.sleep(0)

# maybe cancel some tasks
for i in range(num_to_cancel):
tasks[i].cancel()

# allow enough time for the non-cancelled tasks to complete
await asyncio.sleep(3)

# check state
assert cache_item_task_finished == (num_to_cancel < 3)


@pytest.mark.asyncio
async def test_cancel_single_waiter_triggers_handle_cancelled_error() -> None:
# This test ensures the _handle_cancelled_error path (waiters == 1) is exercised.
cache_item_task_finished = False

@alru_cache
async def coro(val: int) -> int:
nonlocal cache_item_task_finished
await asyncio.sleep(2)
cache_item_task_finished = True
return val

task = asyncio.create_task(coro(42))
await asyncio.sleep(0)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

# The underlying coroutine should be cancelled, so the flag should remain False
assert cache_item_task_finished is False
36 changes: 4 additions & 32 deletions tests/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,62 +11,34 @@ async def test_done_callback_cancelled() -> None:
wrapped = _LRUCacheWrapper(mock.ANY, None, False, None)
loop = asyncio.get_running_loop()
task = loop.create_future()
fut = loop.create_future()

key = 1

task.add_done_callback(partial(wrapped._task_done_callback, fut, key))
wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined]
task.add_done_callback(partial(wrapped._task_done_callback, key))

task.cancel()

await asyncio.sleep(0)

assert fut.cancelled()
assert task not in wrapped._LRUCacheWrapper__tasks # type: ignore[attr-defined]


async def test_done_callback_exception() -> None:
wrapped = _LRUCacheWrapper(mock.ANY, None, False, None)
loop = asyncio.get_running_loop()
task = loop.create_future()
fut = loop.create_future()

key = 1

task.add_done_callback(partial(wrapped._task_done_callback, fut, key))
wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined]
task.add_done_callback(partial(wrapped._task_done_callback, key))

exc = ZeroDivisionError()

task.set_exception(exc)

await asyncio.sleep(0)

with pytest.raises(ZeroDivisionError):
await fut

with pytest.raises(ZeroDivisionError):
fut.result()

assert fut.exception() is exc


async def test_done_callback() -> None:
wrapped = _LRUCacheWrapper(mock.ANY, None, False, None)
loop = asyncio.get_running_loop()
task = loop.create_future()

key = 1
fut = loop.create_future()

task.add_done_callback(partial(wrapped._task_done_callback, fut, key))
wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined]

task.set_result(1)

await asyncio.sleep(0)

assert fut.result() == 1
assert task not in wrapped._LRUCacheWrapper__tasks # type: ignore[attr-defined]


async def test_cache_invalidate_typed() -> None:
Expand Down
Loading