Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,20 @@ def __init__(
redis_client: Optional[Union[Redis, RedisCluster]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> None:
super().__init__(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_blob_prefix=checkpoint_blob_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)
# Pre-compute common prefixes for performance
self._checkpoint_prefix = CHECKPOINT_PREFIX
self._checkpoint_blob_prefix = CHECKPOINT_BLOB_PREFIX
self._checkpoint_write_prefix = CHECKPOINT_WRITE_PREFIX
# Prefixes are now set in BaseRedisSaver.__init__
self._separator = REDIS_KEY_SEPARATOR

# Instance-level cache for frequently used keys (limited size to prevent memory issues)
Expand Down Expand Up @@ -116,13 +119,13 @@ def configure_client(

def create_indexes(self) -> None:
self.checkpoints_index = SearchIndex.from_dict(
self.SCHEMAS[0], redis_client=self._redis
self.checkpoints_schema, redis_client=self._redis
)
self.checkpoint_blobs_index = SearchIndex.from_dict(
self.SCHEMAS[1], redis_client=self._redis
self.blobs_schema, redis_client=self._redis
)
self.checkpoint_writes_index = SearchIndex.from_dict(
self.SCHEMAS[2], redis_client=self._redis
self.writes_schema, redis_client=self._redis
)

def _make_redis_checkpoint_key_cached(
Expand Down Expand Up @@ -848,7 +851,7 @@ def _get_write_keys_from_search(
write_results = self.checkpoint_writes_index.search(write_query)

return [
BaseRedisSaver._make_redis_checkpoint_writes_key(
self._make_redis_checkpoint_writes_key(
to_storage_safe_id(thread_id),
to_storage_safe_str(checkpoint_ns),
to_storage_safe_id(checkpoint_id),
Expand Down Expand Up @@ -1119,6 +1122,9 @@ def from_conn_string(
redis_client: Optional[Union[Redis, RedisCluster]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> Iterator[RedisSaver]:
"""Create a new RedisSaver instance."""
saver: Optional[RedisSaver] = None
Expand All @@ -1128,6 +1134,9 @@ def from_conn_string(
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_blob_prefix=checkpoint_blob_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)

yield saver
Expand Down Expand Up @@ -1615,7 +1624,7 @@ def delete_thread(self, thread_id: str) -> None:
channel = getattr(doc, "channel", "")
version = getattr(doc, "version", "")

blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
blob_key = self._make_redis_checkpoint_blob_key(
storage_safe_thread_id, checkpoint_ns, channel, version
)
keys_to_delete.append(blob_key)
Expand All @@ -1635,7 +1644,7 @@ def delete_thread(self, thread_id: str) -> None:
task_id = getattr(doc, "task_id", "")
idx = getattr(doc, "idx", 0)

write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
write_key = self._make_redis_checkpoint_writes_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
keys_to_delete.append(write_key)
Expand Down
49 changes: 28 additions & 21 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@
from redisvl.query.filter import Num, Tag
from ulid import ULID

from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.checkpoint.redis.base import (
BaseRedisSaver,
CHECKPOINT_BLOB_PREFIX,
CHECKPOINT_PREFIX,
CHECKPOINT_WRITE_PREFIX,
REDIS_KEY_SEPARATOR,
)
from langgraph.checkpoint.redis.key_registry import (
AsyncCheckpointKeyRegistry as AsyncKeyRegistry,
)
Expand Down Expand Up @@ -81,30 +87,25 @@ def __init__(
redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> None:
super().__init__(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_blob_prefix=checkpoint_blob_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)
self.loop = asyncio.get_running_loop()

# Instance-level cache for frequently used keys (limited size to prevent memory issues)
self._key_cache: Dict[str, str] = {}
self._key_cache_max_size = 1000 # Configurable limit

# Pre-compute common prefixes for performance
from langgraph.checkpoint.redis.base import (
CHECKPOINT_BLOB_PREFIX,
CHECKPOINT_PREFIX,
CHECKPOINT_WRITE_PREFIX,
REDIS_KEY_SEPARATOR,
)

self._checkpoint_prefix = CHECKPOINT_PREFIX
self._checkpoint_blob_prefix = CHECKPOINT_BLOB_PREFIX
self._checkpoint_write_prefix = CHECKPOINT_WRITE_PREFIX
self._separator = REDIS_KEY_SEPARATOR

def configure_client(
Expand All @@ -128,13 +129,13 @@ def configure_client(
def create_indexes(self) -> None:
"""Create indexes without connecting to Redis."""
self.checkpoints_index = AsyncSearchIndex.from_dict(
self.SCHEMAS[0], redis_client=self._redis
self.checkpoints_schema, redis_client=self._redis
)
self.checkpoint_blobs_index = AsyncSearchIndex.from_dict(
self.SCHEMAS[1], redis_client=self._redis
self.blobs_schema, redis_client=self._redis
)
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(
self.SCHEMAS[2], redis_client=self._redis
self.writes_schema, redis_client=self._redis
)

def _make_redis_checkpoint_key_cached(
Expand Down Expand Up @@ -375,7 +376,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)

# Construct direct key for checkpoint data
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
checkpoint_key = self._make_redis_checkpoint_key(
storage_safe_thread_id,
storage_safe_checkpoint_ns,
storage_safe_checkpoint_id,
Expand Down Expand Up @@ -476,7 +477,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
# If we didn't get TTL from pipeline (i.e., came from else branch), fetch it now
if "current_ttl" not in locals():
# Get the checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
checkpoint_key = self._make_redis_checkpoint_key(
to_storage_safe_id(doc_thread_id),
to_storage_safe_str(doc_checkpoint_ns),
to_storage_safe_id(doc_checkpoint_id),
Expand Down Expand Up @@ -1054,7 +1055,7 @@ async def aput(
}

# Prepare checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
checkpoint_key = self._make_redis_checkpoint_key(
storage_safe_thread_id,
storage_safe_checkpoint_ns,
storage_safe_checkpoint_id,
Expand Down Expand Up @@ -1441,12 +1442,18 @@ async def from_conn_string(
redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> AsyncIterator[AsyncRedisSaver]:
async with cls(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_blob_prefix=checkpoint_blob_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
) as saver:
yield saver

Expand Down Expand Up @@ -1980,7 +1987,7 @@ async def adelete_thread(self, thread_id: str) -> None:
checkpoint_namespaces.add(checkpoint_ns)

# Delete checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
checkpoint_key = self._make_redis_checkpoint_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id
)
keys_to_delete.append(checkpoint_key)
Expand All @@ -2004,7 +2011,7 @@ async def adelete_thread(self, thread_id: str) -> None:
channel = getattr(doc, "channel", "")
version = getattr(doc, "version", "")

blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
blob_key = self._make_redis_checkpoint_blob_key(
storage_safe_thread_id, checkpoint_ns, channel, version
)
keys_to_delete.append(blob_key)
Expand All @@ -2024,7 +2031,7 @@ async def adelete_thread(self, thread_id: str) -> None:
task_id = getattr(doc, "task_id", "")
idx = getattr(doc, "idx", 0)

write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
write_key = self._make_redis_checkpoint_writes_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
keys_to_delete.append(write_key)
Expand Down
71 changes: 17 additions & 54 deletions langgraph/checkpoint/redis/ashallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,52 +38,6 @@
to_storage_safe_str,
)

SCHEMAS = [
{
"index": {
"name": "checkpoints",
"prefix": CHECKPOINT_PREFIX + REDIS_KEY_SEPARATOR,
"storage_type": "json",
},
"fields": [
{"name": "thread_id", "type": "tag"},
{"name": "checkpoint_ns", "type": "tag"},
{"name": "source", "type": "tag"},
{"name": "step", "type": "numeric"},
],
},
{
"index": {
"name": "checkpoints_blobs",
"prefix": CHECKPOINT_BLOB_PREFIX + REDIS_KEY_SEPARATOR,
"storage_type": "json",
},
"fields": [
{"name": "thread_id", "type": "tag"},
{"name": "checkpoint_ns", "type": "tag"},
{"name": "channel", "type": "tag"},
{"name": "type", "type": "tag"},
],
},
{
"index": {
"name": "checkpoint_writes",
"prefix": CHECKPOINT_WRITE_PREFIX + REDIS_KEY_SEPARATOR,
"storage_type": "json",
},
"fields": [
{"name": "thread_id", "type": "tag"},
{"name": "checkpoint_ns", "type": "tag"},
{"name": "checkpoint_id", "type": "tag"},
{"name": "task_id", "type": "tag"},
{"name": "idx", "type": "numeric"},
{"name": "channel", "type": "tag"},
{"name": "type", "type": "tag"},
],
},
]


class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
"""Async Redis implementation that only stores the most recent checkpoint."""

Expand All @@ -101,12 +55,18 @@ def __init__(
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
ttl: Optional[dict[str, Any]] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> None:
super().__init__(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_blob_prefix=checkpoint_blob_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)
self.loop = asyncio.get_running_loop()

Expand All @@ -115,9 +75,6 @@ def __init__(
self._key_cache_max_size = 1000 # Configurable limit
self._channel_cache: Dict[str, Any] = {}

# Cache commonly used prefixes
self._checkpoint_prefix = CHECKPOINT_PREFIX
self._checkpoint_write_prefix = CHECKPOINT_WRITE_PREFIX
self._separator = REDIS_KEY_SEPARATOR

async def __aenter__(self) -> AsyncShallowRedisSaver:
Expand Down Expand Up @@ -158,13 +115,19 @@ async def from_conn_string(
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
ttl: Optional[dict[str, Any]] = None,
checkpoint_prefix: str = CHECKPOINT_PREFIX,
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
) -> AsyncIterator[AsyncShallowRedisSaver]:
"""Create a new AsyncShallowRedisSaver instance."""
async with cls(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
checkpoint_prefix=checkpoint_prefix,
checkpoint_blob_prefix=checkpoint_blob_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
) as saver:
yield saver

Expand Down Expand Up @@ -733,14 +696,14 @@ def configure_client(
def create_indexes(self) -> None:
"""Create indexes without connecting to Redis."""
self.checkpoints_index = AsyncSearchIndex.from_dict(
self.SCHEMAS[0], redis_client=self._redis
self.checkpoints_schema, redis_client=self._redis
)
# Shallow implementation doesn't use blobs, but base class requires the attribute
self.checkpoint_blobs_index = AsyncSearchIndex.from_dict(
self.SCHEMAS[1], redis_client=self._redis
self.blobs_schema, redis_client=self._redis
)
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(
self.SCHEMAS[2], redis_client=self._redis
self.writes_schema, redis_client=self._redis
)

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Expand Down Expand Up @@ -837,7 +800,7 @@ def _make_redis_checkpoint_writes_key_cached(
)
if cache_key not in self._key_cache:
self._key_cache[cache_key] = (
BaseRedisSaver._make_redis_checkpoint_writes_key(
self._make_redis_checkpoint_writes_key(
thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
)
Expand Down Expand Up @@ -884,7 +847,7 @@ def _make_shallow_redis_checkpoint_blob_key_cached(
if len(self._key_cache) >= self._key_cache_max_size:
# Remove oldest entry when cache is full
self._key_cache.pop(next(iter(self._key_cache)))
self._key_cache[cache_key] = BaseRedisSaver._make_redis_checkpoint_blob_key(
self._key_cache[cache_key] = self._make_redis_checkpoint_blob_key(
thread_id, checkpoint_ns, channel, version
)
return self._key_cache[cache_key]
Loading