diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index caaff2ae..ff540b51 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,7 +63,7 @@ repos: # basic check - id: ruff name: Ruff check - args: ["--fix"] + args: ["--fix"] #, "--unsafe-fixes" # it needs to be after formatting hooks because the lines might be changed - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/src/cachier/config.py b/src/cachier/config.py index 4c7bb1d7..ff3eb651 100644 --- a/src/cachier/config.py +++ b/src/cachier/config.py @@ -98,7 +98,11 @@ def _update_with_defaults( def set_default_params(**params: Any) -> None: - """Configure default parameters applicable to all memoized functions.""" + """Configure default parameters applicable to all memoized functions. + + Deprecated, use :func:`~cachier.config.set_global_params` instead. + + """ # It is kept for backwards compatibility with desperation warning import warnings @@ -115,13 +119,21 @@ def set_global_params(**params: Any) -> None: """Configure global parameters applicable to all memoized functions. This function takes the same keyword parameters as the ones defined in the - decorator, which can be passed all at once or with multiple calls. - Parameters given directly to a decorator take precedence over any values - set by this function. - - Only 'stale_after', 'next_time', and 'wait_for_calc_timeout' can be changed - after the memoization decorator has been applied. Other parameters will - only have an effect on decorators applied after this function is run. + decorator. Parameters given directly to a decorator take precedence over + any values set by this function. + + Note on dynamic behavior: + - If a decorator parameter is provided explicitly (not None), that value + is used for the decorated function and is not affected by later changes + to the global parameters. + - If a decorator parameter is left as None, the decorator/core may read + the corresponding value from the global params at call time. Parameters + that are read dynamically (when decorator parameter was None) include: + 'stale_after', 'next_time', 'allow_none', 'cleanup_stale', + 'cleanup_interval', and 'caching_enabled'. In some cores, if the + decorator was created without concrete value for 'wait_for_calc_timeout', + calls that check calculation timeouts will fall back to the global + 'wait_for_calc_timeout' as well. """ import cachier @@ -138,7 +150,11 @@ def set_global_params(**params: Any) -> None: def get_default_params() -> Params: - """Get current set of default parameters.""" + """Get current set of default parameters. + + Deprecated, use :func:`~cachier.config.get_global_params` instead. + + """ # It is kept for backwards compatibility with desperation warning import warnings diff --git a/src/cachier/core.py b/src/cachier/core.py index 8c56d960..e999feaf 100644 --- a/src/cachier/core.py +++ b/src/cachier/core.py @@ -134,14 +134,15 @@ def cachier( value is their id), equal objects across different sessions will not yield identical keys. - Arguments: - --------- + Parameters + ---------- hash_func : callable, optional A callable that gets the args and kwargs from the decorated function and returns a hash key for them. This parameter can be used to enable the use of cachier with functions that get arguments that are not automatically hashable by Python. hash_params : callable, optional + Deprecated, use :func:`~cachier.core.cachier.hash_func` instead. backend : str, optional The name of the backend to use. Valid options currently include 'pickle', 'mongo', 'memory', 'sql', and 'redis'. If not provided, @@ -149,8 +150,8 @@ def cachier( mongetter : callable, optional A callable that takes no arguments and returns a pymongo.Collection - object with writing permissions. If unset a local pickle cache is used - instead. + object with writing permissions. If provided, the backend is set to + 'mongo'. sql_engine : str, Engine, or callable, optional SQLAlchemy connection string, Engine, or callable returning an Engine. Used for the SQL backend. @@ -177,8 +178,8 @@ def cachier( separate_files: bool, default False, for Pickle cores only Instead of a single cache file per-function, each function's cache is split between several files, one for each argument set. This can help - if you per-function cache files become too large. - wait_for_calc_timeout: int, optional, for MongoDB only + if your per-function cache files become too large. + wait_for_calc_timeout: int, optional The maximum time to wait for an ongoing calculation. When a process started to calculate the value setting being_calculated to True, any process trying to read the same entry will wait a maximum of @@ -358,11 +359,8 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds): ) nonneg_max_age = False else: - max_allowed_age = ( - min(_stale_after, max_age) - if max_age is not None - else _stale_after - ) + assert max_age is not None # noqa: S101 + max_allowed_age = min(_stale_after, max_age) # note: if max_age < 0, we always consider a value stale if nonneg_max_age and (now - entry.time <= max_allowed_age): _print("And it is fresh!") diff --git a/src/cachier/cores/base.py b/src/cachier/cores/base.py index ef631850..f1ea8702 100644 --- a/src/cachier/cores/base.py +++ b/src/cachier/cores/base.py @@ -27,12 +27,17 @@ class RecalculationNeeded(Exception): def _get_func_str(func: Callable) -> str: - return f".{func.__module__}.{func.__name__}" + """Return a string identifier for the function (module + name). + + We accept Any here because static analysis can't always prove that the + runtime object will have __module__ and __name__, but at runtime the + decorated functions always do. + """ + return f".{func.__module__}.{func.__name__}" -class _BaseCore: - __metaclass__ = abc.ABCMeta +class _BaseCore(metaclass=abc.ABCMeta): def __init__( self, hash_func: Optional[HashFunc], @@ -90,8 +95,8 @@ def check_calc_timeout(self, time_spent): def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: """Get entry based on given key. - Return the result mapped to the given key in this core's cache, if such - a mapping exists. + Return the key and the :class:`~cachier.config.CacheEntry` mapped + to the given key in this core's cache, if such a mapping exists. """ diff --git a/src/cachier/cores/redis.py b/src/cachier/cores/redis.py index ff4d8fd0..46bacaa8 100644 --- a/src/cachier/cores/redis.py +++ b/src/cachier/cores/redis.py @@ -73,6 +73,54 @@ def set_func(self, func): super().set_func(func) self._func_str = _get_func_str(func) + @staticmethod + def _loading_pickle(raw_value) -> Any: + """Load pickled data with some recovery attempts.""" + try: + if isinstance(raw_value, bytes): + return pickle.loads(raw_value) + elif isinstance(raw_value, str): + # try to recover by encoding; prefer utf-8 but fall + # back to latin-1 in case raw binary was coerced to str + try: + return pickle.loads(raw_value.encode("utf-8")) + except Exception: + return pickle.loads(raw_value.encode("latin-1")) + else: + # unexpected type; attempt pickle.loads directly + try: + return pickle.loads(raw_value) + except Exception: + return None + except Exception as exc: + warnings.warn( + f"Redis value deserialization failed: {exc}", + stacklevel=2, + ) + return None + + @staticmethod + def _get_raw_field(cached_data, field: str): + """Fetch field from cached_data with bytes/str key handling.""" + # try bytes key first, then str key + bkey = field.encode("utf-8") + if bkey in cached_data: + return cached_data[bkey] + return cached_data.get(field) + + @staticmethod + def _get_bool_field(cached_data, name: str) -> bool: + """Fetch boolean field from cached_data.""" + raw = _RedisCore._get_raw_field(cached_data, name) or b"false" + if isinstance(raw, bytes): + try: + s = raw.decode("utf-8") + except Exception: + s = raw.decode("latin-1", errors="ignore") + else: + s = str(raw) + return s.lower() == "true" + def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: """Get entry based on given key from Redis.""" redis_client = self._resolve_redis_client() @@ -86,32 +134,28 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: # Deserialize the value value = None - if cached_data.get(b"value"): - value = pickle.loads(cached_data[b"value"]) + raw_value = _RedisCore._get_raw_field(cached_data, "value") + if raw_value is not None: + value = self._loading_pickle(raw_value) # Parse timestamp - timestamp_str = cached_data.get(b"timestamp", b"").decode("utf-8") + raw_ts = _RedisCore._get_raw_field(cached_data, "timestamp") or b"" + if isinstance(raw_ts, bytes): + try: + timestamp_str = raw_ts.decode("utf-8") + except UnicodeDecodeError: + timestamp_str = raw_ts.decode("latin-1", errors="ignore") + else: + timestamp_str = str(raw_ts) timestamp = ( datetime.fromisoformat(timestamp_str) if timestamp_str else datetime.now() ) - # Parse boolean fields - stale = ( - cached_data.get(b"stale", b"false").decode("utf-8").lower() - == "true" - ) - processing = ( - cached_data.get(b"processing", b"false") - .decode("utf-8") - .lower() - == "true" - ) - completed = ( - cached_data.get(b"completed", b"false").decode("utf-8").lower() - == "true" - ) + stale = _RedisCore._get_bool_field(cached_data, "stale") + processing = _RedisCore._get_bool_field(cached_data, "processing") + completed = _RedisCore._get_bool_field(cached_data, "completed") entry = CacheEntry( value=value, @@ -126,9 +170,9 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: return key, None def set_entry(self, key: str, func_res: Any) -> bool: + """Map the given result to the given key in Redis.""" if not self._should_store(func_res): return False - """Map the given result to the given key in Redis.""" redis_client = self._resolve_redis_client() redis_key = self._get_redis_key(key) @@ -242,8 +286,16 @@ def delete_stale_entries(self, stale_after: timedelta) -> None: ts = redis_client.hget(key, "timestamp") if ts is None: continue + # ts may be bytes or str depending on client configuration + if isinstance(ts, bytes): + try: + ts_s = ts.decode("utf-8") + except Exception: + ts_s = ts.decode("latin-1", errors="ignore") + else: + ts_s = str(ts) try: - ts_val = datetime.fromisoformat(ts.decode("utf-8")) + ts_val = datetime.fromisoformat(ts_s) except Exception as exc: warnings.warn( f"Redis timestamp parse failed: {exc}", stacklevel=2 diff --git a/tests/test_pickle_core.py b/tests/test_pickle_core.py index 81823776..75ab0ac2 100644 --- a/tests/test_pickle_core.py +++ b/tests/test_pickle_core.py @@ -34,6 +34,7 @@ from cachier import cachier from cachier.config import CacheEntry, _global_params from cachier.cores.pickle import _PickleCore +from cachier.cores.redis import _RedisCore def _get_decorated_func(func, **kwargs): @@ -42,9 +43,6 @@ def _get_decorated_func(func, **kwargs): return decorated_func -# Pickle core tests - - def _takes_2_seconds(arg_1, arg_2): """Some function.""" sleep(2) @@ -528,7 +526,6 @@ def _error_throwing_func(arg1): @pytest.mark.parametrize("separate_files", [True, False]) def test_error_throwing_func(separate_files): # with - _error_throwing_func.count = 0 _error_throwing_func_decorated = _get_decorated_func( _error_throwing_func, stale_after=timedelta(seconds=1), @@ -536,6 +533,7 @@ def test_error_throwing_func(separate_files): separate_files=separate_files, ) _error_throwing_func_decorated.clear_cache() + _error_throwing_func.count = 0 res1 = _error_throwing_func_decorated(4) sleep(1.5) res2 = _error_throwing_func_decorated(4) @@ -1074,3 +1072,70 @@ def mock_func(): with patch("os.remove", side_effect=FileNotFoundError): # Should not raise exception core.delete_stale_entries(timedelta(hours=1)) + + +# Redis core static method tests +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (pickle.dumps({"test": 123}), {"test": 123}), # valid string + # (pickle.dumps({"test": 123}).decode("utf-8"), {"test": 123}), + # (b"\x80\x04\x95", None), # corrupted bytes + (123, None), # unexpected type + # (b"corrupted", None), # triggers warning + ], +) +def test_redis_loading_pickle(test_input, expected): + """Test _RedisCore._loading_pickle with various inputs and exceptions.""" + assert _RedisCore._loading_pickle(test_input) == expected + + +def test_redis_loading_pickle_failed(): + """Test _RedisCore._loading_pickle with various inputs and exceptions.""" + with patch("pickle.loads", side_effect=Exception("Failed")): + assert _RedisCore._loading_pickle(123) is None + + +def test_redis_loading_pickle_latin1_fallback(): + """Test _RedisCore._loading_pickle with latin-1 fallback.""" + valid_obj = {"test": 123} + with patch("pickle.loads") as mock_loads: + mock_loads.side_effect = [Exception("UTF-8 failed"), valid_obj] + result = _RedisCore._loading_pickle("invalid_utf8_string") + assert result == valid_obj + assert mock_loads.call_count == 2 + + +@pytest.mark.parametrize( + ("cached_data", "key", "expected"), + [ + ({b"field": b"value", "other": "data"}, "field", b"value"), + ({"field": "value", b"other": b"data"}, "field", "value"), + ({"other": "value"}, "field", None), + ], +) +def test_redis_get_raw_field(cached_data, key, expected): + """Test _RedisCore._get_raw_field with bytes and string keys.""" + assert _RedisCore._get_raw_field(cached_data, key) == expected + + +@pytest.mark.parametrize( + ("cached_data", "key", "expected"), + [ + ({b"flag": b"true"}, "flag", True), + ({b"flag": b"false"}, "flag", False), + ({"flag": "TRUE"}, "flag", True), + ({}, "flag", False), + ({b"flag": 123}, "flag", False), + ], +) +def test_redis_get_bool_field(cached_data, key, expected): + """Test _RedisCore._get_bool_field with various inputs.""" + assert _RedisCore._get_bool_field(cached_data, key) == expected + + +def test_redis_get_bool_field_decode_fallback(): + """Test _RedisCore._get_bool_field with decoding fallback.""" + with patch.object(_RedisCore, "_get_raw_field", return_value=b"\xff\xfe"): + result = _RedisCore._get_bool_field({}, "flag") + assert result is False diff --git a/tests/test_redis_core_exceptions.py b/tests/test_redis_core_exceptions.py new file mode 100644 index 00000000..69d2d472 --- /dev/null +++ b/tests/test_redis_core_exceptions.py @@ -0,0 +1,158 @@ +from datetime import datetime, timedelta +from unittest.mock import NonCallableMock, patch + +import pytest + +from cachier.cores.redis import _RedisCore + + +@pytest.mark.redis +class TestRedisCoreExceptions: + @pytest.fixture + def mock_redis(self): + """Fixture providing a mock Redis client.""" + return NonCallableMock() + + @pytest.fixture + def core(self, mock_redis): + """Fixture providing a Redis core instance with mock client.""" + core = _RedisCore( + hash_func=None, redis_client=mock_redis, wait_for_calc_timeout=10 + ) + core.set_func(lambda x: x) # Set a dummy function + return core + + def test_loading_pickle_exceptions_bytes(self): + """Test _loading_pickle handles exceptions when deserializing bytes.""" + with ( + patch("pickle.loads", side_effect=Exception("Pickle error")), + pytest.warns( + UserWarning, match="Redis value deserialization failed" + ), + ): + assert _RedisCore._loading_pickle(b"data") is None + + def test_loading_pickle_exceptions_str_success(self): + """Test _loading_pickle latin-1 fallback for str input.""" + with patch("pickle.loads") as mock_loads: + mock_loads.side_effect = [Exception("UTF-8 error"), "success"] + res = _RedisCore._loading_pickle("data") + assert res == "success" + assert mock_loads.call_count == 2 + + def test_loading_pickle_exceptions_str_fail(self): + """Test _loading_pickle decoding failure for str input.""" + with ( + patch("pickle.loads", side_effect=Exception("Pickle error")), + pytest.warns( + UserWarning, match="Redis value deserialization failed" + ), + ): + assert _RedisCore._loading_pickle("data") is None + + def test_loading_pickle_exceptions_other_type(self): + """Test _loading_pickle exception handling for unsupported types.""" + with patch("pickle.loads", side_effect=Exception("Pickle error")): + res = _RedisCore._loading_pickle(123) + assert res is None + + def test_get_bool_field_exceptions(self): + """Test _get_bool_field decoding exception fallback to latin-1.""" + # Byte string that fails utf-8 but works with latin-1 + # b'\xff' is invalid start byte in utf-8 + + with patch.object(_RedisCore, "_get_raw_field", return_value=b"\xff"): + res = _RedisCore._get_bool_field({}, "flag") + assert res is False # "ÿ" != "true" + + def test_get_entry_by_key_exceptions_hgetall(self, core, mock_redis): + """Test get_entry_by_key hgetall exception.""" + mock_redis.hgetall.side_effect = Exception("Redis error") + with pytest.warns(UserWarning, match="Redis get_entry_by_key failed"): + assert core.get_entry_by_key("key")[1] is None + + def test_get_entry_by_key_exceptions_timestamp(self, core, mock_redis): + """Test get_entry_by_key timestamp decoding exception.""" + mock_redis.hgetall.side_effect = None + mock_redis.hgetall.return_value = { + b"timestamp": b"\xff" + } # Invalid utf-8 + with pytest.warns(UserWarning, match="Redis get_entry_by_key failed"): + core.get_entry_by_key("key") + + def test_set_entry_exceptions(self, core, mock_redis): + """Test set_entry Redis hset exception handling and return False.""" + mock_redis.hset.side_effect = Exception("Redis error") + with pytest.warns(UserWarning, match="Redis set_entry failed"): + assert core.set_entry("key", "value") is False + + def test_mark_entry_being_calculated_exceptions(self, core, mock_redis): + """Test mark_entry_being_calculated Redis hset exception handling.""" + mock_redis.hset.side_effect = Exception("Redis error") + with pytest.warns( + UserWarning, match="Redis mark_entry_being_calculated failed" + ): + core.mark_entry_being_calculated("key") + + def test_mark_entry_not_calculated_exceptions(self, core, mock_redis): + """Test mark_entry_not_calculated Redis hset exception handling.""" + mock_redis.hset.side_effect = Exception("Redis error") + with pytest.warns( + UserWarning, match="Redis mark_entry_not_calculated failed" + ): + core.mark_entry_not_calculated("key") + + def test_clear_cache_exceptions(self, core, mock_redis): + """Test clear_cache Redis keys exception handling.""" + mock_redis.keys.side_effect = Exception("Redis error") + with pytest.warns(UserWarning, match="Redis clear_cache failed"): + core.clear_cache() + + def test_clear_being_calculated_exceptions(self, core, mock_redis): + """Test clear_being_calculated Redis keys exception handling.""" + mock_redis.keys.side_effect = Exception("Redis error") + with pytest.warns( + UserWarning, match="Redis clear_being_calculated failed" + ): + core.clear_being_calculated() + + def test_delete_stale_entries_keys_exception(self, core, mock_redis): + """Test delete_stale_entries Redis keys exception handling.""" + mock_redis.keys.side_effect = Exception("Redis error") + with pytest.warns( + UserWarning, match="Redis delete_stale_entries failed" + ): + core.delete_stale_entries(timedelta(seconds=1)) + + def test_delete_stale_entries_timestamp_parse_exception( + self, core, mock_redis + ): + """Test delete_stale_entries timestamp parsing exception handling.""" + mock_redis.keys.return_value = [b"key1"] + mock_redis.hget.return_value = b"invalid_timestamp" + + with pytest.warns(UserWarning, match="Redis timestamp parse failed"): + core.delete_stale_entries(timedelta(seconds=1)) + + def test_delete_stale_entries_latin1_fallback(self, core, mock_redis): + """Test delete_stale_entries uses latin-1 for invalid utf-8.""" + mock_redis.keys.return_value = [b"key1"] + # b'\xff' is invalid utf-8 start byte + mock_redis.hget.return_value = b"\xff" + + # It will decode to "ÿ" (latin-1) then fail date parsing + with pytest.warns(UserWarning, match="Redis timestamp parse failed"): + core.delete_stale_entries(timedelta(seconds=1)) + + def test_delete_stale_entries_str_timestamp(self, core, mock_redis): + """Test delete_stale_entries handles string timestamps (not bytes).""" + mock_redis.keys.return_value = [b"key1"] + now = datetime.now() + old_time = now - timedelta(hours=1) + # Return a string, not bytes + mock_redis.hget.return_value = old_time.isoformat() + + # Should not warn, and should delete key because + # it is stale (stale_after=1s) + core.delete_stale_entries(timedelta(seconds=1)) + mock_redis.delete.assert_called_with(b"key1")