diff --git a/async_lru/__init__.py b/async_lru/__init__.py index 447e9cdb..1ebbdcf8 100644 --- a/async_lru/__init__.py +++ b/async_lru/__init__.py @@ -9,6 +9,7 @@ Coroutine, Generic, Hashable, + Literal, Optional, OrderedDict, Set, @@ -63,8 +64,7 @@ def cancel(self) -> None: self.later_call = None -@final -class _LRUCacheWrapper(Generic[_R]): +class _LRUCacheWrapperBase(Generic[_R]): def __init__( self, fn: _CB[_R], @@ -161,8 +161,7 @@ def cache_parameters(self) -> _CacheParameters: ) def _cache_hit(self, key: Hashable) -> None: - self.__hits += 1 - self.__cache.move_to_end(key) + raise NotImplementedError("must be implemented by subclass") def _cache_miss(self, key: Hashable) -> None: self.__misses += 1 @@ -192,6 +191,23 @@ def _task_done_callback( fut.set_result(task.result()) + async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: + raise NotImplementedError("must be implemented by subclass") + + def __get__( + self, instance: _T, owner: Optional[Type[_T]] + ) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]: + if owner is None: + return self + else: + return _LRUCacheWrapperInstanceMethod(self, instance) + +@final +class _LRUCacheWrapper(_LRUCacheWrapperBase[_R]): + def _cache_hit(self, key: Hashable) -> None: + self.__hits += 1 + self.__cache.move_to_end(key) + async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: if self.__closed: raise RuntimeError(f"alru_cache is closed for {self}") @@ -217,20 +233,44 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: self.__cache[key] = _CacheItem(fut, None) - if self.__maxsize is not None and len(self.__cache) > self.__maxsize: + if len(self.__cache) > self.__maxsize: dropped_key, cache_item = self.__cache.popitem(last=False) cache_item.cancel() self._cache_miss(key) return await asyncio.shield(fut) + +@final +class _LRUCacheWrapperUnbounded(_LRUCacheWrapperBase[_R]): + def _cache_hit(self, key: Hashable) -> None: + self.__hits += 1 - def __get__( - self, instance: _T, owner: Optional[Type[_T]] - ) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]: - if owner is None: - return self - else: - return _LRUCacheWrapperInstanceMethod(self, instance) + async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: + if self.__closed: + raise RuntimeError(f"alru_cache is closed for {self}") + + loop = asyncio.get_running_loop() + + key = _make_key(fn_args, fn_kwargs, self.__typed) + + cache_item = self.__cache.get(key) + + if cache_item is not None: + self._cache_hit(key) + if not cache_item.fut.done(): + return await asyncio.shield(cache_item.fut) + + return cache_item.fut.result() + + fut = loop.create_future() + coro = self.__wrapped__(*fn_args, **fn_kwargs) + task: asyncio.Task[_R] = loop.create_task(coro) + self.__tasks.add(task) + task.add_done_callback(partial(self._task_done_callback, fut, key)) + + self.__cache[key] = _CacheItem(fut, None) + self._cache_miss(key) + return await asyncio.shield(fut) @final @@ -293,11 +333,32 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs) +@overload def _make_wrapper( - maxsize: Optional[int], + maxsize: int, typed: bool, ttl: Optional[float] = None, ) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]: + ... + + +@overload +def _make_wrapper( + maxsize: Literal[None], + typed: bool, + ttl: Optional[float] = None, +) -> Callable[[_CBP[_R]], _LRUCacheWrapperUnbounded[_R]]: + ... + + +def _make_wrapper( + maxsize: Optional[int], + typed: bool, + ttl: Optional[float] = None, +) -> Union[ + Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], + Callable[[_CBP[_R]], _LRUCacheWrapperUnbounded[_R]], +]: def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]: origin = fn @@ -311,7 +372,8 @@ def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]: if hasattr(fn, "_make_unbound_method"): fn = fn._make_unbound_method() - wrapper = _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl) + wrapper_cls = _LRUCacheWrapperUnbounded if maxsize is None else _LRUCacheWrapper + wrapper = wrapper_cls(cast(_CB[_R], fn), maxsize, typed, ttl) if sys.version_info >= (3, 12): wrapper = inspect.markcoroutinefunction(wrapper) return wrapper @@ -319,6 +381,16 @@ def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]: return wrapper +@overload +def alru_cache( + maxsize: Literal[None], + typed: bool = False, + *, + ttl: Optional[float] = None, +) -> Callable[[_CBP[_R]], _LRUCacheWrapperUnbounded[_R]]: + ... + + @overload def alru_cache( maxsize: Optional[int] = 128, @@ -342,7 +414,11 @@ def alru_cache( typed: bool = False, *, ttl: Optional[float] = None, -) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]: +) -> Union[ + Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], + Callable[[_CBP[_R]], _LRUCacheWrapperUnbounded[_R]], + _LRUCacheWrapper[_R], +]: if maxsize is None or isinstance(maxsize, int): return _make_wrapper(maxsize, typed, ttl) else: