Skip to content

Commit 0c0fce7

Browse files
committed
finish up impl
1 parent d1fe71e commit 0c0fce7

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

async_lru/__init__.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class _CacheItem(Generic[_R]):
5757
fut: "asyncio.Future[_R]"
5858
later_call: Optional[asyncio.Handle]
5959
waiters: int
60+
task: "asyncio.Task[_R]"
6061

6162
def cancel(self) -> None:
6263
if self.later_call is not None:
@@ -193,7 +194,17 @@ def _task_done_callback(
193194

194195
fut.set_result(task.result())
195196

197+
def _handle_cancelled_error(self, key: Hashable, cache_item: "_CacheItem") -> None:
198+
# Called when a waiter is cancelled.
199+
# If this is the last waiter and the underlying task is not done,
200+
# cancel the underlying task and remove the cache entry.
201+
if cache_item.waiters == 1 and not cache_item.task.done():
202+
cache_item.cancel() # Cancel TTL expiration
203+
cache_item.task.cancel() # Cancel the running coroutine
204+
self.__cache.pop(key, None) # Remove from cache
205+
196206
async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
207+
# Main entry point for cached coroutine calls.
197208
if self.__closed:
198209
raise RuntimeError(f"alru_cache is closed for {self}")
199210

@@ -206,15 +217,21 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
206217
if cache_item is not None:
207218
self._cache_hit(key)
208219
if not cache_item.fut.done():
220+
221+
# Each logical waiter increments waiters on entry.
209222
cache_item.waiters += 1
223+
210224
try:
225+
# All waiters await the same future.
211226
return await asyncio.shield(cache_item.fut)
212227
except asyncio.CancelledError:
213-
_handle_cancelled_error(cache_item, task)
228+
# If a waiter is cancelled, handle possible last-waiter cleanup.
229+
self._handle_cancelled_error(key, cache_item)
214230
raise
215231
finally:
232+
# Each logical waiter decrements waiters on exit (normal or cancelled).
216233
cache_item.waiters -= 1
217-
234+
# If the future is already done, just return the result.
218235
return cache_item.fut.result()
219236

220237
fut = loop.create_future()
@@ -223,19 +240,19 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
223240
self.__tasks.add(task)
224241
task.add_done_callback(partial(self._task_done_callback, fut, key))
225242

226-
cache_item = _CacheItem(fut, None, 1)
243+
cache_item = _CacheItem(fut, None, 1, task)
227244
self.__cache[key] = cache_item
228245

229246
if self.__maxsize is not None and len(self.__cache) > self.__maxsize:
230-
dropped_key, cache_item = self.__cache.popitem(last=False)
231-
cache_item.cancel()
247+
dropped_key, dropped_cache_item = self.__cache.popitem(last=False)
248+
dropped_cache_item.cancel()
232249

233250
self._cache_miss(key)
234251

235252
try:
236253
return await asyncio.shield(fut)
237254
except asyncio.CancelledError:
238-
_handle_cancelled_error(cache_item, task)
255+
self._handle_cancelled_error(key, cache_item)
239256
raise
240257
finally:
241258
cache_item.waiters -= 1
@@ -249,13 +266,6 @@ def __get__(
249266
return _LRUCacheWrapperInstanceMethod(self, instance)
250267

251268

252-
def _handle_cancelled_error(cache_item: _CacheItem, task: asyncio.Task[Any]) -> None:
253-
if cache_item.waiters == 1 and not task.done():
254-
task.cancel()
255-
cache_item.cancel()
256-
self.__cache.pop(key)
257-
258-
259269
@final
260270
class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
261271
def __init__(

0 commit comments

Comments
 (0)