@@ -57,6 +57,7 @@ class _CacheItem(Generic[_R]):
5757 fut : "asyncio.Future[_R]"
5858 later_call : Optional [asyncio .Handle ]
5959 waiters : int
60+ task : "asyncio.Task[_R]"
6061
6162 def cancel (self ) -> None :
6263 if self .later_call is not None :
@@ -193,7 +194,17 @@ def _task_done_callback(
193194
194195 fut .set_result (task .result ())
195196
197+ def _handle_cancelled_error (self , key : Hashable , cache_item : "_CacheItem" ) -> None :
198+ # Called when a waiter is cancelled.
199+ # If this is the last waiter and the underlying task is not done,
200+ # cancel the underlying task and remove the cache entry.
201+ if cache_item .waiters == 1 and not cache_item .task .done ():
202+ cache_item .cancel () # Cancel TTL expiration
203+ cache_item .task .cancel () # Cancel the running coroutine
204+ self .__cache .pop (key , None ) # Remove from cache
205+
196206 async def __call__ (self , / , * fn_args : Any , ** fn_kwargs : Any ) -> _R :
207+ # Main entry point for cached coroutine calls.
197208 if self .__closed :
198209 raise RuntimeError (f"alru_cache is closed for { self } " )
199210
@@ -206,15 +217,21 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
206217 if cache_item is not None :
207218 self ._cache_hit (key )
208219 if not cache_item .fut .done ():
220+
221+ # Each logical waiter increments waiters on entry.
209222 cache_item .waiters += 1
223+
210224 try :
225+ # All waiters await the same future.
211226 return await asyncio .shield (cache_item .fut )
212227 except asyncio .CancelledError :
213- _handle_cancelled_error (cache_item , task )
228+ # If a waiter is cancelled, handle possible last-waiter cleanup.
229+ self ._handle_cancelled_error (key , cache_item )
214230 raise
215231 finally :
232+ # Each logical waiter decrements waiters on exit (normal or cancelled).
216233 cache_item .waiters -= 1
217-
234+ # If the future is already done, just return the result.
218235 return cache_item .fut .result ()
219236
220237 fut = loop .create_future ()
@@ -223,19 +240,19 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
223240 self .__tasks .add (task )
224241 task .add_done_callback (partial (self ._task_done_callback , fut , key ))
225242
226- cache_item = _CacheItem (fut , None , 1 )
243+ cache_item = _CacheItem (fut , None , 1 , task )
227244 self .__cache [key ] = cache_item
228245
229246 if self .__maxsize is not None and len (self .__cache ) > self .__maxsize :
230- dropped_key , cache_item = self .__cache .popitem (last = False )
231- cache_item .cancel ()
247+ dropped_key , dropped_cache_item = self .__cache .popitem (last = False )
248+ dropped_cache_item .cancel ()
232249
233250 self ._cache_miss (key )
234251
235252 try :
236253 return await asyncio .shield (fut )
237254 except asyncio .CancelledError :
238- _handle_cancelled_error (cache_item , task )
255+ self . _handle_cancelled_error (key , cache_item )
239256 raise
240257 finally :
241258 cache_item .waiters -= 1
@@ -249,13 +266,6 @@ def __get__(
249266 return _LRUCacheWrapperInstanceMethod (self , instance )
250267
251268
252- def _handle_cancelled_error (cache_item : _CacheItem , task : asyncio .Task [Any ]) -> None :
253- if cache_item .waiters == 1 and not task .done ():
254- task .cancel ()
255- cache_item .cancel ()
256- self .__cache .pop (key )
257-
258-
259269@final
260270class _LRUCacheWrapperInstanceMethod (Generic [_R , _T ]):
261271 def __init__ (
0 commit comments