Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 91 additions & 15 deletions async_lru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Coroutine,
Generic,
Hashable,
Literal,
Optional,
OrderedDict,
Set,
Expand Down Expand Up @@ -63,8 +64,7 @@ def cancel(self) -> None:
self.later_call = None


@final
class _LRUCacheWrapper(Generic[_R]):
class _LRUCacheWrapperBase(Generic[_R]):
def __init__(
self,
fn: _CB[_R],
Expand Down Expand Up @@ -161,8 +161,7 @@ def cache_parameters(self) -> _CacheParameters:
)

def _cache_hit(self, key: Hashable) -> None:
self.__hits += 1
self.__cache.move_to_end(key)
raise NotImplementedError("must be implemented by subclass")

def _cache_miss(self, key: Hashable) -> None:
self.__misses += 1
Expand Down Expand Up @@ -192,6 +191,23 @@ def _task_done_callback(

fut.set_result(task.result())

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
raise NotImplementedError("must be implemented by subclass")

def __get__(
self, instance: _T, owner: Optional[Type[_T]]
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
if owner is None:
return self
else:
return _LRUCacheWrapperInstanceMethod(self, instance)

@final
class _LRUCacheWrapper(_LRUCacheWrapperBase[_R]):
def _cache_hit(self, key: Hashable) -> None:
self.__hits += 1
self.__cache.move_to_end(key)

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
if self.__closed:
raise RuntimeError(f"alru_cache is closed for {self}")
Expand All @@ -217,20 +233,44 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:

self.__cache[key] = _CacheItem(fut, None)

if self.__maxsize is not None and len(self.__cache) > self.__maxsize:
if len(self.__cache) > self.__maxsize:
dropped_key, cache_item = self.__cache.popitem(last=False)
cache_item.cancel()

self._cache_miss(key)
return await asyncio.shield(fut)

@final
class _LRUCacheWrapperUnbounded(_LRUCacheWrapperBase[_R]):
def _cache_hit(self, key: Hashable) -> None:
self.__hits += 1

def __get__(
self, instance: _T, owner: Optional[Type[_T]]
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
if owner is None:
return self
else:
return _LRUCacheWrapperInstanceMethod(self, instance)
async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
if self.__closed:
raise RuntimeError(f"alru_cache is closed for {self}")

loop = asyncio.get_running_loop()

key = _make_key(fn_args, fn_kwargs, self.__typed)

cache_item = self.__cache.get(key)

if cache_item is not None:
self._cache_hit(key)
if not cache_item.fut.done():
return await asyncio.shield(cache_item.fut)

return cache_item.fut.result()

fut = loop.create_future()
coro = self.__wrapped__(*fn_args, **fn_kwargs)
task: asyncio.Task[_R] = loop.create_task(coro)
self.__tasks.add(task)
task.add_done_callback(partial(self._task_done_callback, fut, key))

self.__cache[key] = _CacheItem(fut, None)
self._cache_miss(key)
return await asyncio.shield(fut)


@final
Expand Down Expand Up @@ -293,11 +333,32 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)


@overload
def _make_wrapper(
maxsize: Optional[int],
maxsize: int,
typed: bool,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
...


@overload
def _make_wrapper(
maxsize: Literal[None],
typed: bool,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapperUnbounded[_R]]:
...


def _make_wrapper(
maxsize: Optional[int],
typed: bool,
ttl: Optional[float] = None,
) -> Union[
Callable[[_CBP[_R]], _LRUCacheWrapper[_R]],
Callable[[_CBP[_R]], _LRUCacheWrapperUnbounded[_R]],
]:
def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
origin = fn

Expand All @@ -311,14 +372,25 @@ def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
if hasattr(fn, "_make_unbound_method"):
fn = fn._make_unbound_method()

wrapper = _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl)
wrapper_cls = _LRUCacheWrapperUnbounded if maxsize is None else _LRUCacheWrapper
wrapper = wrapper_cls(cast(_CB[_R], fn), maxsize, typed, ttl)
if sys.version_info >= (3, 12):
wrapper = inspect.markcoroutinefunction(wrapper)
return wrapper

return wrapper


@overload
def alru_cache(
maxsize: Literal[None],
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapperUnbounded[_R]]:
...


@overload
def alru_cache(
maxsize: Optional[int] = 128,
Expand All @@ -342,7 +414,11 @@ def alru_cache(
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
) -> Union[
Callable[[_CBP[_R]], _LRUCacheWrapper[_R]],
Callable[[_CBP[_R]], _LRUCacheWrapperUnbounded[_R]],
_LRUCacheWrapper[_R],
]:
if maxsize is None or isinstance(maxsize, int):
return _make_wrapper(maxsize, typed, ttl)
else:
Expand Down
Loading