diff --git a/dogpile/cache/api.py b/dogpile/cache/api.py
index 6fbe842..f0c79dd 100644
--- a/dogpile/cache/api.py
+++ b/dogpile/cache/api.py
@@ -348,7 +348,8 @@ def get_serialized_multi(
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.
- :return: list of bytes objects
+ :return: list of bytes objects or the :data:`.NO_VALUE` contant
+ if not present.
The default implementation of this method for :class:`.CacheBackend`
returns the value of the :meth:`.CacheBackend.get_multi` method.
@@ -543,7 +544,8 @@ def get_serialized_multi(
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.
- :return: list of bytes objects
+ :return: list of bytes objects or the :data:`.NO_VALUE`
+ constant if not present.
.. versionadded:: 1.1
diff --git a/dogpile/cache/backends/file.py b/dogpile/cache/backends/file.py
index 06047ae..b4a26c4 100644
--- a/dogpile/cache/backends/file.py
+++ b/dogpile/cache/backends/file.py
@@ -6,10 +6,15 @@
"""
+from __future__ import annotations
+
from contextlib import contextmanager
import dbm
import os
import threading
+from typing import Literal
+from typing import TypedDict
+from typing import Union
from ..api import BytesBackend
from ..api import NO_VALUE
@@ -18,6 +23,13 @@
__all__ = ["DBMBackend", "FileLock", "AbstractFileLock"]
+class DBMBackendArguments(TypedDict, total=False):
+ filename: str
+ lock_factory: "AbstractFileLock"
+ rw_lockfile: Union[str, Literal[False], None]
+ dogpile_lockfile: Union[str, Literal[False], None]
+
+
class DBMBackend(BytesBackend):
"""A file-backend using a dbm file to store keys.
@@ -137,7 +149,7 @@ def release_write_lock(self):
"""
- def __init__(self, arguments):
+ def __init__(self, arguments: DBMBackendArguments):
self.filename = os.path.abspath(
os.path.normpath(arguments["filename"])
)
diff --git a/dogpile/cache/backends/memcached.py b/dogpile/cache/backends/memcached.py
index eee3448..c839c96 100644
--- a/dogpile/cache/backends/memcached.py
+++ b/dogpile/cache/backends/memcached.py
@@ -6,20 +6,27 @@
"""
+from __future__ import annotations
+
import random
import threading
import time
import typing
from typing import Any
from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import TypedDict
+from typing import Union
import warnings
from ..api import CacheBackend
from ..api import NO_VALUE
from ... import util
-
if typing.TYPE_CHECKING:
+ import ssl
+
import bmemcached
import memcache
import pylibmc
@@ -41,6 +48,50 @@
)
+class GenericMemcachedBackendArguments(TypedDict, total=False):
+ url: str
+ distributed_lock: bool
+ lock_timeout: int
+
+
+class MemcachedArgsArguments(GenericMemcachedBackendArguments, total=False):
+ min_compress_len: int
+ memcached_expire_time: int
+
+
+class MemcachedBackendArguments(GenericMemcachedBackendArguments, total=False):
+ min_compress_len: int
+ memcached_expire_time: int
+ dead_retry: int
+ socket_timeout: int
+
+
+class BMemcachedBackendArguments(
+ GenericMemcachedBackendArguments, total=False
+):
+ username: Optional[str]
+ password: Optional[bool]
+ tls_context: Optional["ssl.SSLContext"]
+
+
+class PyMemcacheBackendArguments(
+ GenericMemcachedBackendArguments, total=False
+):
+ serde: Optional[Any]
+ default_noreply: bool
+ tls_context: Optional["ssl.SSLContext"]
+ socket_keepalive: "pymemcache.client.base.KeepaliveOpts"
+ enable_retry_client: bool
+ retry_attempts: Optional[int]
+ retry_delay: Union[int, float, None]
+ retry_for: Optional[Sequence[Exception]]
+ do_not_retry_for: Optional[Sequence[Exception]]
+ hashclient_retry_attempts: int
+ hashclient_retry_timeout: int
+ hashclient_dead_timeout: int
+ memcached_expire_time: int
+
+
class MemcachedLock:
"""Simple distributed lock using memcached."""
@@ -117,7 +168,7 @@ class GenericMemcachedBackend(CacheBackend):
serializer = None
deserializer = None
- def __init__(self, arguments):
+ def __init__(self, arguments: GenericMemcachedBackendArguments):
self._imports()
# using a plain threading.local here. threading.local
# automatically deletes the __dict__ when a thread ends,
@@ -226,7 +277,7 @@ class MemcacheArgs(GenericMemcachedBackend):
of the value using the compressor
"""
- def __init__(self, arguments):
+ def __init__(self, arguments: MemcachedArgsArguments):
self.min_compress_len = arguments.get("min_compress_len", 0)
self.set_arguments = {}
@@ -273,7 +324,7 @@ class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
"""
- def __init__(self, arguments):
+ def __init__(self, arguments: MemcachedArgsArguments):
self.binary = arguments.get("binary", False)
self.behaviors = arguments.get("behaviors", {})
super(PylibmcBackend, self).__init__(arguments)
@@ -324,7 +375,7 @@ class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend):
"""
- def __init__(self, arguments):
+ def __init__(self, arguments: MemcachedBackendArguments):
self.dead_retry = arguments.get("dead_retry", 30)
self.socket_timeout = arguments.get("socket_timeout", 3)
super(MemcachedBackend, self).__init__(arguments)
@@ -400,7 +451,7 @@ class BMemcachedBackend(GenericMemcachedBackend):
"""
- def __init__(self, arguments):
+ def __init__(self, arguments: BMemcachedBackendArguments):
self.username = arguments.get("username", None)
self.password = arguments.get("password", None)
self.tls_context = arguments.get("tls_context", None)
@@ -560,7 +611,7 @@ class PyMemcacheBackend(GenericMemcachedBackend):
.. versionadded:: 1.1.5
- :param dead_timeout: Time in seconds before attempting to add a node
+ :param hashclient_dead_timeout: Time in seconds before attempting to add a node
back in the pool in the HashClient's internal mechanisms.
.. versionadded:: 1.1.5
@@ -589,7 +640,7 @@ class PyMemcacheBackend(GenericMemcachedBackend):
""" # noqa E501
- def __init__(self, arguments):
+ def __init__(self, arguments: PyMemcacheBackendArguments):
super().__init__(arguments)
self.serde = arguments.get("serde", pymemcache.serde.pickle_serde)
@@ -607,7 +658,9 @@ def __init__(self, arguments):
self.hashclient_retry_timeout = arguments.get(
"hashclient_retry_timeout", 1
)
- self.dead_timeout = arguments.get("hashclient_dead_timeout", 60)
+ self.hashclient_dead_timeout = arguments.get(
+ "hashclient_dead_timeout", 60
+ )
if (
self.retry_delay is not None
or self.retry_attempts is not None
@@ -633,7 +686,7 @@ def _create_client(self):
"tls_context": self.tls_context,
"retry_attempts": self.hashclient_retry_attempts,
"retry_timeout": self.hashclient_retry_timeout,
- "dead_timeout": self.dead_timeout,
+ "dead_timeout": self.hashclient_dead_timeout,
}
if self.socket_keepalive is not None:
_kwargs.update({"socket_keepalive": self.socket_keepalive})
diff --git a/dogpile/cache/backends/memory.py b/dogpile/cache/backends/memory.py
index 0ace10b..752d6c9 100644
--- a/dogpile/cache/backends/memory.py
+++ b/dogpile/cache/backends/memory.py
@@ -10,11 +10,21 @@
"""
+from __future__ import annotations
+
+from typing import Any
+from typing import Dict
+from typing import TypedDict
+
from ..api import CacheBackend
from ..api import DefaultSerialization
from ..api import NO_VALUE
+class MemoryBackendArguments(TypedDict):
+ cache_dict: Dict[Any, Any]
+
+
class MemoryBackend(CacheBackend):
"""A backend that uses a plain dictionary.
@@ -49,7 +59,7 @@ class MemoryBackend(CacheBackend):
"""
- def __init__(self, arguments):
+ def __init__(self, arguments: MemoryBackendArguments):
self._cache = arguments.get("cache_dict", {})
def get(self, key):
diff --git a/dogpile/cache/backends/null.py b/dogpile/cache/backends/null.py
index b4ad0fb..b349792 100644
--- a/dogpile/cache/backends/null.py
+++ b/dogpile/cache/backends/null.py
@@ -10,6 +10,11 @@
"""
+from __future__ import annotations
+
+from typing import Any
+from typing import Dict
+
from ..api import CacheBackend
from ..api import NO_VALUE
@@ -41,7 +46,7 @@ class NullBackend(CacheBackend):
"""
- def __init__(self, arguments):
+ def __init__(self, arguments: Dict[str, Any]):
pass
def get_mutex(self, key):
diff --git a/dogpile/cache/backends/redis.py b/dogpile/cache/backends/redis.py
index 68f84f5..7618f5a 100644
--- a/dogpile/cache/backends/redis.py
+++ b/dogpile/cache/backends/redis.py
@@ -6,13 +6,27 @@
"""
-import typing
+from __future__ import annotations
+
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypedDict
+from typing import Union
import warnings
from ..api import BytesBackend
+from ..api import KeyType
from ..api import NO_VALUE
+from ..api import SerializedReturnType
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
import redis
else:
# delayed import
@@ -21,6 +35,41 @@
__all__ = ("RedisBackend", "RedisSentinelBackend", "RedisClusterBackend")
+class RedisBackendKwargs(TypedDict, total=False):
+ """
+ TypedDict of kwargs for `RedisBackend` and derived classes
+ .. versionadded:: 1.4.1
+ """
+
+ url: Optional[str]
+ host: str
+ username: Optional[str]
+ password: Optional[str]
+ port: int
+ db: int
+ redis_expiration_time: int
+ distributed_lock: bool
+ lock_timeout: int
+ socket_timeout: Optional[float]
+ socket_connect_timeout: Optional[float]
+ socket_keepalive: bool
+ socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]]
+ lock_sleep: float
+ connection_pool: Optional["redis.ConnectionPool"]
+ thread_local_lock: bool
+ connection_kwargs: Dict[str, Any]
+
+
+class RedisSentinelBackendKwargs(RedisBackendKwargs):
+ sentinels: List[Tuple[str, str]]
+ service_name: str
+ sentinel_kwargs: Dict[str, Any]
+
+
+class RedisClusterBackendKwargs(RedisBackendKwargs):
+ startup_nodes: List["redis.cluster.ClusterNode"]
+
+
class RedisBackend(BytesBackend):
r"""A `Redis `_ backend, using the
`redis-py `_ driver.
@@ -114,33 +163,29 @@ class RedisBackend(BytesBackend):
.. versionadded:: 1.1.6
-
-
-
"""
- def __init__(self, arguments):
- arguments = arguments.copy()
+ def __init__(self, arguments: RedisBackendKwargs):
self._imports()
- self.url = arguments.pop("url", None)
- self.host = arguments.pop("host", "localhost")
- self.username = arguments.pop("username", None)
- self.password = arguments.pop("password", None)
- self.port = arguments.pop("port", 6379)
- self.db = arguments.pop("db", 0)
- self.distributed_lock = arguments.pop("distributed_lock", False)
- self.socket_timeout = arguments.pop("socket_timeout", None)
- self.socket_connect_timeout = arguments.pop(
+ self.url = arguments.get("url", None)
+ self.host = arguments.get("host", "localhost")
+ self.username = arguments.get("username", None)
+ self.password = arguments.get("password", None)
+ self.port = arguments.get("port", 6379)
+ self.db = arguments.get("db", 0)
+ self.distributed_lock = arguments.get("distributed_lock", False)
+ self.socket_timeout = arguments.get("socket_timeout", None)
+ self.socket_connect_timeout = arguments.get(
"socket_connect_timeout", None
)
- self.socket_keepalive = arguments.pop("socket_keepalive", False)
- self.socket_keepalive_options = arguments.pop(
+ self.socket_keepalive = arguments.get("socket_keepalive", False)
+ self.socket_keepalive_options = arguments.get(
"socket_keepalive_options", None
)
- self.lock_timeout = arguments.pop("lock_timeout", None)
- self.lock_sleep = arguments.pop("lock_sleep", 0.1)
- self.thread_local_lock = arguments.pop("thread_local_lock", True)
- self.connection_kwargs = arguments.pop("connection_kwargs", {})
+ self.lock_timeout = arguments.get("lock_timeout", None)
+ self.lock_sleep = arguments.get("lock_sleep", 0.1)
+ self.thread_local_lock = arguments.get("thread_local_lock", True)
+ self.connection_kwargs = arguments.get("connection_kwargs", {})
if self.distributed_lock and self.thread_local_lock:
warnings.warn(
@@ -148,16 +193,16 @@ def __init__(self, arguments):
"set to False when distributed_lock is True"
)
- self.redis_expiration_time = arguments.pop("redis_expiration_time", 0)
- self.connection_pool = arguments.pop("connection_pool", None)
+ self.redis_expiration_time = arguments.get("redis_expiration_time", 0)
+ self.connection_pool = arguments.get("connection_pool", None)
self._create_client()
- def _imports(self):
+ def _imports(self) -> None:
# defer imports until backend is used
global redis
import redis # noqa
- def _create_client(self):
+ def _create_client(self) -> None:
if self.connection_pool is not None:
# the connection pool already has all other connection
# options present within, so here we disregard socket_timeout
@@ -195,7 +240,7 @@ def _create_client(self):
self.writer_client = redis.StrictRedis(**args)
self.reader_client = self.writer_client
- def get_mutex(self, key):
+ def get_mutex(self, key: KeyType) -> Optional[_RedisLockWrapper]:
if self.distributed_lock:
return _RedisLockWrapper(
self.writer_client.lock(
@@ -208,25 +253,27 @@ def get_mutex(self, key):
else:
return None
- def get_serialized(self, key):
+ def get_serialized(self, key: KeyType) -> SerializedReturnType:
value = self.reader_client.get(key)
if value is None:
return NO_VALUE
- return value
+ return cast(SerializedReturnType, value)
- def get_serialized_multi(self, keys):
+ def get_serialized_multi(
+ self, keys: Sequence[KeyType]
+ ) -> Sequence[SerializedReturnType]:
if not keys:
return []
values = self.reader_client.mget(keys)
return [v if v is not None else NO_VALUE for v in values]
- def set_serialized(self, key, value):
+ def set_serialized(self, key: KeyType, value: bytes) -> None:
if self.redis_expiration_time:
self.writer_client.setex(key, self.redis_expiration_time, value)
else:
self.writer_client.set(key, value)
- def set_serialized_multi(self, mapping):
+ def set_serialized_multi(self, mapping: Mapping[KeyType, bytes]) -> None:
if not self.redis_expiration_time:
self.writer_client.mset(mapping)
else:
@@ -235,23 +282,23 @@ def set_serialized_multi(self, mapping):
pipe.setex(key, self.redis_expiration_time, value)
pipe.execute()
- def delete(self, key):
+ def delete(self, key: KeyType) -> None:
self.writer_client.delete(key)
- def delete_multi(self, keys):
+ def delete_multi(self, keys: Sequence[KeyType]) -> None:
self.writer_client.delete(*keys)
class _RedisLockWrapper:
__slots__ = ("mutex", "__weakref__")
- def __init__(self, mutex: typing.Any):
+ def __init__(self, mutex: Any):
self.mutex = mutex
- def acquire(self, wait: bool = True) -> typing.Any:
+ def acquire(self, wait: bool = True) -> Any:
return self.mutex.acquire(blocking=wait)
- def release(self) -> typing.Any:
+ def release(self) -> Any:
return self.mutex.release()
def locked(self) -> bool:
@@ -356,13 +403,10 @@ class RedisSentinelBackend(RedisBackend):
"""
- def __init__(self, arguments):
- arguments = arguments.copy()
-
- self.sentinels = arguments.pop("sentinels", None)
- self.service_name = arguments.pop("service_name", "mymaster")
- self.sentinel_kwargs = arguments.pop("sentinel_kwargs", {})
-
+ def __init__(self, arguments: RedisSentinelBackendKwargs):
+ self.sentinels = arguments.get("sentinels", None)
+ self.service_name = arguments.get("service_name", "mymaster")
+ self.sentinel_kwargs = arguments.get("sentinel_kwargs", {})
super().__init__(
arguments={
"distributed_lock": True,
@@ -371,7 +415,7 @@ def __init__(self, arguments):
}
)
- def _imports(self):
+ def _imports(self) -> None:
# defer imports until backend is used
global redis
import redis.sentinel # noqa
@@ -545,17 +589,16 @@ class RedisClusterBackend(RedisBackend):
"""
- def __init__(self, arguments):
- arguments = arguments.copy()
- self.startup_nodes = arguments.pop("startup_nodes", None)
+ def __init__(self, arguments: RedisClusterBackendKwargs):
+ self.startup_nodes = arguments.get("startup_nodes", None)
super().__init__(arguments)
- def _imports(self):
+ def _imports(self) -> None:
global redis
import redis.cluster
- def _create_client(self):
- redis_cluster: redis.cluster.RedisCluster[typing.Any]
+ def _create_client(self) -> None:
+ redis_cluster: redis.cluster.RedisCluster[Any]
if self.url is not None:
redis_cluster = redis.cluster.RedisCluster.from_url(
self.url, **self.connection_kwargs
@@ -565,5 +608,5 @@ def _create_client(self):
startup_nodes=self.startup_nodes,
**self.connection_kwargs,
)
- self.writer_client = typing.cast("redis.Redis[bytes]", redis_cluster)
+ self.writer_client = cast("redis.Redis[bytes]", redis_cluster)
self.reader_client = self.writer_client
diff --git a/dogpile/cache/backends/valkey.py b/dogpile/cache/backends/valkey.py
index 8604173..f6078b7 100644
--- a/dogpile/cache/backends/valkey.py
+++ b/dogpile/cache/backends/valkey.py
@@ -6,13 +6,24 @@
"""
-import typing
+from __future__ import annotations
+
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypedDict
+from typing import Union
import warnings
from ..api import BytesBackend
from ..api import NO_VALUE
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
import valkey
else:
# delayed import
@@ -21,6 +32,36 @@
__all__ = ("ValkeyBackend", "ValkeySentinelBackend", "ValkeyClusterBackend")
+class ValkeyBackendArguments(TypedDict, total=False):
+ url: Optional[str]
+ host: str
+ username: Optional[str]
+ password: Optional[str]
+ port: int
+ db: int
+ valkey_expiration_time: int
+ distributed_lock: bool
+ lock_timeout: Optional[int]
+ socket_timeout: Optional[float]
+ socket_connect_timeout: Optional[float]
+ socket_keepalive: bool
+ socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]]
+ lock_sleep: float = 0.1
+ thread_local_lock: bool
+ connection_kwargs: Dict[str, Any]
+ connection_pool: Optional["valkey.ConnectionPool"]
+
+
+class ValkeySentinelBackendArguments(ValkeyBackendArguments):
+ sentinels: Optional[List[Tuple[str, int]]]
+ service_name: str
+ sentinel_kwargs: Dict[str, Any]
+
+
+class ValkeyClusterBackendArguments(ValkeyBackendArguments):
+ startup_nodes: List["valkey.cluster.ClusterNode"]
+
+
class ValkeyBackend(BytesBackend):
r"""A `Valkey `_ backend, using the
`valkey-py `_ driver.
@@ -85,9 +126,9 @@ class ValkeyBackend(BytesBackend):
:param socket_keepalive_options: dict, socket keepalive options.
Default is None (no options).
- :param lock_sleep: integer, number of seconds to sleep when failed to
+ :param lock_sleep: float, number of seconds to sleep when failed to
acquire a lock. This argument is only valid when
- ``distributed_lock`` is ``True``.
+ ``distributed_lock`` is ``True``. Default is `0.1`, the Valkey default.
:param connection_pool: ``valkey.ConnectionPool`` object. If provided,
this object supersedes other connection arguments passed to the
@@ -109,28 +150,27 @@ class ValkeyBackend(BytesBackend):
"""
- def __init__(self, arguments):
- arguments = arguments.copy()
+ def __init__(self, arguments: ValkeyBackendArguments):
self._imports()
- self.url = arguments.pop("url", None)
- self.host = arguments.pop("host", "localhost")
- self.username = arguments.pop("username", None)
- self.password = arguments.pop("password", None)
- self.port = arguments.pop("port", 6379)
- self.db = arguments.pop("db", 0)
- self.distributed_lock = arguments.pop("distributed_lock", False)
- self.socket_timeout = arguments.pop("socket_timeout", None)
- self.socket_connect_timeout = arguments.pop(
+ self.url = arguments.get("url", None)
+ self.host = arguments.get("host", "localhost")
+ self.username = arguments.get("username", None)
+ self.password = arguments.get("password", None)
+ self.port = arguments.get("port", 6379)
+ self.db = arguments.get("db", 0)
+ self.distributed_lock = arguments.get("distributed_lock", False)
+ self.socket_timeout = arguments.get("socket_timeout", None)
+ self.socket_connect_timeout = arguments.get(
"socket_connect_timeout", None
)
- self.socket_keepalive = arguments.pop("socket_keepalive", False)
- self.socket_keepalive_options = arguments.pop(
+ self.socket_keepalive = arguments.get("socket_keepalive", False)
+ self.socket_keepalive_options = arguments.get(
"socket_keepalive_options", None
)
- self.lock_timeout = arguments.pop("lock_timeout", None)
- self.lock_sleep = arguments.pop("lock_sleep", 0.1)
- self.thread_local_lock = arguments.pop("thread_local_lock", True)
- self.connection_kwargs = arguments.pop("connection_kwargs", {})
+ self.lock_timeout = arguments.get("lock_timeout", None)
+ self.lock_sleep = arguments.get("lock_sleep", 0.1)
+ self.thread_local_lock = arguments.get("thread_local_lock", True)
+ self.connection_kwargs = arguments.get("connection_kwargs", {})
if self.distributed_lock and self.thread_local_lock:
warnings.warn(
@@ -138,10 +178,10 @@ def __init__(self, arguments):
"set to False when distributed_lock is True"
)
- self.valkey_expiration_time = arguments.pop(
+ self.valkey_expiration_time = arguments.get(
"valkey_expiration_time", 0
)
- self.connection_pool = arguments.pop("connection_pool", None)
+ self.connection_pool = arguments.get("connection_pool", None)
self._create_client()
def _imports(self):
@@ -237,13 +277,13 @@ def delete_multi(self, keys):
class _ValkeyLockWrapper:
__slots__ = ("mutex", "__weakref__")
- def __init__(self, mutex: typing.Any):
+ def __init__(self, mutex: Any):
self.mutex = mutex
- def acquire(self, wait: bool = True) -> typing.Any:
+ def acquire(self, wait: bool = True) -> Any:
return self.mutex.acquire(blocking=wait)
- def release(self) -> typing.Any:
+ def release(self) -> Any:
return self.mutex.release()
def locked(self) -> bool:
@@ -348,12 +388,10 @@ class ValkeySentinelBackend(ValkeyBackend):
"""
- def __init__(self, arguments):
- arguments = arguments.copy()
-
- self.sentinels = arguments.pop("sentinels", None)
- self.service_name = arguments.pop("service_name", "mymaster")
- self.sentinel_kwargs = arguments.pop("sentinel_kwargs", {})
+ def __init__(self, arguments: ValkeySentinelBackendArguments):
+ self.sentinels = arguments.get("sentinels", None)
+ self.service_name = arguments.get("service_name", "mymaster")
+ self.sentinel_kwargs = arguments.get("sentinel_kwargs", {})
super().__init__(
arguments={
@@ -537,17 +575,17 @@ class ValkeyClusterBackend(ValkeyBackend):
""" # noqa: E501
- def __init__(self, arguments):
- arguments = arguments.copy()
- self.startup_nodes = arguments.pop("startup_nodes", None)
- super().__init__(arguments)
+ def __init__(self, arguments: ValkeyClusterBackendArguments):
+ self.startup_nodes = arguments.get("startup_nodes", None)
+ _arguments_super = cast(ValkeyBackendArguments, arguments)
+ super().__init__(_arguments_super)
def _imports(self):
global valkey
import valkey.cluster
def _create_client(self):
- valkey_cluster: valkey.cluster.ValkeyCluster[typing.Any] # type: ignore # noqa: E501
+ valkey_cluster: valkey.cluster.ValkeyCluster[Any] # type: ignore # noqa: E501
if self.url is not None:
valkey_cluster = valkey.cluster.ValkeyCluster.from_url(
self.url, **self.connection_kwargs
@@ -557,5 +595,5 @@ def _create_client(self):
startup_nodes=self.startup_nodes,
**self.connection_kwargs,
)
- self.writer_client = typing.cast(valkey.Valkey[bytes], valkey_cluster) # type: ignore # noqa: E501
+ self.writer_client = cast(valkey.Valkey[bytes], valkey_cluster) # type: ignore # noqa: E501
self.reader_client = self.writer_client
diff --git a/dogpile/cache/proxy.py b/dogpile/cache/proxy.py
index 08f598e..6d096c7 100644
--- a/dogpile/cache/proxy.py
+++ b/dogpile/cache/proxy.py
@@ -34,24 +34,24 @@ class ProxyBackend(CacheBackend):
from dogpile.cache.proxy import ProxyBackend
class MyFirstProxy(ProxyBackend):
- def get_serialized(self, key):
+ def get_serialized(self, key: KeyType) -> SerializedReturnType:
# ... custom code goes here ...
return self.proxied.get_serialized(key)
- def get(self, key):
+ def get(self, key: KeyType) -> BackendFormatted:
# ... custom code goes here ...
return self.proxied.get(key)
- def set(self, key, value):
+ def set(self, key: KeyType, value: BackendSetType) -> None:
# ... custom code goes here ...
self.proxied.set(key)
class MySecondProxy(ProxyBackend):
- def get_serialized(self, key):
+ def get_serialized(self, key: KeyType) -> SerializedReturnType:
# ... custom code goes here ...
return self.proxied.get_serialized(key)
- def get(self, key):
+ def get(self, key: KeyType) -> BackendFormatted:
# ... custom code goes here ...
return self.proxied.get(key)
diff --git a/dogpile/cache/region.py b/dogpile/cache/region.py
index 352b308..75d7c72 100644
--- a/dogpile/cache/region.py
+++ b/dogpile/cache/region.py
@@ -12,6 +12,8 @@
from typing import Any
from typing import Callable
from typing import cast
+from typing import Dict
+from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
@@ -25,6 +27,7 @@
from . import exception
from .api import BackendArguments
from .api import BackendFormatted
+from .api import CacheBackend
from .api import CachedValue
from .api import CacheMutex
from .api import CacheReturnType
@@ -553,7 +556,7 @@ def wrap(self, proxy: Union[ProxyBackend, Type[ProxyBackend]]) -> None:
self.backend = proxy_instance.wrap(self.backend)
- def _mutex(self, key):
+ def _mutex(self, key: KeyType) -> Any:
return self._lock_registry.get(key)
class _LockWrapper(CacheMutex):
@@ -571,7 +574,7 @@ def release(self):
def locked(self):
return self.lock.locked()
- def _create_mutex(self, key):
+ def _create_mutex(self, key: KeyType) -> Optional[Any]:
mutex = self.backend.get_mutex(key)
if mutex is not None:
return mutex
@@ -579,10 +582,10 @@ def _create_mutex(self, key):
return self._LockWrapper()
# cached value
- _actual_backend = None
+ _actual_backend: Optional[CacheBackend] = None
@property
- def actual_backend(self):
+ def actual_backend(self) -> CacheBackend:
"""Return the ultimate backend underneath any proxies.
The backend might be the result of one or more ``proxy.wrap``
@@ -596,9 +599,11 @@ def actual_backend(self):
while hasattr(_backend, "proxied"):
_backend = _backend.proxied
self._actual_backend = _backend
+ if TYPE_CHECKING:
+ assert self._actual_backend
return self._actual_backend
- def invalidate(self, hard=True):
+ def invalidate(self, hard: bool = True) -> None:
"""Invalidate this :class:`.CacheRegion`.
The default invalidation system works by setting
@@ -648,7 +653,11 @@ def invalidate(self, hard=True):
"""
self.region_invalidator.invalidate(hard)
- def configure_from_config(self, config_dict, prefix):
+ def configure_from_config(
+ self,
+ config_dict: Dict[str, Any],
+ prefix: str,
+ ) -> Self:
"""Configure from a configuration dictionary
and a prefix.
@@ -680,20 +689,20 @@ def configure_from_config(self, config_dict, prefix):
),
_config_argument_dict=config_dict,
_config_prefix="%sarguments." % prefix,
- wrap=config_dict.get("%swrap" % prefix, None),
+ wrap=config_dict.get("%swrap" % prefix, ()),
replace_existing_backend=config_dict.get(
"%sreplace_existing_backend" % prefix, False
),
)
@memoized_property
- def backend(self):
+ def backend(self) -> CacheBackend:
raise exception.RegionNotConfigured(
"No backend is configured on this region."
)
@property
- def is_configured(self):
+ def is_configured(self) -> bool:
"""Return True if the backend has been configured via the
:meth:`.CacheRegion.configure` method already.
@@ -824,7 +833,9 @@ def _get_cache_value(
return value
def _unexpired_value_fn(
- self, expiration_time: Optional[float], ignore_expiration: bool
+ self,
+ expiration_time: Optional[float],
+ ignore_expiration: bool = False,
) -> Callable[[CacheReturnType], CacheReturnType]:
if ignore_expiration:
return lambda value: value
@@ -851,7 +862,12 @@ def value_fn(value):
return value_fn
- def get_multi(self, keys, expiration_time=None, ignore_expiration=False):
+ def get_multi(
+ self,
+ keys: Sequence[KeyType],
+ expiration_time: Optional[float] = None,
+ ignore_expiration: bool = False,
+ ) -> List[Union[ValuePayload, NoValueType]]:
"""Return multiple values from the cache, based on the given keys.
Returns values as a list matching the keys given.
@@ -902,7 +918,7 @@ def get_multi(self, keys, expiration_time=None, ignore_expiration=False):
]
@contextlib.contextmanager
- def _log_time(self, keys):
+ def _log_time(self, keys: Sequence[KeyType]):
start_time = time.time()
yield
seconds = time.time() - start_time
@@ -912,7 +928,11 @@ def _log_time(self, keys):
{"seconds": seconds, "keys": repr_obj(keys)},
)
- def _is_cache_miss(self, value, orig_key):
+ def _is_cache_miss(
+ self,
+ value: CacheReturnType,
+ orig_key: KeyType,
+ ) -> bool:
if value is NO_VALUE:
log.debug("No value present for key: %r", orig_key)
elif value.metadata["v"] != value_version:
@@ -1154,7 +1174,7 @@ def get_or_create_multi(
"""
- def get_value(key):
+ def get_value(key: KeyType) -> Tuple[Any, Union[float, int]]:
value = values.get(key, NO_VALUE)
if self._is_cache_miss(value, orig_key):
diff --git a/dogpile/cache/util.py b/dogpile/cache/util.py
index 7fddaa5..560de14 100644
--- a/dogpile/cache/util.py
+++ b/dogpile/cache/util.py
@@ -1,10 +1,17 @@
from hashlib import sha1
+from typing import Any
+from typing import Callable
+from .api import KeyType
from ..util import compat
from ..util import langhelpers
-def function_key_generator(namespace, fn, to_str=str):
+def function_key_generator(
+ namespace: str,
+ fn: Callable,
+ to_str: Callable[[Any], str] = str,
+) -> Callable:
"""Return a function that generates a string
key, based on a given function as well as
arguments to the returned function itself.
@@ -45,7 +52,11 @@ def generate_key(*args, **kw):
return generate_key
-def function_multi_key_generator(namespace, fn, to_str=str):
+def function_multi_key_generator(
+ namespace: str,
+ fn: Callable,
+ to_str: Callable[[Any], str] = str,
+) -> Callable:
if namespace is None:
namespace = "%s:%s" % (fn.__module__, fn.__name__)
else:
@@ -67,7 +78,11 @@ def generate_keys(*args, **kw):
return generate_keys
-def kwarg_function_key_generator(namespace, fn, to_str=str):
+def kwarg_function_key_generator(
+ namespace: str,
+ fn: Callable,
+ to_str: Callable[[Any], str] = str,
+) -> Callable:
"""Return a function that generates a string
key, based on a given function as well as
arguments to the returned function itself.
@@ -127,16 +142,17 @@ def generate_key(*args, **kwargs):
return generate_key
-def sha1_mangle_key(key):
+def sha1_mangle_key(key: KeyType) -> str:
"""a SHA1 key mangler."""
- if isinstance(key, str):
- key = key.encode("utf-8")
+ bkey = key.encode("utf-8") if isinstance(key, str) else key
- return sha1(key).hexdigest()
+ return sha1(bkey).hexdigest()
-def length_conditional_mangler(length, mangler):
+def length_conditional_mangler(
+ length: int, mangler: Callable[[KeyType], str]
+) -> Callable[[KeyType], str]:
"""a key mangler that mangles if the length of the key is
past a certain threshold.
diff --git a/dogpile/testing/fixtures.py b/dogpile/testing/fixtures.py
index 4dff69f..cce5114 100644
--- a/dogpile/testing/fixtures.py
+++ b/dogpile/testing/fixtures.py
@@ -1,5 +1,7 @@
# mypy: ignore-errors
+from __future__ import annotations
+
import collections
import itertools
import json
@@ -7,6 +9,8 @@
from threading import Lock
from threading import Thread
import time
+from typing import Any
+from typing import Dict
import uuid
import pytest
@@ -51,7 +55,7 @@ def _check_backend_available(cls, backend):
pass
backend: str
- region_args = {}
+ region_args: Dict[str, Any] = {}
config_args = {}
extra_arguments = {}
backend_argument_names = ()
diff --git a/tests/cache/test_memcached_backend.py b/tests/cache/test_memcached_backend.py
index f3b106f..d1e2afa 100644
--- a/tests/cache/test_memcached_backend.py
+++ b/tests/cache/test_memcached_backend.py
@@ -10,8 +10,10 @@
from dogpile.cache.backends.memcached import GenericMemcachedBackend
from dogpile.cache.backends.memcached import MemcachedBackend
+from dogpile.cache.backends.memcached import MemcachedBackendArguments
from dogpile.cache.backends.memcached import PylibmcBackend
from dogpile.cache.backends.memcached import PyMemcacheBackend
+from dogpile.cache.backends.memcached import PyMemcacheBackendArguments
from dogpile.testing import eq_
from dogpile.testing import is_
from dogpile.testing.fixtures import _GenericBackendTestSuite
@@ -242,7 +244,7 @@ def _mock_pymemcache_fixture(self):
)
def test_pymemcache_hashclient_retry_attempts(self):
- config_args = {
+ config_args: PyMemcacheBackendArguments = {
"url": "127.0.0.1:11211",
"hashclient_retry_attempts": 4,
}
@@ -265,7 +267,10 @@ def test_pymemcache_hashclient_retry_attempts(self):
eq_(self.retrying_client.mock_calls, [])
def test_pymemcache_hashclient_retry_timeout(self):
- config_args = {"url": "127.0.0.1:11211", "hashclient_retry_timeout": 4}
+ config_args: PyMemcacheBackendArguments = {
+ "url": "127.0.0.1:11211",
+ "hashclient_retry_timeout": 4,
+ }
with self._mock_pymemcache_fixture():
backend = MockPyMemcacheBackend(config_args)
is_(backend._create_client(), self.hash_client())
@@ -284,7 +289,7 @@ def test_pymemcache_hashclient_retry_timeout(self):
eq_(self.retrying_client.mock_calls, [])
def test_pymemcache_hashclient_retry_timeout_w_enable_retry(self):
- config_args = {
+ config_args: PyMemcacheBackendArguments = {
"url": "127.0.0.1:11211",
"hashclient_retry_timeout": 4,
"enable_retry_client": True,
@@ -317,7 +322,10 @@ def test_pymemcache_hashclient_retry_timeout_w_enable_retry(self):
)
def test_pymemcache_dead_timeout(self):
- config_args = {"url": "127.0.0.1:11211", "hashclient_dead_timeout": 4}
+ config_args: PyMemcacheBackendArguments = {
+ "url": "127.0.0.1:11211",
+ "hashclient_dead_timeout": 4,
+ }
with self._mock_pymemcache_fixture():
backend = MockPyMemcacheBackend(config_args)
backend._create_client()
@@ -338,7 +346,10 @@ def test_pymemcache_dead_timeout(self):
eq_(self.retrying_client.mock_calls, [])
def test_pymemcache_enable_retry_client_not_set(self):
- config_args = {"url": "127.0.0.1:11211", "retry_attempts": 2}
+ config_args: PyMemcacheBackendArguments = {
+ "url": "127.0.0.1:11211",
+ "retry_attempts": 2,
+ }
with self._mock_pymemcache_fixture():
with mock.patch("warnings.warn") as warn_mock:
@@ -352,7 +363,10 @@ def test_pymemcache_enable_retry_client_not_set(self):
)
def test_pymemcache_memacached_expire_time(self):
- config_args = {"url": "127.0.0.1:11211", "memcached_expire_time": 20}
+ config_args: PyMemcacheBackendArguments = {
+ "url": "127.0.0.1:11211",
+ "memcached_expire_time": 20,
+ }
with self._mock_pymemcache_fixture():
backend = MockPyMemcacheBackend(config_args)
backend.set("foo", "bar")
@@ -446,7 +460,7 @@ def delete(self, key):
class MemcachedBackendTest:
def test_memcached_dead_retry(self):
- config_args = {
+ config_args: MemcachedBackendArguments = {
"url": "127.0.0.1:11211",
"dead_retry": 4,
}
@@ -454,7 +468,7 @@ def test_memcached_dead_retry(self):
eq_(backend._create_client().kw["dead_retry"], 4)
def test_memcached_socket_timeout(self):
- config_args = {
+ config_args: MemcachedBackendArguments = {
"url": "127.0.0.1:11211",
"socket_timeout": 6,
}
diff --git a/tests/cache/test_redis_backend.py b/tests/cache/test_redis_backend.py
index 01472a9..c0520cc 100644
--- a/tests/cache/test_redis_backend.py
+++ b/tests/cache/test_redis_backend.py
@@ -2,11 +2,14 @@
import os
from threading import Event
import time
+from typing import Type
+from typing import TYPE_CHECKING
from unittest.mock import Mock
from unittest.mock import patch
import pytest
+from dogpile.cache.backends.redis import RedisBackendKwargs
from dogpile.cache.region import _backend_loader
from dogpile.testing import eq_
from dogpile.testing.fixtures import _GenericBackendFixture
@@ -14,6 +17,9 @@
from dogpile.testing.fixtures import _GenericMutexTestSuite
from dogpile.testing.fixtures import _GenericSerializerTestSuite
+if TYPE_CHECKING:
+ import redis
+
REDIS_HOST = "127.0.0.1"
REDIS_PORT = int(os.getenv("DOGPILE_REDIS_PORT", "6379"))
expect_redis_running = os.getenv("DOGPILE_REDIS_PORT") is not None
@@ -129,6 +135,7 @@ def blah(k):
@patch("redis.StrictRedis", autospec=True)
class RedisConnectionTest:
backend = "dogpile.cache.redis"
+ backend_cls = Type["redis.StrictRedis"]
@classmethod
def setup_class(cls):
@@ -147,7 +154,7 @@ def _test_helper(self, mock_obj, expected_args, connection_args=None):
def test_connect_with_defaults(self, MockStrictRedis):
# The defaults, used if keys are missing from the arguments dict.
- arguments = {
+ arguments: RedisBackendKwargs = {
"host": "localhost",
"port": 6379,
"db": 0,
@@ -157,7 +164,7 @@ def test_connect_with_defaults(self, MockStrictRedis):
self._test_helper(MockStrictRedis, expected, arguments)
def test_connect_with_basics(self, MockStrictRedis):
- arguments = {
+ arguments: RedisBackendKwargs = {
"host": "127.0.0.1",
"port": 6379,
"db": 0,
@@ -167,7 +174,7 @@ def test_connect_with_basics(self, MockStrictRedis):
self._test_helper(MockStrictRedis, expected, arguments)
def test_connect_with_password(self, MockStrictRedis):
- arguments = {
+ arguments: RedisBackendKwargs = {
"host": "127.0.0.1",
"password": "some password",
"port": 6379,
@@ -182,7 +189,7 @@ def test_connect_with_password(self, MockStrictRedis):
self._test_helper(MockStrictRedis, expected, arguments)
def test_connect_with_username_and_password(self, MockStrictRedis):
- arguments = {
+ arguments: RedisBackendKwargs = {
"host": "127.0.0.1",
"username": "redis",
"password": "some password",
@@ -192,7 +199,7 @@ def test_connect_with_username_and_password(self, MockStrictRedis):
self._test_helper(MockStrictRedis, arguments)
def test_connect_with_socket_timeout(self, MockStrictRedis):
- arguments = {
+ arguments: RedisBackendKwargs = {
"host": "127.0.0.1",
"port": 6379,
"socket_timeout": 0.5,
@@ -203,7 +210,7 @@ def test_connect_with_socket_timeout(self, MockStrictRedis):
self._test_helper(MockStrictRedis, expected, arguments)
def test_connect_with_socket_connect_timeout(self, MockStrictRedis):
- arguments = {
+ arguments: RedisBackendKwargs = {
"host": "127.0.0.1",
"port": 6379,
"socket_timeout": 1.0,
@@ -214,7 +221,7 @@ def test_connect_with_socket_connect_timeout(self, MockStrictRedis):
self._test_helper(MockStrictRedis, expected, arguments)
def test_connect_with_socket_keepalive(self, MockStrictRedis):
- arguments = {
+ arguments: RedisBackendKwargs = {
"host": "127.0.0.1",
"port": 6379,
"socket_keepalive": True,
@@ -225,7 +232,7 @@ def test_connect_with_socket_keepalive(self, MockStrictRedis):
self._test_helper(MockStrictRedis, expected, arguments)
def test_connect_with_socket_keepalive_options(self, MockStrictRedis):
- arguments = {
+ arguments: RedisBackendKwargs = {
"host": "127.0.0.1",
"port": 6379,
"socket_keepalive": True,
@@ -239,18 +246,23 @@ def test_connect_with_socket_keepalive_options(self, MockStrictRedis):
def test_connect_with_connection_pool(self, MockStrictRedis):
pool = Mock()
- arguments = {"connection_pool": pool, "socket_timeout": 0.5}
+ arguments: RedisBackendKwargs = {
+ "connection_pool": pool,
+ "socket_timeout": 0.5,
+ }
expected_args = {"connection_pool": pool}
self._test_helper(
MockStrictRedis, expected_args, connection_args=arguments
)
def test_connect_with_url(self, MockStrictRedis):
- arguments = {"url": "redis://redis:password@127.0.0.1:6379/0"}
+ arguments: RedisBackendKwargs = {
+ "url": "redis://redis:password@127.0.0.1:6379/0"
+ }
self._test_helper(MockStrictRedis.from_url, arguments)
def test_extra_arbitrary_args(self, MockStrictRedis):
- arguments = {
+ arguments: RedisBackendKwargs = {
"url": "redis://redis:password@127.0.0.1:6379/0",
"connection_kwargs": {
"ssl": True,
diff --git a/tests/cache/test_region.py b/tests/cache/test_region.py
index 6f8120e..bf5538a 100644
--- a/tests/cache/test_region.py
+++ b/tests/cache/test_region.py
@@ -24,7 +24,7 @@
from dogpile.testing.fixtures import MockBackend
-def key_mangler(key):
+def key_mangler(key: str) -> str:
return "HI!" + key
diff --git a/tests/cache/test_valkey_backend.py b/tests/cache/test_valkey_backend.py
index f0e1c25..8d83a7b 100644
--- a/tests/cache/test_valkey_backend.py
+++ b/tests/cache/test_valkey_backend.py
@@ -2,6 +2,8 @@
import os
from threading import Event
import time
+from typing import Type
+from typing import TYPE_CHECKING
from unittest.mock import Mock
from unittest.mock import patch
@@ -14,6 +16,9 @@
from dogpile.testing.fixtures import _GenericMutexTestSuite
from dogpile.testing.fixtures import _GenericSerializerTestSuite
+if TYPE_CHECKING:
+ import valkey
+
VALKEY_HOST = "127.0.0.1"
VALKEY_PORT = int(os.getenv("DOGPILE_VALKEY_PORT", "6379"))
expect_valkey_running = os.getenv("DOGPILE_VALKEY_PORT") is not None
@@ -129,6 +134,7 @@ def blah(k):
@patch("valkey.StrictValkey", autospec=True)
class ValkeyConnectionTest:
backend = "dogpile.cache.valkey"
+ backend_cls = Type["valkey.StrictValkey"]
@classmethod
def setup_class(cls):