Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions aiocache/decorators.py
Copy link
Member

@Dreamsorcerer Dreamsorcerer May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you refer back to the example in the issue I posted, my intention was to replace the cached class completely. This currently feels just as awkward as before with the wrapper class referencing the decorator class.

My expectation is that the cached decorator will literally just be defined as:

def cached(func):
    return Wrapper(func)

All other logic can exist in the Wrapper class.

Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,38 @@ def __init__(
self.cache = cache

def __call__(self, f):
@functools.wraps(f)
async def wrapper(*args, **kwargs):
return await self.decorator(f, *args, **kwargs)
class CachedFunctionWrapper:
def __init__(self, decorator, func):
functools.update_wrapper(self, func)
self._decorator = decorator
self._func = func
self.cache = decorator.cache

async def __call__(self, *args, **kwargs):
return await self._decorator.decorator(self._func, *args, **kwargs)

async def refresh(self, *args, **kwargs):
"""
Force a refresh of the cache.

This method recomputes the result by calling the original function,
then updates the cache with the new value for the given arguments.
"""
return await self._decorator.decorator(
self._func, *args, cache_read=False, cache_write=True, **kwargs
)

async def invalidate(self, *args, **kwargs):
"""
Invalidate the cache for the given key, or clear all cache if no key is provided.
"""
if not args and not kwargs:
await self.cache.clear()
else:
key = self._decorator.get_cache_key(self._func, args, kwargs)
await self.cache.delete(key)
Comment on lines +57 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip this in the initial PR and add it in a followup PR?


wrapper.cache = self.cache
return wrapper
return CachedFunctionWrapper(self, f)

async def decorator(
self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs
Expand Down Expand Up @@ -229,12 +255,38 @@ def __init__(
self.ttl = ttl

def __call__(self, f):
@functools.wraps(f)
async def wrapper(*args, **kwargs):
return await self.decorator(f, *args, **kwargs)
class CachedFunctionWrapper:
def __init__(self, decorator, func):
functools.update_wrapper(self, func)
self._decorator = decorator
self._func = func
self.cache = decorator.cache

async def __call__(self, *args, **kwargs):
return await self._decorator.decorator(self._func, *args, **kwargs)

async def refresh(self, *args, **kwargs):
"""
Force a cache refresh for the given key.

This method recomputes the result by calling the original function,
then updates the cache with the new value for the given key.
"""
return await self._decorator.decorator(
self._func, *args, cache_read=False, cache_write=True, **kwargs
)

async def invalidate(self, *args, **kwargs):
"""Invalidate the cache for the provided key, or clear all if no key is provided."""
if not args and not kwargs:
await self.cache.clear()
else:
# Invalidate each key in args
for key in args:
cache_key = self._decorator.key_builder(key, self._func, *args, **kwargs)
await self.cache.delete(cache_key)

wrapper.cache = self.cache
return wrapper
return CachedFunctionWrapper(self, f)

async def decorator(
self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs
Expand Down
16 changes: 16 additions & 0 deletions docs/decorators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@ cached
:language: python
:linenos:

The ``@cached`` decorator returns a wrapper object that exposes cache control methods, such as ``.refresh()`` and ``.invalidate()``. Use ``.refresh()`` to force a cache refresh for the given arguments, bypassing the cache.

**Example:**

.. code-block:: python

@cached()
async def compute(x):
return x * 2

await compute(1) # Uses cache if available
await compute.refresh(1) # Forces refresh, updates cache

await compute.invalidate() # Invalidate all cache keys
await compute.invalidate(key) # Invalidate a specific cache key

.. _multi_cached:

multi_cached
Expand Down
65 changes: 65 additions & 0 deletions tests/ut/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,39 @@ async def test_cache_write_disabled(self, decorator, decorator_call):
assert decorator.cache.set.call_count == 0
assert stub.call_count == 1

@pytest.mark.asyncio
async def test_cached_refresh_forces_update(self, monkeypatch):
calls = []
cache = SimpleMemoryCache()

@cached(cache=cache)
async def foo(x):
calls.append(x)
return x * 2

assert await foo(3) == 6
assert await foo(3) == 6
assert await foo.refresh(3) == 6
assert calls == [3, 3]

@pytest.mark.asyncio
async def test_cached_invalidate_key_and_all(self):
calls = []
cache = SimpleMemoryCache()

@cached(cache=cache)
async def foo(x):
calls.append(x)
return x * 2

await foo(1)
await foo(2)
await foo.invalidate(1)
await foo(1)
await foo.invalidate()
await foo(2)
assert calls == [1, 2, 1, 2]

async def test_disable_params_not_propagated(self, decorator, decorator_call):
decorator.cache.get.return_value = None

Expand Down Expand Up @@ -431,6 +464,38 @@ async def test_cache_write_disabled(self, decorator, decorator_call):
assert decorator.cache.multi_set.call_count == 0
assert stub_dict.call_count == 1

@pytest.mark.asyncio
async def test_multi_cached_refresh_forces_update(self, monkeypatch):
calls = []
cache = SimpleMemoryCache()

@multi_cached(cache=cache, keys_from_attr="keys")
async def bar(keys=None):
calls.append(tuple(keys))
return {k: k * 10 for k in keys}

assert await bar(keys=[1, 2]) == {1: 10, 2: 20}
assert await bar(keys=[1, 2]) == {1: 10, 2: 20}
assert await bar.refresh(keys=[1, 2]) == {1: 10, 2: 20}
assert calls == [(1, 2), (1, 2)]

@pytest.mark.asyncio
async def test_multi_cached_invalidate_key_and_all(self):
calls = []
cache = SimpleMemoryCache()

@multi_cached(cache=cache, keys_from_attr="keys")
async def bar(keys=None):
calls.extend(keys)
return {k: k * 10 for k in keys}

await bar(keys=[1, 2])
await bar.invalidate(1)
await bar(keys=[1])
await bar.invalidate()
await bar(keys=[2])
assert calls == [1, 2, 1, 2]

async def test_disable_params_not_propagated(self, decorator, decorator_call):
decorator.cache.multi_get.return_value = [None, None]

Expand Down