diff --git a/aiocache/decorators.py b/aiocache/decorators.py index d2c41b24a..4b9ffc630 100644 --- a/aiocache/decorators.py +++ b/aiocache/decorators.py @@ -44,12 +44,38 @@ def __init__( self.cache = cache def __call__(self, f): - @functools.wraps(f) - async def wrapper(*args, **kwargs): - return await self.decorator(f, *args, **kwargs) + class CachedFunctionWrapper: + def __init__(self, decorator, func): + functools.update_wrapper(self, func) + self._decorator = decorator + self._func = func + self.cache = decorator.cache + + async def __call__(self, *args, **kwargs): + return await self._decorator.decorator(self._func, *args, **kwargs) + + async def refresh(self, *args, **kwargs): + """ + Force a refresh of the cache. + + This method recomputes the result by calling the original function, + then updates the cache with the new value for the given arguments. + """ + return await self._decorator.decorator( + self._func, *args, cache_read=False, cache_write=True, **kwargs + ) + + async def invalidate(self, *args, **kwargs): + """ + Invalidate the cache for the given key, or clear all cache if no key is provided. + """ + if not args and not kwargs: + await self.cache.clear() + else: + key = self._decorator.get_cache_key(self._func, args, kwargs) + await self.cache.delete(key) - wrapper.cache = self.cache - return wrapper + return CachedFunctionWrapper(self, f) async def decorator( self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs @@ -229,12 +255,38 @@ def __init__( self.ttl = ttl def __call__(self, f): - @functools.wraps(f) - async def wrapper(*args, **kwargs): - return await self.decorator(f, *args, **kwargs) + class CachedFunctionWrapper: + def __init__(self, decorator, func): + functools.update_wrapper(self, func) + self._decorator = decorator + self._func = func + self.cache = decorator.cache + + async def __call__(self, *args, **kwargs): + return await self._decorator.decorator(self._func, *args, **kwargs) + + async def refresh(self, *args, **kwargs): + """ + Force a cache refresh for the given key. + + This method recomputes the result by calling the original function, + then updates the cache with the new value for the given key. + """ + return await self._decorator.decorator( + self._func, *args, cache_read=False, cache_write=True, **kwargs + ) + + async def invalidate(self, *args, **kwargs): + """Invalidate the cache for the provided key, or clear all if no key is provided.""" + if not args and not kwargs: + await self.cache.clear() + else: + # Invalidate each key in args + for key in args: + cache_key = self._decorator.key_builder(key, self._func, *args, **kwargs) + await self.cache.delete(cache_key) - wrapper.cache = self.cache - return wrapper + return CachedFunctionWrapper(self, f) async def decorator( self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs diff --git a/docs/decorators.rst b/docs/decorators.rst index 68319ed62..7af12d59e 100644 --- a/docs/decorators.rst +++ b/docs/decorators.rst @@ -17,6 +17,22 @@ cached :language: python :linenos: +The ``@cached`` decorator returns a wrapper object that exposes cache control methods, such as ``.refresh()`` and ``.invalidate()``. Use ``.refresh()`` to force a cache refresh for the given arguments, bypassing the cache. + +**Example:** + +.. code-block:: python + + @cached() + async def compute(x): + return x * 2 + + await compute(1) # Uses cache if available + await compute.refresh(1) # Forces refresh, updates cache + + await compute.invalidate() # Invalidate all cache keys + await compute.invalidate(key) # Invalidate a specific cache key + .. _multi_cached: multi_cached diff --git a/tests/ut/test_decorators.py b/tests/ut/test_decorators.py index 7fe4c68ef..979ffdb4e 100644 --- a/tests/ut/test_decorators.py +++ b/tests/ut/test_decorators.py @@ -94,6 +94,39 @@ async def test_cache_write_disabled(self, decorator, decorator_call): assert decorator.cache.set.call_count == 0 assert stub.call_count == 1 + @pytest.mark.asyncio + async def test_cached_refresh_forces_update(self, monkeypatch): + calls = [] + cache = SimpleMemoryCache() + + @cached(cache=cache) + async def foo(x): + calls.append(x) + return x * 2 + + assert await foo(3) == 6 + assert await foo(3) == 6 + assert await foo.refresh(3) == 6 + assert calls == [3, 3] + + @pytest.mark.asyncio + async def test_cached_invalidate_key_and_all(self): + calls = [] + cache = SimpleMemoryCache() + + @cached(cache=cache) + async def foo(x): + calls.append(x) + return x * 2 + + await foo(1) + await foo(2) + await foo.invalidate(1) + await foo(1) + await foo.invalidate() + await foo(2) + assert calls == [1, 2, 1, 2] + async def test_disable_params_not_propagated(self, decorator, decorator_call): decorator.cache.get.return_value = None @@ -431,6 +464,38 @@ async def test_cache_write_disabled(self, decorator, decorator_call): assert decorator.cache.multi_set.call_count == 0 assert stub_dict.call_count == 1 + @pytest.mark.asyncio + async def test_multi_cached_refresh_forces_update(self, monkeypatch): + calls = [] + cache = SimpleMemoryCache() + + @multi_cached(cache=cache, keys_from_attr="keys") + async def bar(keys=None): + calls.append(tuple(keys)) + return {k: k * 10 for k in keys} + + assert await bar(keys=[1, 2]) == {1: 10, 2: 20} + assert await bar(keys=[1, 2]) == {1: 10, 2: 20} + assert await bar.refresh(keys=[1, 2]) == {1: 10, 2: 20} + assert calls == [(1, 2), (1, 2)] + + @pytest.mark.asyncio + async def test_multi_cached_invalidate_key_and_all(self): + calls = [] + cache = SimpleMemoryCache() + + @multi_cached(cache=cache, keys_from_attr="keys") + async def bar(keys=None): + calls.extend(keys) + return {k: k * 10 for k in keys} + + await bar(keys=[1, 2]) + await bar.invalidate(1) + await bar(keys=[1]) + await bar.invalidate() + await bar(keys=[2]) + assert calls == [1, 2, 1, 2] + async def test_disable_params_not_propagated(self, decorator, decorator_call): decorator.cache.multi_get.return_value = [None, None]