Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions memori/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import litellm # noqa: F401
from litellm import success_callback # noqa: F401

_ = litellm # Mark as intentionally imported
LITELLM_AVAILABLE = True
except ImportError:
LITELLM_AVAILABLE = False
Expand Down Expand Up @@ -2634,12 +2635,14 @@ def get_auto_ingest_system_prompt(self, user_input: str) -> str:
Get auto-ingest context as system prompt for direct injection.
Returns relevant memories based on user input as formatted system prompt.
Use this for auto_ingest mode.
Note: Context retrieval is handled by _get_auto_ingest_context().
This function only formats pre-retrieved context.
"""
try:
# For now, use recent short-term memories as a simple approach
# This avoids the search engine issues and still provides context
# TODO: Use user_input for intelligent context retrieval
context = self._get_conscious_context() # Get recent short-term memories
# Get recent short-term memories as fallback context
# The actual intelligent retrieval is handled by _get_auto_ingest_context()
context = self._get_conscious_context()

if not context:
return ""
Expand Down
229 changes: 221 additions & 8 deletions tests/openai_support/openai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"organization": os.getenv("OPENAI_ORGANIZATION"), # Optional organization
}

is_valid = config["api_key"] and not config["api_key"].startswith("sk-your-")
is_valid = bool(config["api_key"]) and not config["api_key"].startswith("sk-your-")

return is_valid, config

Expand Down Expand Up @@ -73,15 +73,19 @@

memory.enable()

# Create OpenAI client
# Create OpenAI client with explicit timeout
try:
client_kwargs = {"api_key": openai_config["api_key"]}
client_kwargs = {
"api_key": openai_config["api_key"],
"timeout": 30, # Prevent hanging on network issues
}

if openai_config["base_url"]:
client_kwargs["base_url"] = openai_config["base_url"]
if openai_config["organization"]:
client_kwargs["organization"] = openai_config["organization"]

# Create client directly; memori.enable() handles interception
client = OpenAI(**client_kwargs)

# Test connection first
Expand Down Expand Up @@ -121,16 +125,18 @@

success_count += 1

# Small delay to avoid rate limiting
time.sleep(0.5)
# Small delay between API calls to avoid rate limits
time.sleep(0.2)

except Exception as e:
print(f"[{i}/{len(test_inputs)}] Error: {e}")
error_count += 1

if "rate_limit" in str(e).lower() or "429" in str(e):
print("Rate limit hit, waiting 60 seconds...")
time.sleep(60)
# Exponential backoff: 2, 4, 8, 16, 32, max 60 seconds
wait = min(60, 2 ** min(i, 5))
print(f"Rate limit hit, waiting {wait} seconds...")
time.sleep(wait)
elif "quota" in str(e).lower():
print("Quota exceeded - stopping test")
break
Expand Down Expand Up @@ -160,8 +166,9 @@

print(f"\n✓ OpenAI Test '{test_name}' completed.")
print(f" Database saved at: {db_path}")
total = max(1, len(test_inputs)) # Prevent divide-by-zero
print(
f" Success rate: {success_count}/{len(test_inputs)} ({100*success_count/len(test_inputs):.1f}%)\n"
f" Success rate: {success_count}/{len(test_inputs)} ({100*success_count/total:.1f}%)\n"
)

return success_count > 0
Expand Down Expand Up @@ -291,5 +298,211 @@
return successful_tests > 0


def test_auto_ingest_intelligent_retrieval():
"""
Test _get_auto_ingest_context() for intelligent context retrieval.
This function is the actual implementation that handles:
- Database search with user_input
- Fallback to recent memories
- Recursion guard protection
- Search engine integration
- Error handling
"""
from unittest.mock import patch, MagicMock

print("\n" + "=" * 60)
print("Testing _get_auto_ingest_context() Intelligent Retrieval")
print("=" * 60 + "\n")

# Create temp database
db_dir = "test_databases_openai/auto_ingest_test"
os.makedirs(db_dir, exist_ok=True)
db_path = f"{db_dir}/memory.db"

# Initialize Memori with auto_ingest
memori = Memori(
database_connect=f"sqlite:///{db_path}",
auto_ingest=True,
namespace="test_namespace",
)

test_passed = 0
test_total = 8

# Test 1: Direct database search returns results
print("\n[Test 1/8] Direct database search returns results...")
mock_search_results = [
{"searchable_content": "Result A", "category_primary": "fact"},
{"searchable_content": "Result B", "category_primary": "preference"},
{"searchable_content": "Result C", "category_primary": "skill"},
]

with patch.object(
memori.db_manager, "search_memories", return_value=mock_search_results
) as mock_search:
result = memori._get_auto_ingest_context("What are my preferences?")

# Verify results returned with metadata
if len(result) == 3 and result[0].get("retrieval_method") == "direct_database_search":
print("[OK] Test 1 passed: Direct search returns 3 results with metadata")
test_passed += 1
else:
print(f"[FAIL] Test 1 failed: got {len(result)} results, metadata: {result[0].get('retrieval_method') if result else 'N/A'}")

# Test 2: Empty input returns empty list
print("\n[Test 2/8] Empty input returns empty list...")
result = memori._get_auto_ingest_context("")

if result == []:
print("[OK] Test 2 passed: Empty input returns []")
test_passed += 1
else:
print(f"[FAIL] Test 2 failed: Expected [], got {result}")

# Test 3: Fallback to recent memories when search returns empty
print("\n[Test 3/8] Fallback to recent memories when search empty...")
mock_fallback = [
{"searchable_content": "Recent memory 1", "category_primary": "fact"},
{"searchable_content": "Recent memory 2", "category_primary": "preference"},
]

# First call returns empty, second call (fallback) returns results
with patch.object(
memori.db_manager, "search_memories", side_effect=[[], mock_fallback]
) as mock_search:
result = memori._get_auto_ingest_context("query with no results")

# Check fallback was used and metadata added
if len(result) == 2 and result[0].get("retrieval_method") == "recent_memories_fallback":
print("[OK] Test 3 passed: Fallback to recent memories works")
test_passed += 1
else:
print(f"[FAIL] Test 3 failed: got {len(result)} results, metadata: {result[0].get('retrieval_method') if result else 'N/A'}")

# Test 4: Recursion guard prevents infinite loops
print("\n[Test 4/8] Recursion guard prevents infinite loops...")
memori._in_context_retrieval = True

mock_results = [{"searchable_content": "Safe result", "category_primary": "fact"}]
with patch.object(
memori.db_manager, "search_memories", return_value=mock_results
) as mock_search:
result = memori._get_auto_ingest_context("test recursion")

# Should use direct search and return results
if result == mock_results:
print("[OK] Test 4 passed: Recursion guard triggers direct search")
test_passed += 1
else:
print(f"[FAIL] Test 4 failed: Expected direct search results")

# Reset recursion guard
memori._in_context_retrieval = False

# Test 5: Search engine fallback when direct search fails
print("\n[Test 5/8] Search engine fallback when direct search empty...")
mock_search_engine = MagicMock()
mock_engine_results = [
{"searchable_content": "Engine result", "category_primary": "fact"}
]
mock_search_engine.execute_search.return_value = mock_engine_results
memori.search_engine = mock_search_engine

with patch.object(
memori.db_manager, "search_memories", side_effect=[[], []] # Both direct and fallback empty
):
result = memori._get_auto_ingest_context("advanced query")

# Check search engine was used
if len(result) == 1 and result[0].get("retrieval_method") == "search_engine":
print("[OK] Test 5 passed: Search engine fallback works")
test_passed += 1
else:
print(f"[FAIL] Test 5 failed: got {len(result)} results, metadata: {result[0].get('retrieval_method') if result else 'N/A'}")

# Reset search engine
memori.search_engine = None

# Test 6: Error handling - graceful degradation
print("\n[Test 6/8] Error handling with graceful degradation...")

# First call fails, fallback succeeds
mock_fallback = [{"searchable_content": "Fallback", "category_primary": "fact"}]
with patch.object(
memori.db_manager,
"search_memories",
side_effect=[Exception("DB error"), mock_fallback]
):
result = memori._get_auto_ingest_context("test error handling")

# Should fallback to recent memories
if len(result) == 1 and result[0].get("retrieval_method") == "recent_memories_fallback":
print("[OK] Test 6 passed: Error handled, fallback used")
test_passed += 1
else:
print(f"[FAIL] Test 6 failed: got {len(result)} results")

# Test 7: Verify search called with correct parameters
print("\n[Test 7/8] Verify search called with correct parameters...")
with patch.object(
memori.db_manager,
"search_memories",
return_value=[{"searchable_content": "Test", "category_primary": "fact"}]
) as mock_search:
user_query = "find my API keys"
result = memori._get_auto_ingest_context(user_query)

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning test

This assignment to 'result' is unnecessary as it is
redefined
before this value is used.

# Check search was called with correct params
if mock_search.called:
call = mock_search.call_args
called_query = call.kwargs.get("query") if call.kwargs else call.args[0]
called_namespace = call.kwargs.get("namespace") if call.kwargs else None
called_limit = call.kwargs.get("limit") if call.kwargs else None

query_match = called_query == user_query
namespace_match = called_namespace == "test_namespace"
limit_match = called_limit == 5

if query_match and namespace_match and limit_match:
print("[OK] Test 7 passed: search_memories called with correct params")
test_passed += 1
else:
print(f"[FAIL] Test 7 failed: query={query_match}, ns={namespace_match}, limit={limit_match}")
else:
print("[FAIL] Test 7 failed: search_memories not called")

# Test 8: Retrieval metadata is added to results
print("\n[Test 8/8] Retrieval metadata added to all results...")
mock_results = [
{"searchable_content": "Item 1", "category_primary": "fact"},
{"searchable_content": "Item 2", "category_primary": "preference"},
]

with patch.object(
memori.db_manager, "search_memories", return_value=mock_results
):
result = memori._get_auto_ingest_context("metadata test")

# Check all results have metadata
all_have_metadata = all(
r.get("retrieval_method") and r.get("retrieval_query")
for r in result
)

if all_have_metadata and result[0]["retrieval_query"] == "metadata test":
print("[OK] Test 8 passed: All results have retrieval metadata")
test_passed += 1
else:
print(f"[FAIL] Test 8 failed: metadata missing or incorrect")

# Summary
print("\n" + "=" * 60)
print(f"_get_auto_ingest_context() Tests: {test_passed}/{test_total} passed")
print("=" * 60 + "\n")

return test_passed == test_total


if __name__ == "__main__":
main()
Loading