Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions asyncio_redis_rate_limit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class RateLimiter:
'_backend',
'_cache_prefix',
'_lock',
'_use_nx_on_expire'
)

def __init__(
Expand All @@ -58,13 +59,15 @@ def __init__(
backend: AnyRedis,
*,
cache_prefix: str,
use_nx_on_expire: bool = True,
) -> None:
"""In the future other backends might be supported as well."""
self._unique_key = unique_key
self._rate_spec = rate_spec
self._backend = backend
self._cache_prefix = cache_prefix
self._lock = asyncio.Lock()
self._use_nx_on_expire = use_nx_on_expire

async def __aenter__(self: _RateLimiterT) -> _RateLimiterT:
"""
Expand Down Expand Up @@ -110,6 +113,7 @@ async def _run_pipeline(
pipeline.incr(cache_key),
cache_key,
self._rate_spec.seconds,
use_nx=self._use_nx_on_expire,
).execute()
return current_rate # type: ignore[no-any-return]

Expand All @@ -130,6 +134,7 @@ def rate_limit( # noqa: WPS320
backend: AnyRedis,
*,
cache_prefix: str = 'aio-rate-limit',
use_nx_on_expire: bool = True,
) -> Callable[
[_CoroutineFunction[_ParamsT, _ResultT]],
_CoroutineFunction[_ParamsT, _ResultT],
Expand Down Expand Up @@ -167,6 +172,7 @@ async def factory(
backend=backend,
rate_spec=rate_spec,
cache_prefix=cache_prefix,
use_nx_on_expire=use_nx_on_expire,
):
return await function(*args, **kwargs)
return factory
Expand Down
5 changes: 5 additions & 0 deletions asyncio_redis_rate_limit/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ def pipeline_expire(
pipeline: Any,
cache_key: str,
seconds: int,
*,
use_nx: bool = True,
) -> AnyPipeline:
"""Compatibility mode for `.expire(..., nx=True)` command."""
if not use_nx:
return pipeline.expire(cache_key, seconds) # type: ignore

if isinstance(pipeline, _AsyncPipeline):
return pipeline.expire(cache_key, seconds, nx=True) # type: ignore
# `aioredis` somehow does not have this boolean argument in `.expire`,
Expand Down
27 changes: 27 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __call__(
self,
requests: int = ...,
seconds: int = ...,
*,
use_nx_on_expire: bool = ...,
) -> _LimitedSig:
"""We use this callback to construct `limited` test function."""

Expand Down Expand Up @@ -246,6 +248,31 @@ async def test_ten_reqs_in_two_secs2(
await asyncio.sleep(1 + 0.5)
await function()

@pytest.mark.repeat(5)
async def test_ten_reqs_in_two_secs_without_nx(
limited: _LimitedCallback,
) -> None:
"""Ensure that several gathered coroutines do respect the rate limit."""
function = limited(requests=10, seconds=2, use_nx_on_expire=False)

# Or just consume all:
for attempt in range(10):
await function(attempt)

# This one will fail:
with pytest.raises(RateLimitError):
await function()

# Now, let's move time to the next second:
await asyncio.sleep(1)

# This one will also fail:
with pytest.raises(RateLimitError):
await function()

# Next attempts will pass:
await asyncio.sleep(1 + 0.5)
await function()

class _Counter:
def __init__(self) -> None:
Expand Down