diff --git a/dogpile/cache/api.py b/dogpile/cache/api.py index 6fbe842..f0c79dd 100644 --- a/dogpile/cache/api.py +++ b/dogpile/cache/api.py @@ -348,7 +348,8 @@ def get_serialized_multi( :meth:`.CacheRegion.get_multi` method, which will also be processed by the "key mangling" function if one was present. - :return: list of bytes objects + :return: list of bytes objects or the :data:`.NO_VALUE` contant + if not present. The default implementation of this method for :class:`.CacheBackend` returns the value of the :meth:`.CacheBackend.get_multi` method. @@ -543,7 +544,8 @@ def get_serialized_multi( :meth:`.CacheRegion.get_multi` method, which will also be processed by the "key mangling" function if one was present. - :return: list of bytes objects + :return: list of bytes objects or the :data:`.NO_VALUE` + constant if not present. .. versionadded:: 1.1 diff --git a/dogpile/cache/backends/file.py b/dogpile/cache/backends/file.py index 06047ae..b4a26c4 100644 --- a/dogpile/cache/backends/file.py +++ b/dogpile/cache/backends/file.py @@ -6,10 +6,15 @@ """ +from __future__ import annotations + from contextlib import contextmanager import dbm import os import threading +from typing import Literal +from typing import TypedDict +from typing import Union from ..api import BytesBackend from ..api import NO_VALUE @@ -18,6 +23,13 @@ __all__ = ["DBMBackend", "FileLock", "AbstractFileLock"] +class DBMBackendArguments(TypedDict, total=False): + filename: str + lock_factory: "AbstractFileLock" + rw_lockfile: Union[str, Literal[False], None] + dogpile_lockfile: Union[str, Literal[False], None] + + class DBMBackend(BytesBackend): """A file-backend using a dbm file to store keys. @@ -137,7 +149,7 @@ def release_write_lock(self): """ - def __init__(self, arguments): + def __init__(self, arguments: DBMBackendArguments): self.filename = os.path.abspath( os.path.normpath(arguments["filename"]) ) diff --git a/dogpile/cache/backends/memcached.py b/dogpile/cache/backends/memcached.py index eee3448..c839c96 100644 --- a/dogpile/cache/backends/memcached.py +++ b/dogpile/cache/backends/memcached.py @@ -6,20 +6,27 @@ """ +from __future__ import annotations + import random import threading import time import typing from typing import Any from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import TypedDict +from typing import Union import warnings from ..api import CacheBackend from ..api import NO_VALUE from ... import util - if typing.TYPE_CHECKING: + import ssl + import bmemcached import memcache import pylibmc @@ -41,6 +48,50 @@ ) +class GenericMemcachedBackendArguments(TypedDict, total=False): + url: str + distributed_lock: bool + lock_timeout: int + + +class MemcachedArgsArguments(GenericMemcachedBackendArguments, total=False): + min_compress_len: int + memcached_expire_time: int + + +class MemcachedBackendArguments(GenericMemcachedBackendArguments, total=False): + min_compress_len: int + memcached_expire_time: int + dead_retry: int + socket_timeout: int + + +class BMemcachedBackendArguments( + GenericMemcachedBackendArguments, total=False +): + username: Optional[str] + password: Optional[bool] + tls_context: Optional["ssl.SSLContext"] + + +class PyMemcacheBackendArguments( + GenericMemcachedBackendArguments, total=False +): + serde: Optional[Any] + default_noreply: bool + tls_context: Optional["ssl.SSLContext"] + socket_keepalive: "pymemcache.client.base.KeepaliveOpts" + enable_retry_client: bool + retry_attempts: Optional[int] + retry_delay: Union[int, float, None] + retry_for: Optional[Sequence[Exception]] + do_not_retry_for: Optional[Sequence[Exception]] + hashclient_retry_attempts: int + hashclient_retry_timeout: int + hashclient_dead_timeout: int + memcached_expire_time: int + + class MemcachedLock: """Simple distributed lock using memcached.""" @@ -117,7 +168,7 @@ class GenericMemcachedBackend(CacheBackend): serializer = None deserializer = None - def __init__(self, arguments): + def __init__(self, arguments: GenericMemcachedBackendArguments): self._imports() # using a plain threading.local here. threading.local # automatically deletes the __dict__ when a thread ends, @@ -226,7 +277,7 @@ class MemcacheArgs(GenericMemcachedBackend): of the value using the compressor """ - def __init__(self, arguments): + def __init__(self, arguments: MemcachedArgsArguments): self.min_compress_len = arguments.get("min_compress_len", 0) self.set_arguments = {} @@ -273,7 +324,7 @@ class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend): """ - def __init__(self, arguments): + def __init__(self, arguments: MemcachedArgsArguments): self.binary = arguments.get("binary", False) self.behaviors = arguments.get("behaviors", {}) super(PylibmcBackend, self).__init__(arguments) @@ -324,7 +375,7 @@ class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend): """ - def __init__(self, arguments): + def __init__(self, arguments: MemcachedBackendArguments): self.dead_retry = arguments.get("dead_retry", 30) self.socket_timeout = arguments.get("socket_timeout", 3) super(MemcachedBackend, self).__init__(arguments) @@ -400,7 +451,7 @@ class BMemcachedBackend(GenericMemcachedBackend): """ - def __init__(self, arguments): + def __init__(self, arguments: BMemcachedBackendArguments): self.username = arguments.get("username", None) self.password = arguments.get("password", None) self.tls_context = arguments.get("tls_context", None) @@ -560,7 +611,7 @@ class PyMemcacheBackend(GenericMemcachedBackend): .. versionadded:: 1.1.5 - :param dead_timeout: Time in seconds before attempting to add a node + :param hashclient_dead_timeout: Time in seconds before attempting to add a node back in the pool in the HashClient's internal mechanisms. .. versionadded:: 1.1.5 @@ -589,7 +640,7 @@ class PyMemcacheBackend(GenericMemcachedBackend): """ # noqa E501 - def __init__(self, arguments): + def __init__(self, arguments: PyMemcacheBackendArguments): super().__init__(arguments) self.serde = arguments.get("serde", pymemcache.serde.pickle_serde) @@ -607,7 +658,9 @@ def __init__(self, arguments): self.hashclient_retry_timeout = arguments.get( "hashclient_retry_timeout", 1 ) - self.dead_timeout = arguments.get("hashclient_dead_timeout", 60) + self.hashclient_dead_timeout = arguments.get( + "hashclient_dead_timeout", 60 + ) if ( self.retry_delay is not None or self.retry_attempts is not None @@ -633,7 +686,7 @@ def _create_client(self): "tls_context": self.tls_context, "retry_attempts": self.hashclient_retry_attempts, "retry_timeout": self.hashclient_retry_timeout, - "dead_timeout": self.dead_timeout, + "dead_timeout": self.hashclient_dead_timeout, } if self.socket_keepalive is not None: _kwargs.update({"socket_keepalive": self.socket_keepalive}) diff --git a/dogpile/cache/backends/memory.py b/dogpile/cache/backends/memory.py index 0ace10b..752d6c9 100644 --- a/dogpile/cache/backends/memory.py +++ b/dogpile/cache/backends/memory.py @@ -10,11 +10,21 @@ """ +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import TypedDict + from ..api import CacheBackend from ..api import DefaultSerialization from ..api import NO_VALUE +class MemoryBackendArguments(TypedDict): + cache_dict: Dict[Any, Any] + + class MemoryBackend(CacheBackend): """A backend that uses a plain dictionary. @@ -49,7 +59,7 @@ class MemoryBackend(CacheBackend): """ - def __init__(self, arguments): + def __init__(self, arguments: MemoryBackendArguments): self._cache = arguments.get("cache_dict", {}) def get(self, key): diff --git a/dogpile/cache/backends/null.py b/dogpile/cache/backends/null.py index b4ad0fb..b349792 100644 --- a/dogpile/cache/backends/null.py +++ b/dogpile/cache/backends/null.py @@ -10,6 +10,11 @@ """ +from __future__ import annotations + +from typing import Any +from typing import Dict + from ..api import CacheBackend from ..api import NO_VALUE @@ -41,7 +46,7 @@ class NullBackend(CacheBackend): """ - def __init__(self, arguments): + def __init__(self, arguments: Dict[str, Any]): pass def get_mutex(self, key): diff --git a/dogpile/cache/backends/redis.py b/dogpile/cache/backends/redis.py index 68f84f5..7618f5a 100644 --- a/dogpile/cache/backends/redis.py +++ b/dogpile/cache/backends/redis.py @@ -6,13 +6,27 @@ """ -import typing +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypedDict +from typing import Union import warnings from ..api import BytesBackend +from ..api import KeyType from ..api import NO_VALUE +from ..api import SerializedReturnType -if typing.TYPE_CHECKING: +if TYPE_CHECKING: import redis else: # delayed import @@ -21,6 +35,41 @@ __all__ = ("RedisBackend", "RedisSentinelBackend", "RedisClusterBackend") +class RedisBackendKwargs(TypedDict, total=False): + """ + TypedDict of kwargs for `RedisBackend` and derived classes + .. versionadded:: 1.4.1 + """ + + url: Optional[str] + host: str + username: Optional[str] + password: Optional[str] + port: int + db: int + redis_expiration_time: int + distributed_lock: bool + lock_timeout: int + socket_timeout: Optional[float] + socket_connect_timeout: Optional[float] + socket_keepalive: bool + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] + lock_sleep: float + connection_pool: Optional["redis.ConnectionPool"] + thread_local_lock: bool + connection_kwargs: Dict[str, Any] + + +class RedisSentinelBackendKwargs(RedisBackendKwargs): + sentinels: List[Tuple[str, str]] + service_name: str + sentinel_kwargs: Dict[str, Any] + + +class RedisClusterBackendKwargs(RedisBackendKwargs): + startup_nodes: List["redis.cluster.ClusterNode"] + + class RedisBackend(BytesBackend): r"""A `Redis `_ backend, using the `redis-py `_ driver. @@ -114,33 +163,29 @@ class RedisBackend(BytesBackend): .. versionadded:: 1.1.6 - - - """ - def __init__(self, arguments): - arguments = arguments.copy() + def __init__(self, arguments: RedisBackendKwargs): self._imports() - self.url = arguments.pop("url", None) - self.host = arguments.pop("host", "localhost") - self.username = arguments.pop("username", None) - self.password = arguments.pop("password", None) - self.port = arguments.pop("port", 6379) - self.db = arguments.pop("db", 0) - self.distributed_lock = arguments.pop("distributed_lock", False) - self.socket_timeout = arguments.pop("socket_timeout", None) - self.socket_connect_timeout = arguments.pop( + self.url = arguments.get("url", None) + self.host = arguments.get("host", "localhost") + self.username = arguments.get("username", None) + self.password = arguments.get("password", None) + self.port = arguments.get("port", 6379) + self.db = arguments.get("db", 0) + self.distributed_lock = arguments.get("distributed_lock", False) + self.socket_timeout = arguments.get("socket_timeout", None) + self.socket_connect_timeout = arguments.get( "socket_connect_timeout", None ) - self.socket_keepalive = arguments.pop("socket_keepalive", False) - self.socket_keepalive_options = arguments.pop( + self.socket_keepalive = arguments.get("socket_keepalive", False) + self.socket_keepalive_options = arguments.get( "socket_keepalive_options", None ) - self.lock_timeout = arguments.pop("lock_timeout", None) - self.lock_sleep = arguments.pop("lock_sleep", 0.1) - self.thread_local_lock = arguments.pop("thread_local_lock", True) - self.connection_kwargs = arguments.pop("connection_kwargs", {}) + self.lock_timeout = arguments.get("lock_timeout", None) + self.lock_sleep = arguments.get("lock_sleep", 0.1) + self.thread_local_lock = arguments.get("thread_local_lock", True) + self.connection_kwargs = arguments.get("connection_kwargs", {}) if self.distributed_lock and self.thread_local_lock: warnings.warn( @@ -148,16 +193,16 @@ def __init__(self, arguments): "set to False when distributed_lock is True" ) - self.redis_expiration_time = arguments.pop("redis_expiration_time", 0) - self.connection_pool = arguments.pop("connection_pool", None) + self.redis_expiration_time = arguments.get("redis_expiration_time", 0) + self.connection_pool = arguments.get("connection_pool", None) self._create_client() - def _imports(self): + def _imports(self) -> None: # defer imports until backend is used global redis import redis # noqa - def _create_client(self): + def _create_client(self) -> None: if self.connection_pool is not None: # the connection pool already has all other connection # options present within, so here we disregard socket_timeout @@ -195,7 +240,7 @@ def _create_client(self): self.writer_client = redis.StrictRedis(**args) self.reader_client = self.writer_client - def get_mutex(self, key): + def get_mutex(self, key: KeyType) -> Optional[_RedisLockWrapper]: if self.distributed_lock: return _RedisLockWrapper( self.writer_client.lock( @@ -208,25 +253,27 @@ def get_mutex(self, key): else: return None - def get_serialized(self, key): + def get_serialized(self, key: KeyType) -> SerializedReturnType: value = self.reader_client.get(key) if value is None: return NO_VALUE - return value + return cast(SerializedReturnType, value) - def get_serialized_multi(self, keys): + def get_serialized_multi( + self, keys: Sequence[KeyType] + ) -> Sequence[SerializedReturnType]: if not keys: return [] values = self.reader_client.mget(keys) return [v if v is not None else NO_VALUE for v in values] - def set_serialized(self, key, value): + def set_serialized(self, key: KeyType, value: bytes) -> None: if self.redis_expiration_time: self.writer_client.setex(key, self.redis_expiration_time, value) else: self.writer_client.set(key, value) - def set_serialized_multi(self, mapping): + def set_serialized_multi(self, mapping: Mapping[KeyType, bytes]) -> None: if not self.redis_expiration_time: self.writer_client.mset(mapping) else: @@ -235,23 +282,23 @@ def set_serialized_multi(self, mapping): pipe.setex(key, self.redis_expiration_time, value) pipe.execute() - def delete(self, key): + def delete(self, key: KeyType) -> None: self.writer_client.delete(key) - def delete_multi(self, keys): + def delete_multi(self, keys: Sequence[KeyType]) -> None: self.writer_client.delete(*keys) class _RedisLockWrapper: __slots__ = ("mutex", "__weakref__") - def __init__(self, mutex: typing.Any): + def __init__(self, mutex: Any): self.mutex = mutex - def acquire(self, wait: bool = True) -> typing.Any: + def acquire(self, wait: bool = True) -> Any: return self.mutex.acquire(blocking=wait) - def release(self) -> typing.Any: + def release(self) -> Any: return self.mutex.release() def locked(self) -> bool: @@ -356,13 +403,10 @@ class RedisSentinelBackend(RedisBackend): """ - def __init__(self, arguments): - arguments = arguments.copy() - - self.sentinels = arguments.pop("sentinels", None) - self.service_name = arguments.pop("service_name", "mymaster") - self.sentinel_kwargs = arguments.pop("sentinel_kwargs", {}) - + def __init__(self, arguments: RedisSentinelBackendKwargs): + self.sentinels = arguments.get("sentinels", None) + self.service_name = arguments.get("service_name", "mymaster") + self.sentinel_kwargs = arguments.get("sentinel_kwargs", {}) super().__init__( arguments={ "distributed_lock": True, @@ -371,7 +415,7 @@ def __init__(self, arguments): } ) - def _imports(self): + def _imports(self) -> None: # defer imports until backend is used global redis import redis.sentinel # noqa @@ -545,17 +589,16 @@ class RedisClusterBackend(RedisBackend): """ - def __init__(self, arguments): - arguments = arguments.copy() - self.startup_nodes = arguments.pop("startup_nodes", None) + def __init__(self, arguments: RedisClusterBackendKwargs): + self.startup_nodes = arguments.get("startup_nodes", None) super().__init__(arguments) - def _imports(self): + def _imports(self) -> None: global redis import redis.cluster - def _create_client(self): - redis_cluster: redis.cluster.RedisCluster[typing.Any] + def _create_client(self) -> None: + redis_cluster: redis.cluster.RedisCluster[Any] if self.url is not None: redis_cluster = redis.cluster.RedisCluster.from_url( self.url, **self.connection_kwargs @@ -565,5 +608,5 @@ def _create_client(self): startup_nodes=self.startup_nodes, **self.connection_kwargs, ) - self.writer_client = typing.cast("redis.Redis[bytes]", redis_cluster) + self.writer_client = cast("redis.Redis[bytes]", redis_cluster) self.reader_client = self.writer_client diff --git a/dogpile/cache/backends/valkey.py b/dogpile/cache/backends/valkey.py index 8604173..f6078b7 100644 --- a/dogpile/cache/backends/valkey.py +++ b/dogpile/cache/backends/valkey.py @@ -6,13 +6,24 @@ """ -import typing +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypedDict +from typing import Union import warnings from ..api import BytesBackend from ..api import NO_VALUE -if typing.TYPE_CHECKING: +if TYPE_CHECKING: import valkey else: # delayed import @@ -21,6 +32,36 @@ __all__ = ("ValkeyBackend", "ValkeySentinelBackend", "ValkeyClusterBackend") +class ValkeyBackendArguments(TypedDict, total=False): + url: Optional[str] + host: str + username: Optional[str] + password: Optional[str] + port: int + db: int + valkey_expiration_time: int + distributed_lock: bool + lock_timeout: Optional[int] + socket_timeout: Optional[float] + socket_connect_timeout: Optional[float] + socket_keepalive: bool + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] + lock_sleep: float = 0.1 + thread_local_lock: bool + connection_kwargs: Dict[str, Any] + connection_pool: Optional["valkey.ConnectionPool"] + + +class ValkeySentinelBackendArguments(ValkeyBackendArguments): + sentinels: Optional[List[Tuple[str, int]]] + service_name: str + sentinel_kwargs: Dict[str, Any] + + +class ValkeyClusterBackendArguments(ValkeyBackendArguments): + startup_nodes: List["valkey.cluster.ClusterNode"] + + class ValkeyBackend(BytesBackend): r"""A `Valkey `_ backend, using the `valkey-py `_ driver. @@ -85,9 +126,9 @@ class ValkeyBackend(BytesBackend): :param socket_keepalive_options: dict, socket keepalive options. Default is None (no options). - :param lock_sleep: integer, number of seconds to sleep when failed to + :param lock_sleep: float, number of seconds to sleep when failed to acquire a lock. This argument is only valid when - ``distributed_lock`` is ``True``. + ``distributed_lock`` is ``True``. Default is `0.1`, the Valkey default. :param connection_pool: ``valkey.ConnectionPool`` object. If provided, this object supersedes other connection arguments passed to the @@ -109,28 +150,27 @@ class ValkeyBackend(BytesBackend): """ - def __init__(self, arguments): - arguments = arguments.copy() + def __init__(self, arguments: ValkeyBackendArguments): self._imports() - self.url = arguments.pop("url", None) - self.host = arguments.pop("host", "localhost") - self.username = arguments.pop("username", None) - self.password = arguments.pop("password", None) - self.port = arguments.pop("port", 6379) - self.db = arguments.pop("db", 0) - self.distributed_lock = arguments.pop("distributed_lock", False) - self.socket_timeout = arguments.pop("socket_timeout", None) - self.socket_connect_timeout = arguments.pop( + self.url = arguments.get("url", None) + self.host = arguments.get("host", "localhost") + self.username = arguments.get("username", None) + self.password = arguments.get("password", None) + self.port = arguments.get("port", 6379) + self.db = arguments.get("db", 0) + self.distributed_lock = arguments.get("distributed_lock", False) + self.socket_timeout = arguments.get("socket_timeout", None) + self.socket_connect_timeout = arguments.get( "socket_connect_timeout", None ) - self.socket_keepalive = arguments.pop("socket_keepalive", False) - self.socket_keepalive_options = arguments.pop( + self.socket_keepalive = arguments.get("socket_keepalive", False) + self.socket_keepalive_options = arguments.get( "socket_keepalive_options", None ) - self.lock_timeout = arguments.pop("lock_timeout", None) - self.lock_sleep = arguments.pop("lock_sleep", 0.1) - self.thread_local_lock = arguments.pop("thread_local_lock", True) - self.connection_kwargs = arguments.pop("connection_kwargs", {}) + self.lock_timeout = arguments.get("lock_timeout", None) + self.lock_sleep = arguments.get("lock_sleep", 0.1) + self.thread_local_lock = arguments.get("thread_local_lock", True) + self.connection_kwargs = arguments.get("connection_kwargs", {}) if self.distributed_lock and self.thread_local_lock: warnings.warn( @@ -138,10 +178,10 @@ def __init__(self, arguments): "set to False when distributed_lock is True" ) - self.valkey_expiration_time = arguments.pop( + self.valkey_expiration_time = arguments.get( "valkey_expiration_time", 0 ) - self.connection_pool = arguments.pop("connection_pool", None) + self.connection_pool = arguments.get("connection_pool", None) self._create_client() def _imports(self): @@ -237,13 +277,13 @@ def delete_multi(self, keys): class _ValkeyLockWrapper: __slots__ = ("mutex", "__weakref__") - def __init__(self, mutex: typing.Any): + def __init__(self, mutex: Any): self.mutex = mutex - def acquire(self, wait: bool = True) -> typing.Any: + def acquire(self, wait: bool = True) -> Any: return self.mutex.acquire(blocking=wait) - def release(self) -> typing.Any: + def release(self) -> Any: return self.mutex.release() def locked(self) -> bool: @@ -348,12 +388,10 @@ class ValkeySentinelBackend(ValkeyBackend): """ - def __init__(self, arguments): - arguments = arguments.copy() - - self.sentinels = arguments.pop("sentinels", None) - self.service_name = arguments.pop("service_name", "mymaster") - self.sentinel_kwargs = arguments.pop("sentinel_kwargs", {}) + def __init__(self, arguments: ValkeySentinelBackendArguments): + self.sentinels = arguments.get("sentinels", None) + self.service_name = arguments.get("service_name", "mymaster") + self.sentinel_kwargs = arguments.get("sentinel_kwargs", {}) super().__init__( arguments={ @@ -537,17 +575,17 @@ class ValkeyClusterBackend(ValkeyBackend): """ # noqa: E501 - def __init__(self, arguments): - arguments = arguments.copy() - self.startup_nodes = arguments.pop("startup_nodes", None) - super().__init__(arguments) + def __init__(self, arguments: ValkeyClusterBackendArguments): + self.startup_nodes = arguments.get("startup_nodes", None) + _arguments_super = cast(ValkeyBackendArguments, arguments) + super().__init__(_arguments_super) def _imports(self): global valkey import valkey.cluster def _create_client(self): - valkey_cluster: valkey.cluster.ValkeyCluster[typing.Any] # type: ignore # noqa: E501 + valkey_cluster: valkey.cluster.ValkeyCluster[Any] # type: ignore # noqa: E501 if self.url is not None: valkey_cluster = valkey.cluster.ValkeyCluster.from_url( self.url, **self.connection_kwargs @@ -557,5 +595,5 @@ def _create_client(self): startup_nodes=self.startup_nodes, **self.connection_kwargs, ) - self.writer_client = typing.cast(valkey.Valkey[bytes], valkey_cluster) # type: ignore # noqa: E501 + self.writer_client = cast(valkey.Valkey[bytes], valkey_cluster) # type: ignore # noqa: E501 self.reader_client = self.writer_client diff --git a/dogpile/cache/proxy.py b/dogpile/cache/proxy.py index 08f598e..6d096c7 100644 --- a/dogpile/cache/proxy.py +++ b/dogpile/cache/proxy.py @@ -34,24 +34,24 @@ class ProxyBackend(CacheBackend): from dogpile.cache.proxy import ProxyBackend class MyFirstProxy(ProxyBackend): - def get_serialized(self, key): + def get_serialized(self, key: KeyType) -> SerializedReturnType: # ... custom code goes here ... return self.proxied.get_serialized(key) - def get(self, key): + def get(self, key: KeyType) -> BackendFormatted: # ... custom code goes here ... return self.proxied.get(key) - def set(self, key, value): + def set(self, key: KeyType, value: BackendSetType) -> None: # ... custom code goes here ... self.proxied.set(key) class MySecondProxy(ProxyBackend): - def get_serialized(self, key): + def get_serialized(self, key: KeyType) -> SerializedReturnType: # ... custom code goes here ... return self.proxied.get_serialized(key) - def get(self, key): + def get(self, key: KeyType) -> BackendFormatted: # ... custom code goes here ... return self.proxied.get(key) diff --git a/dogpile/cache/region.py b/dogpile/cache/region.py index 352b308..75d7c72 100644 --- a/dogpile/cache/region.py +++ b/dogpile/cache/region.py @@ -12,6 +12,8 @@ from typing import Any from typing import Callable from typing import cast +from typing import Dict +from typing import List from typing import Mapping from typing import Optional from typing import Sequence @@ -25,6 +27,7 @@ from . import exception from .api import BackendArguments from .api import BackendFormatted +from .api import CacheBackend from .api import CachedValue from .api import CacheMutex from .api import CacheReturnType @@ -553,7 +556,7 @@ def wrap(self, proxy: Union[ProxyBackend, Type[ProxyBackend]]) -> None: self.backend = proxy_instance.wrap(self.backend) - def _mutex(self, key): + def _mutex(self, key: KeyType) -> Any: return self._lock_registry.get(key) class _LockWrapper(CacheMutex): @@ -571,7 +574,7 @@ def release(self): def locked(self): return self.lock.locked() - def _create_mutex(self, key): + def _create_mutex(self, key: KeyType) -> Optional[Any]: mutex = self.backend.get_mutex(key) if mutex is not None: return mutex @@ -579,10 +582,10 @@ def _create_mutex(self, key): return self._LockWrapper() # cached value - _actual_backend = None + _actual_backend: Optional[CacheBackend] = None @property - def actual_backend(self): + def actual_backend(self) -> CacheBackend: """Return the ultimate backend underneath any proxies. The backend might be the result of one or more ``proxy.wrap`` @@ -596,9 +599,11 @@ def actual_backend(self): while hasattr(_backend, "proxied"): _backend = _backend.proxied self._actual_backend = _backend + if TYPE_CHECKING: + assert self._actual_backend return self._actual_backend - def invalidate(self, hard=True): + def invalidate(self, hard: bool = True) -> None: """Invalidate this :class:`.CacheRegion`. The default invalidation system works by setting @@ -648,7 +653,11 @@ def invalidate(self, hard=True): """ self.region_invalidator.invalidate(hard) - def configure_from_config(self, config_dict, prefix): + def configure_from_config( + self, + config_dict: Dict[str, Any], + prefix: str, + ) -> Self: """Configure from a configuration dictionary and a prefix. @@ -680,20 +689,20 @@ def configure_from_config(self, config_dict, prefix): ), _config_argument_dict=config_dict, _config_prefix="%sarguments." % prefix, - wrap=config_dict.get("%swrap" % prefix, None), + wrap=config_dict.get("%swrap" % prefix, ()), replace_existing_backend=config_dict.get( "%sreplace_existing_backend" % prefix, False ), ) @memoized_property - def backend(self): + def backend(self) -> CacheBackend: raise exception.RegionNotConfigured( "No backend is configured on this region." ) @property - def is_configured(self): + def is_configured(self) -> bool: """Return True if the backend has been configured via the :meth:`.CacheRegion.configure` method already. @@ -824,7 +833,9 @@ def _get_cache_value( return value def _unexpired_value_fn( - self, expiration_time: Optional[float], ignore_expiration: bool + self, + expiration_time: Optional[float], + ignore_expiration: bool = False, ) -> Callable[[CacheReturnType], CacheReturnType]: if ignore_expiration: return lambda value: value @@ -851,7 +862,12 @@ def value_fn(value): return value_fn - def get_multi(self, keys, expiration_time=None, ignore_expiration=False): + def get_multi( + self, + keys: Sequence[KeyType], + expiration_time: Optional[float] = None, + ignore_expiration: bool = False, + ) -> List[Union[ValuePayload, NoValueType]]: """Return multiple values from the cache, based on the given keys. Returns values as a list matching the keys given. @@ -902,7 +918,7 @@ def get_multi(self, keys, expiration_time=None, ignore_expiration=False): ] @contextlib.contextmanager - def _log_time(self, keys): + def _log_time(self, keys: Sequence[KeyType]): start_time = time.time() yield seconds = time.time() - start_time @@ -912,7 +928,11 @@ def _log_time(self, keys): {"seconds": seconds, "keys": repr_obj(keys)}, ) - def _is_cache_miss(self, value, orig_key): + def _is_cache_miss( + self, + value: CacheReturnType, + orig_key: KeyType, + ) -> bool: if value is NO_VALUE: log.debug("No value present for key: %r", orig_key) elif value.metadata["v"] != value_version: @@ -1154,7 +1174,7 @@ def get_or_create_multi( """ - def get_value(key): + def get_value(key: KeyType) -> Tuple[Any, Union[float, int]]: value = values.get(key, NO_VALUE) if self._is_cache_miss(value, orig_key): diff --git a/dogpile/cache/util.py b/dogpile/cache/util.py index 7fddaa5..560de14 100644 --- a/dogpile/cache/util.py +++ b/dogpile/cache/util.py @@ -1,10 +1,17 @@ from hashlib import sha1 +from typing import Any +from typing import Callable +from .api import KeyType from ..util import compat from ..util import langhelpers -def function_key_generator(namespace, fn, to_str=str): +def function_key_generator( + namespace: str, + fn: Callable, + to_str: Callable[[Any], str] = str, +) -> Callable: """Return a function that generates a string key, based on a given function as well as arguments to the returned function itself. @@ -45,7 +52,11 @@ def generate_key(*args, **kw): return generate_key -def function_multi_key_generator(namespace, fn, to_str=str): +def function_multi_key_generator( + namespace: str, + fn: Callable, + to_str: Callable[[Any], str] = str, +) -> Callable: if namespace is None: namespace = "%s:%s" % (fn.__module__, fn.__name__) else: @@ -67,7 +78,11 @@ def generate_keys(*args, **kw): return generate_keys -def kwarg_function_key_generator(namespace, fn, to_str=str): +def kwarg_function_key_generator( + namespace: str, + fn: Callable, + to_str: Callable[[Any], str] = str, +) -> Callable: """Return a function that generates a string key, based on a given function as well as arguments to the returned function itself. @@ -127,16 +142,17 @@ def generate_key(*args, **kwargs): return generate_key -def sha1_mangle_key(key): +def sha1_mangle_key(key: KeyType) -> str: """a SHA1 key mangler.""" - if isinstance(key, str): - key = key.encode("utf-8") + bkey = key.encode("utf-8") if isinstance(key, str) else key - return sha1(key).hexdigest() + return sha1(bkey).hexdigest() -def length_conditional_mangler(length, mangler): +def length_conditional_mangler( + length: int, mangler: Callable[[KeyType], str] +) -> Callable[[KeyType], str]: """a key mangler that mangles if the length of the key is past a certain threshold. diff --git a/dogpile/testing/fixtures.py b/dogpile/testing/fixtures.py index 4dff69f..cce5114 100644 --- a/dogpile/testing/fixtures.py +++ b/dogpile/testing/fixtures.py @@ -1,5 +1,7 @@ # mypy: ignore-errors +from __future__ import annotations + import collections import itertools import json @@ -7,6 +9,8 @@ from threading import Lock from threading import Thread import time +from typing import Any +from typing import Dict import uuid import pytest @@ -51,7 +55,7 @@ def _check_backend_available(cls, backend): pass backend: str - region_args = {} + region_args: Dict[str, Any] = {} config_args = {} extra_arguments = {} backend_argument_names = () diff --git a/tests/cache/test_memcached_backend.py b/tests/cache/test_memcached_backend.py index f3b106f..d1e2afa 100644 --- a/tests/cache/test_memcached_backend.py +++ b/tests/cache/test_memcached_backend.py @@ -10,8 +10,10 @@ from dogpile.cache.backends.memcached import GenericMemcachedBackend from dogpile.cache.backends.memcached import MemcachedBackend +from dogpile.cache.backends.memcached import MemcachedBackendArguments from dogpile.cache.backends.memcached import PylibmcBackend from dogpile.cache.backends.memcached import PyMemcacheBackend +from dogpile.cache.backends.memcached import PyMemcacheBackendArguments from dogpile.testing import eq_ from dogpile.testing import is_ from dogpile.testing.fixtures import _GenericBackendTestSuite @@ -242,7 +244,7 @@ def _mock_pymemcache_fixture(self): ) def test_pymemcache_hashclient_retry_attempts(self): - config_args = { + config_args: PyMemcacheBackendArguments = { "url": "127.0.0.1:11211", "hashclient_retry_attempts": 4, } @@ -265,7 +267,10 @@ def test_pymemcache_hashclient_retry_attempts(self): eq_(self.retrying_client.mock_calls, []) def test_pymemcache_hashclient_retry_timeout(self): - config_args = {"url": "127.0.0.1:11211", "hashclient_retry_timeout": 4} + config_args: PyMemcacheBackendArguments = { + "url": "127.0.0.1:11211", + "hashclient_retry_timeout": 4, + } with self._mock_pymemcache_fixture(): backend = MockPyMemcacheBackend(config_args) is_(backend._create_client(), self.hash_client()) @@ -284,7 +289,7 @@ def test_pymemcache_hashclient_retry_timeout(self): eq_(self.retrying_client.mock_calls, []) def test_pymemcache_hashclient_retry_timeout_w_enable_retry(self): - config_args = { + config_args: PyMemcacheBackendArguments = { "url": "127.0.0.1:11211", "hashclient_retry_timeout": 4, "enable_retry_client": True, @@ -317,7 +322,10 @@ def test_pymemcache_hashclient_retry_timeout_w_enable_retry(self): ) def test_pymemcache_dead_timeout(self): - config_args = {"url": "127.0.0.1:11211", "hashclient_dead_timeout": 4} + config_args: PyMemcacheBackendArguments = { + "url": "127.0.0.1:11211", + "hashclient_dead_timeout": 4, + } with self._mock_pymemcache_fixture(): backend = MockPyMemcacheBackend(config_args) backend._create_client() @@ -338,7 +346,10 @@ def test_pymemcache_dead_timeout(self): eq_(self.retrying_client.mock_calls, []) def test_pymemcache_enable_retry_client_not_set(self): - config_args = {"url": "127.0.0.1:11211", "retry_attempts": 2} + config_args: PyMemcacheBackendArguments = { + "url": "127.0.0.1:11211", + "retry_attempts": 2, + } with self._mock_pymemcache_fixture(): with mock.patch("warnings.warn") as warn_mock: @@ -352,7 +363,10 @@ def test_pymemcache_enable_retry_client_not_set(self): ) def test_pymemcache_memacached_expire_time(self): - config_args = {"url": "127.0.0.1:11211", "memcached_expire_time": 20} + config_args: PyMemcacheBackendArguments = { + "url": "127.0.0.1:11211", + "memcached_expire_time": 20, + } with self._mock_pymemcache_fixture(): backend = MockPyMemcacheBackend(config_args) backend.set("foo", "bar") @@ -446,7 +460,7 @@ def delete(self, key): class MemcachedBackendTest: def test_memcached_dead_retry(self): - config_args = { + config_args: MemcachedBackendArguments = { "url": "127.0.0.1:11211", "dead_retry": 4, } @@ -454,7 +468,7 @@ def test_memcached_dead_retry(self): eq_(backend._create_client().kw["dead_retry"], 4) def test_memcached_socket_timeout(self): - config_args = { + config_args: MemcachedBackendArguments = { "url": "127.0.0.1:11211", "socket_timeout": 6, } diff --git a/tests/cache/test_redis_backend.py b/tests/cache/test_redis_backend.py index 01472a9..c0520cc 100644 --- a/tests/cache/test_redis_backend.py +++ b/tests/cache/test_redis_backend.py @@ -2,11 +2,14 @@ import os from threading import Event import time +from typing import Type +from typing import TYPE_CHECKING from unittest.mock import Mock from unittest.mock import patch import pytest +from dogpile.cache.backends.redis import RedisBackendKwargs from dogpile.cache.region import _backend_loader from dogpile.testing import eq_ from dogpile.testing.fixtures import _GenericBackendFixture @@ -14,6 +17,9 @@ from dogpile.testing.fixtures import _GenericMutexTestSuite from dogpile.testing.fixtures import _GenericSerializerTestSuite +if TYPE_CHECKING: + import redis + REDIS_HOST = "127.0.0.1" REDIS_PORT = int(os.getenv("DOGPILE_REDIS_PORT", "6379")) expect_redis_running = os.getenv("DOGPILE_REDIS_PORT") is not None @@ -129,6 +135,7 @@ def blah(k): @patch("redis.StrictRedis", autospec=True) class RedisConnectionTest: backend = "dogpile.cache.redis" + backend_cls = Type["redis.StrictRedis"] @classmethod def setup_class(cls): @@ -147,7 +154,7 @@ def _test_helper(self, mock_obj, expected_args, connection_args=None): def test_connect_with_defaults(self, MockStrictRedis): # The defaults, used if keys are missing from the arguments dict. - arguments = { + arguments: RedisBackendKwargs = { "host": "localhost", "port": 6379, "db": 0, @@ -157,7 +164,7 @@ def test_connect_with_defaults(self, MockStrictRedis): self._test_helper(MockStrictRedis, expected, arguments) def test_connect_with_basics(self, MockStrictRedis): - arguments = { + arguments: RedisBackendKwargs = { "host": "127.0.0.1", "port": 6379, "db": 0, @@ -167,7 +174,7 @@ def test_connect_with_basics(self, MockStrictRedis): self._test_helper(MockStrictRedis, expected, arguments) def test_connect_with_password(self, MockStrictRedis): - arguments = { + arguments: RedisBackendKwargs = { "host": "127.0.0.1", "password": "some password", "port": 6379, @@ -182,7 +189,7 @@ def test_connect_with_password(self, MockStrictRedis): self._test_helper(MockStrictRedis, expected, arguments) def test_connect_with_username_and_password(self, MockStrictRedis): - arguments = { + arguments: RedisBackendKwargs = { "host": "127.0.0.1", "username": "redis", "password": "some password", @@ -192,7 +199,7 @@ def test_connect_with_username_and_password(self, MockStrictRedis): self._test_helper(MockStrictRedis, arguments) def test_connect_with_socket_timeout(self, MockStrictRedis): - arguments = { + arguments: RedisBackendKwargs = { "host": "127.0.0.1", "port": 6379, "socket_timeout": 0.5, @@ -203,7 +210,7 @@ def test_connect_with_socket_timeout(self, MockStrictRedis): self._test_helper(MockStrictRedis, expected, arguments) def test_connect_with_socket_connect_timeout(self, MockStrictRedis): - arguments = { + arguments: RedisBackendKwargs = { "host": "127.0.0.1", "port": 6379, "socket_timeout": 1.0, @@ -214,7 +221,7 @@ def test_connect_with_socket_connect_timeout(self, MockStrictRedis): self._test_helper(MockStrictRedis, expected, arguments) def test_connect_with_socket_keepalive(self, MockStrictRedis): - arguments = { + arguments: RedisBackendKwargs = { "host": "127.0.0.1", "port": 6379, "socket_keepalive": True, @@ -225,7 +232,7 @@ def test_connect_with_socket_keepalive(self, MockStrictRedis): self._test_helper(MockStrictRedis, expected, arguments) def test_connect_with_socket_keepalive_options(self, MockStrictRedis): - arguments = { + arguments: RedisBackendKwargs = { "host": "127.0.0.1", "port": 6379, "socket_keepalive": True, @@ -239,18 +246,23 @@ def test_connect_with_socket_keepalive_options(self, MockStrictRedis): def test_connect_with_connection_pool(self, MockStrictRedis): pool = Mock() - arguments = {"connection_pool": pool, "socket_timeout": 0.5} + arguments: RedisBackendKwargs = { + "connection_pool": pool, + "socket_timeout": 0.5, + } expected_args = {"connection_pool": pool} self._test_helper( MockStrictRedis, expected_args, connection_args=arguments ) def test_connect_with_url(self, MockStrictRedis): - arguments = {"url": "redis://redis:password@127.0.0.1:6379/0"} + arguments: RedisBackendKwargs = { + "url": "redis://redis:password@127.0.0.1:6379/0" + } self._test_helper(MockStrictRedis.from_url, arguments) def test_extra_arbitrary_args(self, MockStrictRedis): - arguments = { + arguments: RedisBackendKwargs = { "url": "redis://redis:password@127.0.0.1:6379/0", "connection_kwargs": { "ssl": True, diff --git a/tests/cache/test_region.py b/tests/cache/test_region.py index 6f8120e..bf5538a 100644 --- a/tests/cache/test_region.py +++ b/tests/cache/test_region.py @@ -24,7 +24,7 @@ from dogpile.testing.fixtures import MockBackend -def key_mangler(key): +def key_mangler(key: str) -> str: return "HI!" + key diff --git a/tests/cache/test_valkey_backend.py b/tests/cache/test_valkey_backend.py index f0e1c25..8d83a7b 100644 --- a/tests/cache/test_valkey_backend.py +++ b/tests/cache/test_valkey_backend.py @@ -2,6 +2,8 @@ import os from threading import Event import time +from typing import Type +from typing import TYPE_CHECKING from unittest.mock import Mock from unittest.mock import patch @@ -14,6 +16,9 @@ from dogpile.testing.fixtures import _GenericMutexTestSuite from dogpile.testing.fixtures import _GenericSerializerTestSuite +if TYPE_CHECKING: + import valkey + VALKEY_HOST = "127.0.0.1" VALKEY_PORT = int(os.getenv("DOGPILE_VALKEY_PORT", "6379")) expect_valkey_running = os.getenv("DOGPILE_VALKEY_PORT") is not None @@ -129,6 +134,7 @@ def blah(k): @patch("valkey.StrictValkey", autospec=True) class ValkeyConnectionTest: backend = "dogpile.cache.valkey" + backend_cls = Type["valkey.StrictValkey"] @classmethod def setup_class(cls):