Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/gradient/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
ResponseCache as ResponseCache,
)
from ._compat import (
get_args as get_args,
Expand Down
78 changes: 78 additions & 0 deletions src/gradient/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,81 @@ def json_safe(data: object) -> object:
return data.isoformat()

return data


# Response Caching Classes
class ResponseCache:
"""Simple in-memory response cache with TTL support."""

def __init__(self, max_size: int = 100, default_ttl: int = 300) -> None:
"""Initialize the cache.

Args:
max_size: Maximum number of cached responses
default_ttl: Default time-to-live in seconds
"""
self.max_size: int = max_size
self.default_ttl: int = default_ttl
self._cache: dict[str, tuple[Any, float]] = {}
self._access_order: list[str] = []

def _make_key(self, method: str, url: str, params: dict[str, Any] | None = None, data: Any = None) -> str:
"""Generate a cache key from request details."""
import hashlib
import json

key_data = {
"method": method.upper(),
"url": url,
"params": params or {},
"data": json.dumps(data, sort_keys=True) if data else None
}
key_str = json.dumps(key_data, sort_keys=True)
return hashlib.md5(key_str.encode()).hexdigest()

def get(self, method: str, url: str, params: dict[str, Any] | None = None, data: Any = None) -> Any | None:
"""Get a cached response if available and not expired."""
import time

key = self._make_key(method, url, params, data)
if key in self._cache:
response, expiry = self._cache[key]
if time.time() < expiry:
# Move to end (most recently used)
self._access_order.remove(key)
self._access_order.append(key)
return response
else:
# Expired, remove it
del self._cache[key]
self._access_order.remove(key)
return None

def set(self, method: str, url: str, response: Any, ttl: int | None = None,
params: dict[str, Any] | None = None, data: Any = None) -> None:
"""Cache a response with optional TTL."""
import time

key = self._make_key(method, url, params, data)
expiry = time.time() + (ttl or self.default_ttl)

# Remove if already exists
if key in self._cache:
self._access_order.remove(key)

# Evict least recently used if at capacity
if len(self._cache) >= self.max_size:
lru_key = self._access_order.pop(0)
del self._cache[lru_key]

self._cache[key] = (response, expiry)
self._access_order.append(key)

def clear(self) -> None:
"""Clear all cached responses."""
self._cache.clear()
self._access_order.clear()

def size(self) -> int:
"""Get current cache size."""
return len(self._cache)
83 changes: 83 additions & 0 deletions tests/test_response_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Tests for response caching functionality."""

import time
import pytest
from gradient._utils import ResponseCache


class TestResponseCache:
"""Test response caching functionality."""

def test_cache_basic_operations(self):
"""Test basic cache operations."""
cache = ResponseCache(max_size=3, default_ttl=1)

# Test set and get
cache.set("GET", "/api/test", {"data": "value"})
result = cache.get("GET", "/api/test")
assert result == {"data": "value"}

# Test cache miss
result = cache.get("GET", "/api/missing")
assert result is None

def test_cache_with_params(self):
"""Test caching with query parameters."""
cache = ResponseCache()

# Set with params
cache.set("GET", "/api/search", {"results": []}, params={"q": "test"})

# Get with same params should hit
result = cache.get("GET", "/api/search", params={"q": "test"})
assert result == {"results": []}

# Get with different params should miss
result = cache.get("GET", "/api/search", params={"q": "other"})
assert result is None

def test_cache_ttl(self):
"""Test cache TTL functionality."""
cache = ResponseCache(default_ttl=0.1) # Very short TTL

cache.set("GET", "/api/test", {"data": "value"})

# Should hit immediately
result = cache.get("GET", "/api/test")
assert result == {"data": "value"}

# Wait for expiry
time.sleep(0.2)

# Should miss after expiry
result = cache.get("GET", "/api/test")
assert result is None

def test_cache_max_size(self):
"""Test cache size limits with LRU eviction."""
cache = ResponseCache(max_size=2)

# Fill cache
cache.set("GET", "/api/1", "data1")
cache.set("GET", "/api/2", "data2")
assert cache.size() == 2

# Add third item (should evict first)
cache.set("GET", "/api/3", "data3")
assert cache.size() == 2

# First item should be gone
assert cache.get("GET", "/api/1") is None
assert cache.get("GET", "/api/2") == "data2"
assert cache.get("GET", "/api/3") == "data3"

def test_cache_clear(self):
"""Test cache clearing."""
cache = ResponseCache()

cache.set("GET", "/api/test", {"data": "value"})
assert cache.size() == 1

cache.clear()
assert cache.size() == 0
assert cache.get("GET", "/api/test") is None