Skip to content

Commit ae41958

Browse files
bsboddentylerhutchersonvishal-bala
authored
fix: resolve working memory key via search index when user_id/namespace omitted (#239)
* fix: resolve working memory key via search index when user_id/namespace omitted (#235) PUT /v1/working-memory/{session_id} stores user_id and namespace from the request body into the Redis key, but GET/DELETE construct the key from query params. When callers omit those params, a different key is produced and the session appears missing (404). Add an index-based fallback: when a direct key lookup returns nothing, query the working memory search index by session_id to resolve the actual Redis key. This preserves multi-tenant isolation (namespace + user_id remain in keys) while making GET/DELETE work without requiring callers to repeat scoping parameters. Changes: - Add _resolve_working_memory_key_via_index() helper in working_memory.py - Use fallback in get_working_memory() and delete_working_memory() - Return stored namespace/user_id from JSON data instead of caller params - Add user_id/namespace params to MCP get_working_memory tool - Add comprehensive tests covering the fix and multi-tenant isolation * Update agent_memory_server/working_memory.py Co-authored-by: Vishal Bala <vishal-bala@users.noreply.github.com> * Fix linter issues --------- Co-authored-by: Tyler Hutcherson <tyler.hutcherson@redis.com> Co-authored-by: Vishal Bala <vishal-bala@users.noreply.github.com> Co-authored-by: Vishal Bala <vishalbala.1994@gmail.com>
1 parent 32fb411 commit ae41958

3 files changed

Lines changed: 495 additions & 3 deletions

File tree

agent_memory_server/mcp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,20 +979,27 @@ async def set_working_memory(
979979
@mcp_app.tool()
980980
async def get_working_memory(
981981
session_id: str,
982+
user_id: str | None = None,
983+
namespace: str | None = None,
982984
recent_messages_limit: int | None = None,
983985
) -> WorkingMemory:
984986
"""
985987
Get working memory for a session. This works like the GET /sessions/{id}/memory API endpoint.
986988
987989
Args:
988990
session_id: The session ID to retrieve working memory for
991+
user_id: Optional user ID to scope the session lookup
992+
namespace: Optional namespace to scope the session lookup
989993
recent_messages_limit: Optional limit on number of recent messages to return (most recent first)
990994
991995
Returns:
992996
Working memory containing messages, context, and structured memory records
993997
"""
994998
result = await working_memory_core.get_working_memory(
995-
session_id=session_id, recent_messages_limit=recent_messages_limit
999+
session_id=session_id,
1000+
user_id=user_id,
1001+
namespace=namespace,
1002+
recent_messages_limit=recent_messages_limit,
9961003
)
9971004
if result is None:
9981005
return WorkingMemory(session_id=session_id, messages=[], memories=[])

agent_memory_server/working_memory.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,84 @@ async def list_sessions(
356356
return 0, []
357357

358358

359+
async def _resolve_working_memory_key_via_index(
360+
redis_client: Redis,
361+
session_id: str,
362+
user_id: str | None = None,
363+
namespace: str | None = None,
364+
) -> str | None:
365+
"""
366+
Resolve the actual Redis key for a working memory session using the search index.
367+
368+
When a direct key lookup fails (e.g. because user_id/namespace were provided
369+
during PUT but omitted during GET), this function queries the search index
370+
by session_id to find the document and derive the correct Redis key.
371+
372+
If multiple sessions share the same session_id (different namespace/user_id)
373+
and the caller did not supply enough filters to disambiguate, the function
374+
logs a warning and returns ``None`` to avoid silently returning the wrong
375+
session.
376+
377+
Args:
378+
redis_client: Redis client
379+
session_id: The session ID to look up
380+
user_id: Optional user_id filter (narrows results if multiple sessions
381+
share an ID)
382+
namespace: Optional namespace filter
383+
384+
Returns:
385+
The Redis key string if exactly one match is found, None otherwise
386+
"""
387+
388+
from agent_memory_server.working_memory_index import get_working_memory_index
389+
390+
try:
391+
index = await get_working_memory_index(redis_client)
392+
393+
filter_expression = Tag("session_id") == session_id
394+
if namespace:
395+
filter_expression &= Tag("namespace") == namespace
396+
if user_id:
397+
filter_expression &= Tag("user_id") == user_id
398+
399+
# Request up to 2 results so we can detect ambiguity.
400+
filter_query = FilterQuery(
401+
filter_expression=filter_expression,
402+
return_fields=["session_id", "namespace", "user_id"],
403+
num_results=2,
404+
)
405+
406+
raw_results = await index.search(filter_query)
407+
docs = getattr(raw_results, "docs", raw_results) or []
408+
409+
if not docs:
410+
return None
411+
412+
total = getattr(raw_results, "total", len(docs))
413+
if total > 1:
414+
logger.warning(
415+
"Ambiguous working-memory lookup for session_id=%s: "
416+
"%d sessions matched. Provide namespace/user_id to disambiguate.",
417+
session_id,
418+
total,
419+
)
420+
return None
421+
422+
doc = docs[0]
423+
# RedisVL returns doc.id as the full Redis key
424+
doc_key = getattr(doc, "id", None)
425+
if doc_key:
426+
if isinstance(doc_key, bytes):
427+
doc_key = doc_key.decode("utf-8")
428+
return doc_key
429+
430+
return None
431+
432+
except Exception as e:
433+
logger.debug(f"Index-based key resolution failed for session {session_id}: {e}")
434+
return None
435+
436+
359437
async def get_working_memory(
360438
session_id: str,
361439
user_id: str | None = None,
@@ -418,6 +496,21 @@ async def get_working_memory(
418496
)
419497
# If key_type is "none", the key doesn't exist - working_memory_data stays None
420498

499+
# Fallback: if direct key lookup failed, try resolving via the search
500+
# index. This handles the case where PUT stored with user_id/namespace
501+
# but GET was called without them (issue #235).
502+
if not working_memory_data:
503+
resolved_key = await _resolve_working_memory_key_via_index(
504+
redis_client, session_id, user_id, namespace
505+
)
506+
if resolved_key and resolved_key != key:
507+
logger.debug(
508+
f"Resolved working memory key via index: {resolved_key} "
509+
f"(original key: {key})"
510+
)
511+
key = resolved_key
512+
working_memory_data = await redis_client.json().get(key)
513+
421514
if not working_memory_data:
422515
logger.debug(
423516
f"No working memory found for parameters: {session_id}, {user_id}, {namespace}"
@@ -467,14 +560,19 @@ async def get_working_memory(
467560
MemoryStrategyConfig()
468561
) # Default to discrete strategy
469562

563+
# Use stored values for namespace/user_id — the caller may not have
564+
# provided them (index-fallback path, issue #235).
565+
stored_namespace = working_memory_data.get("namespace") or namespace
566+
stored_user_id = working_memory_data.get("user_id") or user_id
567+
470568
return WorkingMemory(
471569
messages=messages,
472570
memories=memories,
473571
context=working_memory_data.get("context"),
474-
user_id=working_memory_data.get("user_id"),
572+
user_id=stored_user_id,
475573
tokens=working_memory_data.get("tokens", 0),
476574
session_id=session_id,
477-
namespace=namespace,
575+
namespace=stored_namespace,
478576
ttl_seconds=working_memory_data.get("ttl_seconds", None),
479577
data=working_memory_data.get("data") or {},
480578
long_term_memory_strategy=long_term_memory_strategy,
@@ -589,6 +687,16 @@ async def delete_working_memory(
589687
)
590688

591689
try:
690+
# Check if the key exists; if not, try resolving via the search index
691+
# (same fallback as get_working_memory for issue #235).
692+
exists = await redis_client.exists(key)
693+
if not exists:
694+
resolved_key = await _resolve_working_memory_key_via_index(
695+
redis_client, session_id, user_id, namespace
696+
)
697+
if resolved_key:
698+
key = resolved_key
699+
592700
# Delete the JSON key - the working memory search index automatically
593701
# removes the document from the index when the key is deleted
594702
await redis_client.delete(key)

0 commit comments

Comments
 (0)