diff --git a/README.md b/README.md index b8ef9ce7..41b1bcff 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Strands Agents Tools is a community-driven project that provides a powerful set - πŸ“ **File Operations** - Read, write, and edit files with syntax highlighting and intelligent modifications - πŸ–₯️ **Shell Integration** - Execute and interact with shell commands securely -- 🧠 **Memory** - Store user and agent memories across agent runs to provide personalized experiences with both Mem0 and Amazon Bedrock Knowledge Bases +- 🧠 **Memory** - Store user and agent memories across agent runs to provide personalized experiences with both Mem0, Amazon Bedrock Knowledge Bases, Elasticsearch, and MongoDB Atlas - πŸ•ΈοΈ **Web Infrastructure** - Perform web searches, extract page content, and crawl websites with Tavily and Exa-powered tools - 🌐 **HTTP Client** - Make API requests with comprehensive authentication support - πŸ’¬ **Slack Client** - Real-time Slack events, message processing, and Slack API access @@ -146,6 +146,7 @@ Below is a comprehensive table of all available tools, how to use them with an a | use_computer | `agent.tool.use_computer(action="click", x=100, y=200, app_name="Chrome") ` | Desktop automation, GUI interaction, screen capture | | search_video | `agent.tool.search_video(query="people discussing AI")` | Semantic video search using TwelveLabs' Marengo model | | chat_video | `agent.tool.chat_video(prompt="What are the main topics?", video_id="video_123")` | Interactive video analysis using TwelveLabs' Pegasus model | +| mongodb_memory | `agent.tool.mongodb_memory(action="record", content="User prefers vegetarian pizza", connection_string="mongodb+srv://...", database_name="memories")` | Store and retrieve memories using MongoDB Atlas with semantic search via AWS Bedrock Titan embeddings | \* *These tools do not work on windows* @@ -886,6 +887,79 @@ result = agent.tool.elasticsearch_memory( ) ``` +### MongoDB Atlas Memory + +**Note**: This tool requires AWS account credentials to generate embeddings using Amazon Bedrock Titan models. + +```python +from strands import Agent +from strands_tools.mongodb_memory import mongodb_memory + +# Create agent with direct tool usage +agent = Agent(tools=[mongodb_memory]) + +# Store a memory with semantic embeddings +result = agent.tool.mongodb_memory( + action="record", + content="User prefers vegetarian pizza with extra cheese", + metadata={"category": "food_preferences", "type": "dietary"}, + connection_string="mongodb+srv://username:password@cluster0.mongodb.net/?retryWrites=true&w=majority", + database_name="memories", + collection_name="user_memories", + namespace="user_123" +) + +# Search memories using semantic similarity (vector search) +result = agent.tool.mongodb_memory( + action="retrieve", + query="food preferences and dietary restrictions", + max_results=5, + connection_string="mongodb+srv://username:password@cluster0.mongodb.net/?retryWrites=true&w=majority", + database_name="memories", + collection_name="user_memories", + namespace="user_123" +) + +# Use configuration dictionary for cleaner code +config = { + "connection_string": "mongodb+srv://username:password@cluster0.mongodb.net/?retryWrites=true&w=majority", + "database_name": "memories", + "collection_name": "user_memories", + "namespace": "user_123" +} + +# List all memories with pagination +result = agent.tool.mongodb_memory( + action="list", + max_results=10, + **config +) + +# Get specific memory by ID +result = agent.tool.mongodb_memory( + action="get", + memory_id="mem_1234567890_abcd1234", + **config +) + +# Delete a memory +result = agent.tool.mongodb_memory( + action="delete", + memory_id="mem_1234567890_abcd1234", + **config +) + +# Use environment variables for connection +# Set MONGODB_ATLAS_CLUSTER_URI in your environment +result = agent.tool.mongodb_memory( + action="record", + content="User prefers vegetarian pizza", + database_name="memories", + collection_name="user_memories", + namespace="user_123" +) +``` + ## 🌍 Environment Variables Configuration Agents Tools provides extensive customization through environment variables. This allows you to configure tool behavior without modifying code, making it ideal for different environments (development, testing, production). @@ -1117,6 +1191,19 @@ The Mem0 Memory Tool supports three different backend configurations: | TWELVELABS_MARENGO_INDEX_ID | Default index ID for search_video tool | None | | TWELVELABS_PEGASUS_INDEX_ID | Default index ID for chat_video tool | None | +#### MongoDB Atlas Memory Tool + +| Environment Variable | Description | Default | +|----------------------|-------------|---------| +| MONGODB_ATLAS_CLUSTER_URI | MongoDB Atlas connection string | None | +| MONGODB_DEFAULT_DATABASE | Default database name for MongoDB operations | memories | +| MONGODB_DEFAULT_COLLECTION | Default collection name for MongoDB operations | user_memories | +| MONGODB_DEFAULT_NAMESPACE | Default namespace for memory isolation | default | +| MONGODB_DEFAULT_MAX_RESULTS | Default maximum results for list operations | 50 | +| MONGODB_DEFAULT_MIN_SCORE | Default minimum relevance score for filtering results | 0.4 | + +**Note**: This tool requires AWS account credentials to generate embeddings using Amazon Bedrock Titan models. + ## Contributing ❀️ diff --git a/docs/mongodb_memory_tool.md b/docs/mongodb_memory_tool.md new file mode 100644 index 00000000..b2433dbc --- /dev/null +++ b/docs/mongodb_memory_tool.md @@ -0,0 +1,662 @@ +# MongoDB Atlas Memory Tool + +The MongoDB Atlas Memory Tool provides comprehensive memory management capabilities using MongoDB Atlas as the backend with vector embeddings for semantic search. It uses the direct tool pattern where tools are imported and used directly with agents. + +## Features + +- **Semantic Search**: Automatic embedding generation using Amazon Bedrock Titan for vector similarity search +- **Memory Management**: Store, retrieve, list, get, and delete memory operations +- **Index Management**: Automatic vector search index creation with proper configuration +- **Namespace Support**: Organize memories by namespace for multi-user scenarios +- **Pagination**: Support for paginated results in list and retrieve operations +- **Error Handling**: Comprehensive error handling with clear error messages + +## Installation + +Install the required dependencies: + +```bash +pip install strands-agents-tools[mongodb_memory] +``` + +This will install: +- `pymongo>=4.0.0,<5.0.0` - MongoDB Python client + +## Prerequisites + +1. **MongoDB Atlas**: You need a MongoDB Atlas cluster with: + - Connection URI (mongodb+srv format) - [How to find your connection string](https://www.mongodb.com/docs/atlas/connect-to-database-deployment/) + - Database user with read/write permissions - [Create database user](https://www.mongodb.com/docs/atlas/security-add-mongodb-users/) + - Vector Search enabled (Atlas Search) - [Enable Atlas Search](https://www.mongodb.com/docs/atlas/atlas-search/create-index/) + +2. **Amazon Bedrock**: Access to Amazon Bedrock for embedding generation: + - AWS credentials configured + - Access to `amazon.titan-embed-text-v2:0` model (or custom embedding model) + +### Getting Your MongoDB Atlas Connection URI + +If you're new to MongoDB Atlas: + +1. **Sign up for MongoDB Atlas**: Visit [MongoDB Atlas](https://www.mongodb.com/cloud/atlas) and create a free account +2. **Create a cluster**: Follow the setup wizard to create your first cluster (free tier available) +3. **Create a database user**: Go to Database Access β†’ Add New Database User with read/write permissions +4. **Configure network access**: Go to Network Access β†’ Add IP Address (add your current IP or 0.0.0.0/0 for testing) +5. **Get connection string**: + - Go to your cluster in the Atlas dashboard + - Click "Connect" button + - Choose "Connect your application" + - Select "Python" as the driver + - Copy the connection string (it will look like: `mongodb+srv://username:password@cluster0.xxxxx.mongodb.net/`) + - Replace `` with your actual database user password + +**Important**: Your connection URI should be in the format `mongodb+srv://username:password@cluster0.xxxxx.mongodb.net/` without additional query parameters. The tool will handle SSL and other connection settings automatically. + +For detailed instructions, see the [official MongoDB Atlas documentation](https://www.mongodb.com/docs/atlas/connect-to-database-deployment/). + +## Quick Start + +### Class-Based Usage (Recommended) + +```python +from strands_tools.mongodb_memory import MongoDBMemoryTool + +# Initialize the tool +memory_tool = MongoDBMemoryTool() + +# Store a memory +result = memory_tool.record_memory( + content="User prefers vegetarian pizza with extra cheese", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) + +# Search memories +result = memory_tool.retrieve_memories( + query="food preferences", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123", + max_results=5 +) +``` + +### Standalone Function Usage + +```python +from strands_tools.mongodb_memory import mongodb_memory + +# Store a memory +result = mongodb_memory( + action="record", + content="User prefers vegetarian pizza with extra cheese", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) + +# Search memories +result = mongodb_memory( + action="retrieve", + query="food preferences", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123", + max_results=5 +) +``` + +### Environment Variables + +You can also use environment variables for configuration: + +```bash +export MONGODB_ATLAS_CLUSTER_URI="mongodb+srv://user:password@cluster.mongodb.net/" +export MONGODB_DATABASE_NAME="memory_db" +export MONGODB_COLLECTION_NAME="memories" +export MONGODB_NAMESPACE="user_123" +export MONGODB_EMBEDDING_MODEL="amazon.titan-embed-text-v2:0" +export AWS_REGION="us-west-2" +``` + +Then use the tool with minimal parameters (environment variables will be used): + +```python +# Class-based usage +memory_tool = MongoDBMemoryTool() +result = memory_tool.record_memory( + content="User prefers vegetarian pizza" + # cluster_uri, database_name, etc. will be read from environment variables +) + +# Standalone function usage +result = mongodb_memory( + action="record", + content="User prefers vegetarian pizza" + # cluster_uri, database_name, etc. will be read from environment variables +) +``` + +## Usage Examples + +### 1. Store Memories + +```python +# Store a simple memory +result = agent.tool.mongodb_memory( + action="record", + content="User prefers vegetarian pizza with extra cheese and no onions", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) + +# Store a memory with metadata +result = agent.tool.mongodb_memory( + action="record", + content="Meeting scheduled for next Tuesday at 2 PM with the development team", + metadata={ + "category": "meetings", + "priority": "high", + "participants": ["dev_team"], + "date": "2024-01-16" + }, + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) +``` + +### 2. Semantic Search + +```python +# Search for food-related memories +result = agent.tool.mongodb_memory( + action="retrieve", + query="food preferences and dietary restrictions", + max_results=5, + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) + +# Search for meeting information +result = agent.tool.mongodb_memory( + action="retrieve", + query="upcoming meetings and appointments", + max_results=10, + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) +``` + +### 3. List All Memories + +```python +# List recent memories +result = agent.tool.mongodb_memory( + action="list", + max_results=20, + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) + +# List with pagination +result = agent.tool.mongodb_memory( + action="list", + max_results=10, + next_token="10", # Start from the 11th result + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) +``` + +### 4. Get Specific Memory + +```python +# Retrieve a specific memory by ID +result = agent.tool.mongodb_memory( + action="get", + memory_id="mem_1704567890123_abc12345", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) +``` + +### 5. Delete Memory + +```python +# Delete a specific memory +result = memory_tool.delete_memory( + memory_id="mem_1704567890123_abc12345", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) +``` + +## Advanced Configuration + +### Using Configuration Dictionary + +For cleaner code, you can use a configuration dictionary: + +```python +config = { + "cluster_uri": "mongodb+srv://user:password@cluster.mongodb.net/", + "database_name": "memory_db", + "collection_name": "memories", + "namespace": "user_123", + "region": "us-east-1" +} + +# Initialize tool +memory_tool = MongoDBMemoryTool() + +# Store memory +result = memory_tool.record_memory( + content="User prefers vegetarian pizza", + **config +) + +# Search memories +result = memory_tool.retrieve_memories( + query="food preferences", + max_results=5, + **config +) +``` + +### Custom Embedding Model + +```python +result = memory_tool.record_memory( + content="User prefers vegetarian pizza", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + embedding_model="amazon.titan-embed-text-v1:0", # Different model + region="us-east-1" +) +``` + +### Multiple Namespaces + +```python +# User-specific memories +result = memory_tool.record_memory( + content="Alice likes Italian food", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_alice" +) + +# System-wide memories +result = memory_tool.record_memory( + content="System maintenance scheduled", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="system_global" +) +``` + +## Response Format + +All operations return a standardized response format: + +```python +{ + "status": "success", # or "error" + "content": [ + { + "text": "Memory stored successfully: {...}" + } + ] +} +``` + +### Successful Record Response + +```json +{ + "status": "success", + "content": [ + { + "text": "Memory stored successfully: {\"memory_id\": \"mem_1704567890123_abc12345\", \"content\": \"User prefers vegetarian pizza\", \"namespace\": \"user_123\", \"timestamp\": \"2024-01-06T20:31:30.123456Z\", \"result\": \"created\"}" + } + ] +} +``` + +### Successful Retrieve Response + +```json +{ + "status": "success", + "content": [ + { + "text": "Memories retrieved successfully: {\"memories\": [{\"memory_id\": \"mem_123\", \"content\": \"User prefers vegetarian pizza\", \"timestamp\": \"2024-01-06T20:31:30Z\", \"metadata\": {\"category\": \"food\"}, \"score\": 0.95}], \"total\": 1, \"max_score\": 0.95}" + } + ] +} +``` + +## Collection Structure + +The tool automatically creates a MongoDB collection with documents structured as follows: + +```json +{ + "_id": "ObjectId", + "memory_id": "mem_1704567890123_abc12345", + "content": "User prefers vegetarian pizza with extra cheese", + "embedding": [0.1, 0.2, 0.3, ...], // 1024-dimensional vector + "namespace": "user_123", + "timestamp": "2024-01-06T20:31:30.123456Z", + "metadata": { + "category": "food", + "priority": "medium" + } +} +``` + +### Vector Search Index + +The tool automatically creates a vector search index with the following configuration: + +```json +{ + "fields": [ + { + "type": "vector", + "path": "embedding", + "numDimensions": 1024, + "similarity": "cosine" + }, + { + "type": "filter", + "path": "namespace" + } + ] +} +``` + +## Error Handling + +The tool provides comprehensive error handling: + +### Connection Errors + +```python +# Invalid connection URI +memory_tool = MongoDBMemoryTool() +result = memory_tool.record_memory( + content="test", + cluster_uri="mongodb+srv://invalid:credentials@invalid.mongodb.net/" +) +# Returns: {"status": "error", "content": [{"text": "Unable to connect to MongoDB Atlas cluster"}]} +``` + +### Missing Parameters + +```python +# Missing required content for record action +memory_tool = MongoDBMemoryTool() +result = memory_tool.record_memory( + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/" +) +# Returns: {"status": "error", "content": [{"text": "content is required"}]} + +# Missing connection parameters +result = memory_tool.record_memory(content="test") +# Returns: {"status": "error", "content": [{"text": "cluster_uri is required"}]} +``` + +### Memory Not Found + +```python +# Non-existent memory ID +result = memory_tool.get_memory( + memory_id="nonexistent", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/" +) +# Returns: {"status": "error", "content": [{"text": "API error: Memory nonexistent not found"}]} +``` + +## Performance Considerations + +### Embedding Generation + +- Embeddings are generated using Amazon Bedrock Titan model +- Each record and retrieve operation requires embedding generation +- Consider caching strategies for frequently accessed queries + +### Index Optimization + +- The tool creates optimized vector search indices +- Uses cosine similarity for semantic matching +- Configures appropriate index settings for performance + +### Pagination + +- Use pagination for large result sets +- `max_results` parameter controls batch size +- `next_token` enables efficient pagination using skip/limit + +## Best Practices + +### 1. Configuration Management + +Create reusable configuration objects: + +```python +# Create a base configuration +base_config = { + "cluster_uri": "mongodb+srv://user:password@cluster.mongodb.net/", + "database_name": "memory_db", + "region": "us-east-1" +} + +# User-specific configuration +def get_user_config(user_id): + return { + **base_config, + "collection_name": "user_memories", + "namespace": f"user_{user_id}" + } + +# Usage +user_config = get_user_config("alice") +result = agent.tool.mongodb_memory( + action="record", + content="Alice likes Italian food", + **user_config +) +``` + +### 2. Namespace Organization + +The namespace parameter is crucial for data isolation and multi-tenant memory management: + +```python +# User-based namespaces +user_namespace = f"user_{user_id}" + +# Session-based namespaces +session_namespace = f"session_{session_id}" + +# Hierarchical namespaces +org_user_namespace = f"org_{org_id}_user_{user_id}" + +# Feature-based namespaces +chat_namespace = "feature_chat" +task_namespace = "feature_tasks" +``` + +### 3. Metadata Usage + +```python +# Use structured metadata for better organization +result = memory_tool.record_memory( + content="Important project deadline", + metadata={ + "type": "deadline", + "project": "project_alpha", + "priority": "high", + "due_date": "2024-02-01", + "assigned_to": ["alice", "bob"] + }, + **config +) +``` + +### 4. Error Handling + +```python +def safe_memory_operation(memory_tool, operation_method, **kwargs): + try: + result = operation_method(**kwargs) + if result["status"] == "error": + logger.error(f"Memory operation failed: {result['content'][0]['text']}") + return None + return result + except Exception as e: + logger.error(f"Unexpected error in memory operation: {e}") + return None + +# Usage example: +memory_tool = MongoDBMemoryTool() +result = safe_memory_operation( + memory_tool, + memory_tool.record_memory, + content="Test memory", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) +``` + +### 5. Batch Operations + +```python +# Store multiple related memories +memories = [ + "User likes Italian food", + "User is allergic to nuts", + "User prefers evening meetings" +] + +config = { + "cluster_uri": "mongodb+srv://user:password@cluster.mongodb.net/", + "database_name": "memory_db", + "collection_name": "memories", + "namespace": "user_123" +} + +memory_tool = MongoDBMemoryTool() +for content in memories: + memory_tool.record_memory( + content=content, + metadata={"batch": "user_preferences", "timestamp": datetime.now().isoformat()}, + **config + ) +``` + +## Troubleshooting + +### Common Issues + +1. **Connection Timeout** + - Check MongoDB Atlas cluster status + - Verify network connectivity and IP whitelist + - Increase connection timeout settings + +2. **Authentication Errors** + - Verify connection URI format + - Check database user credentials + - Ensure user has proper permissions + +3. **Vector Search Index Issues** + - Verify Atlas Search is enabled + - Check index creation status + - Ensure proper index configuration + +4. **Embedding Generation Failures** + - Verify AWS credentials + - Check Bedrock model access + - Ensure proper IAM permissions + +### Debug Mode + +Enable debug logging for troubleshooting: + +```python +import logging +logging.basicConfig(level=logging.DEBUG) + +# This will show detailed MongoDB and Bedrock API calls +memory_tool = MongoDBMemoryTool() +result = memory_tool.record_memory( + content="test", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/" +) +``` + +### Vector Search Index Creation + +If vector search is not working, manually create the index in MongoDB Atlas: + +1. Go to Atlas Search in your MongoDB Atlas dashboard +2. Create a new search index on your collection +3. Use the JSON configuration provided in the Collection Structure section +4. Wait for the index to build (this can take several minutes) + +## Security Considerations + +### Connection Security + +- Use strong passwords for database users +- Enable IP whitelisting in MongoDB Atlas +- Use connection string with SSL/TLS enabled +- Store connection URIs securely (environment variables, secrets manager) + +### Data Privacy + +- Use appropriate namespaces for data isolation +- Consider encryption at rest (MongoDB Atlas feature) +- Implement proper access controls +- Regular security audits + +### Network Security + +- Use VPC peering for production environments +- Implement proper firewall rules +- Monitor database access logs +- Use private endpoints when available + +## Support and Resources + +- [MongoDB Atlas Documentation](https://docs.atlas.mongodb.com/) +- [MongoDB Atlas Vector Search](https://docs.atlas.mongodb.com/atlas-search/vector-search/) +- [Amazon Bedrock Documentation](https://docs.aws.amazon.com/bedrock/) +- [Strands Agents Framework](https://strandsagents.com/) +- [GitHub Issues](https://github.com/strands-agents/tools/issues) diff --git a/pyproject.toml b/pyproject.toml index a89a686d..cdcefd3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,9 +113,12 @@ twelvelabs = [ elasticsearch_memory = [ "elasticsearch>=8.0.0,<9.0.0" ] +mongodb_memory = [ + "pymongo>=4.0.0,<5.0.0", +] [tool.hatch.envs.hatch-static-analysis] -features = ["mem0_memory", "local_chromium_browser", "agent_core_browser", "agent_core_code_interpreter", "a2a_client", "diagram", "rss", "use_computer", "twelvelabs", "elasticsearch_memory"] +features = ["mem0_memory", "local_chromium_browser", "agent_core_browser", "agent_core_code_interpreter", "a2a_client", "diagram", "rss", "use_computer", "twelvelabs", "elasticsearch_memory", "mongodb_memory"] dependencies = [ "strands-agents>=1.0.0", "mypy>=0.981,<1.0.0", @@ -134,7 +137,7 @@ lint-check = [ lint-fix = ["ruff check --fix"] [tool.hatch.envs.hatch-test] -features = ["mem0_memory", "local_chromium_browser", "agent_core_browser", "agent_core_code_interpreter", "a2a_client", "diagram", "rss", "use_computer", "twelvelabs", "elasticsearch_memory"] +features = ["mem0_memory", "local_chromium_browser", "agent_core_browser", "agent_core_code_interpreter", "a2a_client", "diagram", "rss", "use_computer", "twelvelabs", "elasticsearch_memory", "mongodb_memory"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", diff --git a/src/strands_tools/mongodb_memory.py b/src/strands_tools/mongodb_memory.py new file mode 100644 index 00000000..885ee648 --- /dev/null +++ b/src/strands_tools/mongodb_memory.py @@ -0,0 +1,1110 @@ +""" +Tool for managing memories using MongoDB Atlas with semantic search capabilities. + +This module provides comprehensive memory management capabilities using +MongoDB Atlas as the backend with vector embeddings for semantic search. + +Key Features: +------------ +1. Memory Management: + β€’ record: Store new memories with automatic embedding generation + β€’ retrieve: Semantic search using vector embeddings and MongoDB Atlas Vector Search + β€’ list: List all memories with pagination support + β€’ get: Retrieve specific memories by memory ID + β€’ delete: Remove specific memories by memory ID + +2. Semantic Search: + β€’ Automatic embedding generation using Amazon Bedrock Titan + β€’ Vector similarity search with cosine similarity + β€’ MongoDB Atlas Vector Search with $vectorSearch aggregation + β€’ Namespace-based filtering + +3. Collection Management: + β€’ Automatic collection creation with proper structure + β€’ Vector search index configuration for semantic search + β€’ Optimized for semantic search performance + +4. Error Handling: + β€’ Connection validation + β€’ Parameter validation + β€’ Graceful API error handling + β€’ Clear error messages + +Usage Examples: +-------------- +```python +from strands import Agent +from strands_tools.mongodb_memory import mongodb_memory + +# Store a memory +result = mongodb_memory( + action="record", + content="User prefers vegetarian pizza with extra cheese", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123" +) + +# Search memories +result = mongodb_memory( + action="retrieve", + query="food preferences", + cluster_uri="mongodb+srv://user:password@cluster.mongodb.net/", + database_name="memory_db", + collection_name="memories", + namespace="user_123", + max_results=5 +) +``` + +Environment Variables: +--------------------- +```bash +# Required +export MONGODB_ATLAS_CLUSTER_URI="mongodb+srv://user:pass@cluster.mongodb.net/" + +# Optional +export MONGODB_DATABASE_NAME="custom_memories_db" # Default: "strands_memory" +export MONGODB_COLLECTION_NAME="custom_memories" # Default: "memories" +export MONGODB_NAMESPACE="custom_namespace" # Default: "default" +export MONGODB_EMBEDDING_MODEL="amazon.titan-embed-text-v2:0" +export AWS_REGION="us-east-1" # Default: "us-west-2" +``` +""" + +import json +import logging +import os +import re +import time +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Optional + +import boto3 +from pymongo import MongoClient +from pymongo.collection import Collection +from pymongo.cursor import Cursor +from pymongo.errors import ConnectionFailure +from strands import tool + +# Set up logging +logger = logging.getLogger(__name__) + + +# Custom exceptions for better error handling +class MongoDBMemoryError(Exception): + """Base exception for MongoDB memory operations.""" + + pass + + +class MongoDBConnectionError(MongoDBMemoryError): + """Raised when connection to MongoDB fails.""" + + pass + + +class MongoDBMemoryNotFoundError(MongoDBMemoryError): + """Raised when a memory record is not found.""" + + pass + + +class MongoDBEmbeddingError(MongoDBMemoryError): + """Raised when embedding generation fails.""" + + pass + + +class MongoDBValidationError(MongoDBMemoryError): + """Raised when parameter validation fails.""" + + pass + + +# Define memory actions as an Enum +class MemoryAction(str, Enum): + """Enum for memory actions.""" + + RECORD = "record" + RETRIEVE = "retrieve" + LIST = "list" + GET = "get" + DELETE = "delete" + + +# Define required parameters for each action +REQUIRED_PARAMS = { + MemoryAction.RECORD: ["content"], + MemoryAction.RETRIEVE: ["query"], + MemoryAction.LIST: [], + MemoryAction.GET: ["memory_id"], + MemoryAction.DELETE: ["memory_id"], +} + +# Default settings +DEFAULT_DATABASE_NAME = "strands_memory" +DEFAULT_COLLECTION_NAME = "memories" +DEFAULT_EMBEDDING_MODEL = "amazon.titan-embed-text-v2:0" +DEFAULT_EMBEDDING_DIMS = 1024 # Titan v2 returns 1024 dimensions +DEFAULT_MAX_RESULTS = 10 +DEFAULT_VECTOR_INDEX_NAME = "vector_index" +DEFAULT_AWS_REGION = "us-west-2" +DEFAULT_NAMESPACE = "default" + +# MongoDB projection constants +INCLUDE_FIELD = 1 +EXCLUDE_FIELD = 0 + +# Response size limits to prevent "tool result too large" errors +MAX_RESPONSE_SIZE = 70000 # Maximum characters in response (70K total safety margin) +MAX_CONTENT_LENGTH = 12000 # Maximum content length per memory (12K per memory) +MAX_MEMORIES_IN_RESPONSE = 5 # Maximum memories to include in responses + +# Index creation settings +INDEX_CREATION_TIMEOUT = 5 # seconds to wait for index creation + + +def _ensure_vector_search_index(collection: Collection, index_name: str = DEFAULT_VECTOR_INDEX_NAME) -> None: + """ + Create vector search index if it doesn't exist. + + This function ensures that the required vector search index exists for semantic search operations. + If the index doesn't exist, it creates one with the proper configuration for 1024-dimensional + Titan embeddings using cosine similarity. + + Args: + collection: MongoDB collection to create index on + index_name: Name of the vector search index to create + """ + try: + # Check if index exists + existing_indexes = list(collection.list_search_indexes()) + index_exists = any(idx.get("name") == index_name for idx in existing_indexes) + + if not index_exists: + # Create vector search index with proper mappings + index_definition = { + "name": index_name, + "definition": { + "mappings": { + "dynamic": False, + "fields": { + "embedding": { + "type": "knnVector", + "dimensions": DEFAULT_EMBEDDING_DIMS, + "similarity": "cosine", + }, + "namespace": {"type": "string"}, + }, + } + }, + } + + collection.create_search_index(index_definition) + logger.info(f"Created vector search index: {index_name}") + logger.info("Index creation initiated - it may take a few minutes to become available") + + except Exception as e: + logger.warning(f"Could not create vector search index {index_name}: {str(e)}") + logger.info("Vector search index should be created manually in MongoDB Atlas UI") + # Don't raise exception - allow the tool to work without vector search + + +def _generate_embedding(bedrock_runtime: Any, text: str, embedding_model: str) -> List[float]: + """ + Generate embeddings for text using Amazon Bedrock Titan. + + This method generates 1024-dimensional vector embeddings using Amazon Bedrock's + Titan embedding model. These embeddings are used for semantic similarity search. + + Args: + bedrock_runtime: Bedrock runtime client + text: Text to generate embeddings for + embedding_model: Model ID for embedding generation + + Returns: + List of 1024 float values representing the text embedding + + Raises: + Exception: If embedding generation fails + """ + try: + response = bedrock_runtime.invoke_model(modelId=embedding_model, body=json.dumps({"inputText": text})) + + try: + response_body = json.loads(response["body"].read()) + except json.JSONDecodeError as e: + raise MongoDBEmbeddingError(f"Invalid JSON response from Bedrock: {str(e)}") from e + + # Extract embedding from Bedrock response + # According to Amazon Bedrock Titan Embedding API documentation: + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html + # The response contains an "embedding" field with the vector values + embedding = response_body["embedding"] + + # Validate embedding dimensions + if len(embedding) != DEFAULT_EMBEDDING_DIMS: + raise MongoDBEmbeddingError(f"Expected {DEFAULT_EMBEDDING_DIMS} dimensions, got {len(embedding)}") + + return embedding + + except MongoDBEmbeddingError: + raise + except Exception as e: + raise MongoDBEmbeddingError(f"Embedding generation failed: {str(e)}") from e + + +def _truncate_content(content: str, max_length: int = MAX_CONTENT_LENGTH) -> str: + """Truncate content to prevent large responses.""" + if len(content) <= max_length: + return content + return content[:max_length] + "..." + + +def _optimize_response_size(response: Dict, action: str) -> Dict: + """Optimize response size to prevent 'tool result too large' errors.""" + + # For list and retrieve operations, limit the number of memories and truncate content + if action in ["list", "retrieve"] and "memories" in response: + memories = response["memories"] + + # Limit number of memories in response + if len(memories) > MAX_MEMORIES_IN_RESPONSE: + memories = memories[:MAX_MEMORIES_IN_RESPONSE] + response["memories"] = memories + response["truncated"] = True + response["showing"] = len(memories) + + # Truncate content in each memory + for memory in memories: + if "content" in memory: + memory["content"] = _truncate_content(memory["content"]) + + # Remove verbose search_info for retrieve operations to save space + if action == "retrieve" and "search_info" in response: + response["search_info"] = {"type": "vector_search", "model": "titan-v2"} + + return response + + +def _validate_response_size(response_text: str) -> str: + """Validate and truncate response if it exceeds size limits.""" + if len(response_text) <= MAX_RESPONSE_SIZE: + return response_text + + # If response is too large, truncate and add warning + truncated = response_text[: MAX_RESPONSE_SIZE - 100] # Leave room for warning + return f"{truncated}... [Response truncated due to size limit]" + + +def _mask_connection_string(connection_string: str) -> str: + """ + Mask sensitive information in MongoDB connection string for logging/error messages. + + This function helps prevent credential exposure in logs and error messages by + masking the username and password portions of MongoDB connection strings. + + Args: + connection_string: MongoDB connection string that may contain credentials + + Returns: + Masked connection string safe for logging + """ + if not connection_string: + return "[EMPTY]" + + try: + # Pattern to match mongodb+srv://username:password@host/... + pattern = r"mongodb\+srv://([^:]+):([^@]+)@(.+)" + match = re.match(pattern, connection_string) + + if match: + username, password, rest = match.groups() + masked_username = username[:2] + "***" if len(username) > 2 else "***" + return f"mongodb+srv://{masked_username}:***@{rest}" + + # Fallback for other patterns + if "@" in connection_string: + parts = connection_string.split("@") + if len(parts) >= 2: + return f"***@{parts[-1]}" + + return "***[MASKED_CONNECTION_STRING]***" + except Exception: + return "***[MASKED_CONNECTION_STRING]***" + + +def _validate_connection_string(cluster_uri: str) -> bool: + """ + Validate MongoDB connection string format. + + Args: + cluster_uri: MongoDB connection string to validate + + Returns: + True if connection string appears valid, False otherwise + """ + if not cluster_uri or not isinstance(cluster_uri, str): + return False + + # Basic validation for MongoDB Atlas connection strings + return (cluster_uri.startswith("mongodb+srv://") or cluster_uri.startswith("mongodb://")) and "@" in cluster_uri + + +def _generate_memory_id() -> str: + """Generate a unique memory ID.""" + timestamp = int(time.time() * 1000) # milliseconds + unique_id = str(uuid.uuid4())[:8] + return f"mem_{timestamp}_{unique_id}" + + +def _record_memory( + collection: Collection, + bedrock_runtime: Any, + namespace: str, + embedding_model: str, + content: str, + metadata: Optional[Dict] = None, +) -> Dict: + """ + Store a memory in MongoDB with embedding. + + Args: + collection: MongoDB collection + bedrock_runtime: Bedrock runtime client + namespace: Memory namespace + embedding_model: Embedding model ID + content: Text content to store + metadata: Optional metadata dictionary + + Returns: + Dict containing the stored memory information + """ + # Generate unique memory ID + memory_id = _generate_memory_id() + + # Generate embedding for semantic search + embedding = _generate_embedding(bedrock_runtime, content, embedding_model) + + # Prepare document + doc = { + "memory_id": memory_id, + "content": content, + "embedding": embedding, + "namespace": namespace, + "timestamp": datetime.now(timezone.utc).isoformat(), + "metadata": metadata or {}, + } + + # Store in MongoDB + result = collection.insert_one(doc) + + # Return filtered response without embedding vectors (only metadata) + return { + "memory_id": memory_id, + "content": content, + "namespace": namespace, + "timestamp": doc["timestamp"], + "result": "created" if result.inserted_id else "failed", + "embedding_info": {"model": embedding_model, "dimensions": len(embedding), "generated": True}, + } + + +def _retrieve_memories( + collection: Collection, + bedrock_runtime: Any, + namespace: str, + embedding_model: str, + query: str, + max_results: int, + next_token: Optional[str] = None, + index_name: str = DEFAULT_VECTOR_INDEX_NAME, +) -> Dict: + """ + Retrieve memories using semantic search. + + Args: + collection: MongoDB collection + bedrock_runtime: Bedrock runtime client + namespace: Memory namespace + embedding_model: Embedding model ID + query: Search query + max_results: Maximum number of results + next_token: Pagination token (skip count for MongoDB) + index_name: Vector search index name + + Returns: + Dict containing search results + """ + # Generate embedding for query + query_embedding = _generate_embedding(bedrock_runtime, query, embedding_model) + + # Calculate skip from next_token + skip_count = int(next_token) if next_token else 0 + + # Perform semantic search using MongoDB Atlas Vector Search + pipeline = [ + { + "$vectorSearch": { + "index": index_name, + "path": "embedding", + "queryVector": query_embedding, + "numCandidates": max_results * 3, # Use 3x candidates for better search quality + "limit": max_results, + "filter": {"namespace": {"$eq": namespace}}, + } + }, + {"$skip": skip_count}, + {"$limit": max_results}, + {"$addFields": {"score": {"$meta": "vectorSearchScore"}}}, + # MongoDB projection syntax: INCLUDE_FIELD = include field, EXCLUDE_FIELD = exclude field + # We exclude _id (MongoDB's internal ObjectId) and embedding vectors, include only the fields we need + { + "$project": { + "memory_id": INCLUDE_FIELD, + "content": INCLUDE_FIELD, + "timestamp": INCLUDE_FIELD, + "metadata": INCLUDE_FIELD, + "score": INCLUDE_FIELD, + "_id": EXCLUDE_FIELD, + } + }, + ] + + results = list(collection.aggregate(pipeline)) + + # Get total count for pagination + total_pipeline = [ + { + "$vectorSearch": { + "index": index_name, + "path": "embedding", + "queryVector": query_embedding, + "numCandidates": 1000, # Higher limit for count + "limit": 1000, + "filter": {"namespace": {"$eq": namespace}}, + } + }, + {"$count": "total"}, + ] + + try: + total_result = list(collection.aggregate(total_pipeline)) + total_count = total_result[0]["total"] if total_result and len(total_result) > 0 else len(results) + except Exception: + # Fallback to result count if aggregation fails + total_count = len(results) + + # Format results + memories = [] + max_score = 0 + for doc in results: + memory = { + "memory_id": doc["memory_id"], + "content": doc["content"], + "timestamp": doc["timestamp"], + "metadata": doc.get("metadata", {}), + "score": doc.get("score", 0), + } + memories.append(memory) + max_score = max(max_score, doc.get("score", 0)) + + result = { + "memories": memories, + "total": total_count, + "max_score": max_score, + "search_info": { + "query_embedding_generated": True, + "search_type": "MongoDB Atlas Vector Search", + "embedding_model": embedding_model, + "embedding_dimensions": DEFAULT_EMBEDDING_DIMS, + "similarity_function": "cosine", + }, + } + + # Add next_token if there are more results + if skip_count + max_results < total_count: + result["next_token"] = str(skip_count + max_results) + + return result + + +def _list_memories(collection: Collection, namespace: str, max_results: int, next_token: Optional[str] = None) -> Dict: + """ + List all memories in the namespace. + + Args: + collection: MongoDB collection + namespace: Memory namespace + max_results: Maximum number of results + next_token: Pagination token + + Returns: + Dict containing all memories + """ + # Calculate skip from next_token + skip_count = int(next_token) if next_token else 0 + + # Query for memories in namespace + # MongoDB projection syntax: INCLUDE_FIELD = include field, EXCLUDE_FIELD = exclude field + # We exclude _id (MongoDB's internal ObjectId) and embedding vectors, include only the fields we need + cursor: Cursor = ( + collection.find( + {"namespace": namespace}, + { + "memory_id": INCLUDE_FIELD, + "content": INCLUDE_FIELD, + "timestamp": INCLUDE_FIELD, + "metadata": INCLUDE_FIELD, + "_id": EXCLUDE_FIELD, + }, + ) + .sort("timestamp", -1) + .skip(skip_count) + .limit(max_results) + ) + + memories = list(cursor) + + # Get total count + total_count = collection.count_documents({"namespace": namespace}) + + result = {"memories": memories, "total": total_count} + + # Add next_token if there are more results + if skip_count + max_results < total_count: + result["next_token"] = str(skip_count + max_results) + + return result + + +def _get_memory(collection: Collection, namespace: str, memory_id: str) -> Dict: + """ + Get a specific memory by ID. + + Args: + collection: MongoDB collection + namespace: Memory namespace + memory_id: Memory ID to retrieve + + Returns: + Dict containing the memory + + Raises: + Exception: If memory not found or not in correct namespace + """ + try: + # MongoDB projection syntax: INCLUDE_FIELD = include field, EXCLUDE_FIELD = exclude field + # We exclude _id (MongoDB's internal ObjectId) and embedding vectors, include only the fields we need + doc = collection.find_one( + {"memory_id": memory_id}, + { + "memory_id": INCLUDE_FIELD, + "content": INCLUDE_FIELD, + "timestamp": INCLUDE_FIELD, + "metadata": INCLUDE_FIELD, + "namespace": INCLUDE_FIELD, + "_id": EXCLUDE_FIELD, + }, + ) + + if not doc: + raise MongoDBMemoryNotFoundError(f"Memory {memory_id} not found") + + # Verify namespace + if doc.get("namespace") != namespace: + raise MongoDBMemoryNotFoundError(f"Memory {memory_id} not found in namespace {namespace}") + + return { + "memory_id": doc["memory_id"], + "content": doc["content"], + "timestamp": doc["timestamp"], + "metadata": doc.get("metadata", {}), + "namespace": doc["namespace"], + } + + except MongoDBMemoryNotFoundError: + raise + except Exception as e: + raise MongoDBMemoryError(f"Failed to get memory {memory_id}: {str(e)}") from e + + +def _delete_memory(collection: Collection, namespace: str, memory_id: str) -> Dict: + """ + Delete a specific memory by ID. + + Args: + collection: MongoDB collection + namespace: Memory namespace + memory_id: Memory ID to delete + + Returns: + Dict containing deletion result + + Raises: + Exception: If memory not found or deletion fails + """ + try: + # First verify the memory exists and is in correct namespace + _get_memory(collection, namespace, memory_id) + + # Delete the memory + result = collection.delete_one({"memory_id": memory_id, "namespace": namespace}) + + if result.deleted_count == 0: + raise MongoDBMemoryNotFoundError(f"Memory {memory_id} not found") + + # Return minimal response to avoid size issues + return {"memory_id": memory_id, "result": "deleted"} + + except MongoDBMemoryNotFoundError: + raise + except Exception as e: + raise MongoDBMemoryError(f"Failed to delete memory {memory_id}: {str(e)}") from e + + +class MongoDBMemoryTool: + """ + MongoDB Atlas Memory Tool with secure credential management. + + This class encapsulates MongoDB Atlas connection credentials and configuration, + preventing agents from accessing sensitive information like passwords and connection strings. + """ + + def __init__( + self, + cluster_uri: Optional[str] = None, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + embedding_model: Optional[str] = None, + region: Optional[str] = None, + vector_index_name: Optional[str] = None, + ): + """ + Initialize MongoDB Memory Tool with secure credential storage. + + Args: + cluster_uri: MongoDB Atlas cluster URI (kept private from agents) + database_name: Name of the MongoDB database + collection_name: Name of the MongoDB collection + embedding_model: Amazon Bedrock model for embeddings + region: AWS region for Bedrock service + vector_index_name: Name of the vector search index + """ + # Private attributes - not accessible to agents + self._cluster_uri = cluster_uri or os.getenv("MONGODB_ATLAS_CLUSTER_URI") + self._database_name = database_name or os.getenv("MONGODB_DATABASE_NAME", DEFAULT_DATABASE_NAME) + self._collection_name = collection_name or os.getenv("MONGODB_COLLECTION_NAME", DEFAULT_COLLECTION_NAME) + self._embedding_model = embedding_model or os.getenv("MONGODB_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL) + self._region = region or os.getenv("AWS_REGION", DEFAULT_AWS_REGION) + self._vector_index_name = vector_index_name or DEFAULT_VECTOR_INDEX_NAME + + # Validate credentials during initialization + if not self._cluster_uri: + raise MongoDBValidationError("cluster_uri is required for MongoDB Memory Tool initialization") + + if not _validate_connection_string(self._cluster_uri): + raise MongoDBValidationError("Invalid MongoDB connection string format") + + @tool + def mongodb_memory( + self, + action: str, + content: Optional[str] = None, + query: Optional[str] = None, + memory_id: Optional[str] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + metadata: Optional[Dict] = None, + namespace: Optional[str] = None, + ) -> Dict: + """ + Work with MongoDB Atlas memories - create, search, retrieve, list, and manage memory records. + + This tool helps agents store and access memories using MongoDB Atlas with semantic search + capabilities, allowing them to remember important information across conversations. + + Note: Credentials are securely managed by the class and not exposed to agents. + + Key Capabilities: + - Store new memories with automatic embedding generation + - Search for memories using semantic similarity + - Browse and list all stored memories + - Retrieve specific memories by ID + - Delete unwanted memories + + Supported Actions: + ----------------- + Memory Management: + - record: Store a new memory with semantic embeddings + - retrieve: Find relevant memories using semantic search + - list: Browse all stored memories with pagination + - get: Fetch a specific memory by ID + - delete: Remove a specific memory + + Args: + action: The memory operation to perform (one of: "record", "retrieve", "list", "get", "delete") + content: For record action: Text content to store as a memory + query: Search terms for semantic search (required for retrieve action) + memory_id: ID of a specific memory (required for get and delete actions) + max_results: Maximum number of results to return (optional, default: 10) + next_token: Pagination token for list action (optional) + metadata: Additional metadata to store with the memory (optional) + namespace: Namespace for memory operations (defaults to 'default') + + Returns: + Dict: Response containing the requested memory information or operation status + """ + try: + # Use private configuration (credentials not exposed to agents) + namespace = namespace or os.getenv("MONGODB_NAMESPACE", DEFAULT_NAMESPACE) + max_results = max_results or DEFAULT_MAX_RESULTS + + # Initialize MongoDB client with secure error handling + try: + client = MongoClient(self._cluster_uri, serverSelectionTimeoutMS=5000) + # Test connection + client.admin.command("ping") + + database = client[self._database_name] + collection = database[self._collection_name] + + except ConnectionFailure as e: + # Use masked connection string in error messages for security + masked_uri = _mask_connection_string(self._cluster_uri) + logger.error(f"MongoDB connection failed for {masked_uri}: {str(e)}") + return { + "status": "error", + "content": [{"text": f"Unable to connect to MongoDB cluster at {masked_uri}"}], + } + except Exception as e: + # Use masked connection string in error messages for security + masked_uri = _mask_connection_string(self._cluster_uri) + logger.error(f"MongoDB client initialization failed for {masked_uri}: {str(e)}") + return { + "status": "error", + "content": [{"text": f"Failed to initialize MongoDB client for {masked_uri}"}], + } + + # Initialize Amazon Bedrock client for embeddings + try: + bedrock_runtime = boto3.client("bedrock-runtime", region_name=self._region) + except Exception as e: + return {"status": "error", "content": [{"text": f"Failed to initialize Bedrock client: {str(e)}"}]} + + # Ensure vector search index exists for retrieve operations + if action in [MemoryAction.RETRIEVE.value]: + _ensure_vector_search_index(collection, self._vector_index_name) + + # Validate action + try: + action_enum = MemoryAction(action) + except ValueError: + return { + "status": "error", + "content": [ + { + "text": f"Action '{action}' is not supported. " + f"Supported actions: {', '.join([a.value for a in MemoryAction])}" + } + ], + } + + # Validate required parameters + param_values = { + "content": content, + "query": query, + "memory_id": memory_id, + } + + missing_params = [param for param in REQUIRED_PARAMS[action_enum] if param_values.get(param) is None] + + if missing_params: + return { + "status": "error", + "content": [ + { + "text": ( + f"The following parameters are required for {action_enum.value} action: " + f"{', '.join(missing_params)}" + ) + } + ], + } + + # Execute the appropriate action + try: + if action_enum == MemoryAction.RECORD: + response = _record_memory( + collection, bedrock_runtime, namespace, self._embedding_model, content, metadata + ) + return { + "status": "success", + "content": [{"text": "Memory stored successfully"}, {"json": response}], + } + + elif action_enum == MemoryAction.RETRIEVE: + response = _retrieve_memories( + collection, + bedrock_runtime, + namespace, + self._embedding_model, + query, + max_results, + next_token, + self._vector_index_name, + ) + # Optimize response size for retrieve operations + optimized_response = _optimize_response_size(response, "retrieve") + return { + "status": "success", + "content": [{"text": "Memories retrieved successfully"}, {"json": optimized_response}], + } + + elif action_enum == MemoryAction.LIST: + response = _list_memories(collection, namespace, max_results, next_token) + # Optimize response size for list operations + optimized_response = _optimize_response_size(response, "list") + return { + "status": "success", + "content": [{"text": "Memories listed successfully"}, {"json": optimized_response}], + } + + elif action_enum == MemoryAction.GET: + response = _get_memory(collection, namespace, memory_id) + return { + "status": "success", + "content": [{"text": "Memory retrieved successfully"}, {"json": response}], + } + + elif action_enum == MemoryAction.DELETE: + response = _delete_memory(collection, namespace, memory_id) + return { + "status": "success", + "content": [{"text": f"Memory deleted successfully: {memory_id}"}], + } + + except Exception as e: + error_msg = f"API error: {str(e)}" + logger.error(error_msg) + return {"status": "error", "content": [{"text": error_msg}]} + + except Exception as e: + logger.error(f"Unexpected error in mongodb_memory tool: {str(e)}") + return {"status": "error", "content": [{"text": str(e)}]} + + +@tool +def mongodb_memory( + action: str, + content: Optional[str] = None, + query: Optional[str] = None, + memory_id: Optional[str] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + metadata: Optional[Dict] = None, + cluster_uri: Optional[str] = None, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + namespace: Optional[str] = None, + embedding_model: Optional[str] = None, + region: Optional[str] = None, + vector_index_name: Optional[str] = None, +) -> Dict: + """ + Work with MongoDB Atlas memories - create, search, retrieve, list, and manage memory records. + + This tool helps agents store and access memories using MongoDB Atlas with semantic search + capabilities, allowing them to remember important information across conversations. + + Key Capabilities: + - Store new memories with automatic embedding generation + - Search for memories using semantic similarity + - Browse and list all stored memories + - Retrieve specific memories by ID + - Delete unwanted memories + + Supported Actions: + ----------------- + Memory Management: + - record: Store a new memory with semantic embeddings + - retrieve: Find relevant memories using semantic search + - list: List all memories with pagination support + - get: Retrieve specific memories by memory ID + - delete: Remove specific memories by memory ID + + Args: + action: The memory operation to perform (one of: "record", "retrieve", "list", "get", "delete") + content: For record action: Text content to store as a memory + query: Search terms for semantic search (required for retrieve action) + memory_id: ID of a specific memory (required for get and delete actions) + max_results: Maximum number of results to return (optional, default: 10) + next_token: Pagination token for list action (optional) + metadata: Additional metadata to store with the memory (optional) + cluster_uri: MongoDB Atlas cluster URI (optional if set via environment) + database_name: Name of the MongoDB database (optional, defaults to 'strands_memory') + collection_name: Name of the MongoDB collection (optional, defaults to 'memories') + namespace: Namespace for memory operations (defaults to 'default') + embedding_model: Amazon Bedrock model for embeddings (defaults to Titan) + region: AWS region for Bedrock service (defaults to 'us-west-2') + vector_index_name: Name of the vector search index (defaults to 'vector_index') + + Returns: + Dict: Response containing the requested memory information or operation status + """ + try: + # Get values from environment variables if not provided + cluster_uri = cluster_uri or os.getenv("MONGODB_ATLAS_CLUSTER_URI") + database_name = database_name or os.getenv("MONGODB_DATABASE_NAME", DEFAULT_DATABASE_NAME) + collection_name = collection_name or os.getenv("MONGODB_COLLECTION_NAME", DEFAULT_COLLECTION_NAME) + embedding_model = embedding_model or os.getenv("MONGODB_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL) + region = region or os.getenv("AWS_REGION", DEFAULT_AWS_REGION) + vector_index_name = vector_index_name or DEFAULT_VECTOR_INDEX_NAME + namespace = namespace or os.getenv("MONGODB_NAMESPACE", DEFAULT_NAMESPACE) + max_results = max_results or DEFAULT_MAX_RESULTS + + # Validate required parameters + if not cluster_uri: + return { + "status": "error", + "content": [ + { + "text": ( + "cluster_uri is required for MongoDB Memory Tool. " + "Set MONGODB_ATLAS_CLUSTER_URI environment variable or provide cluster_uri parameter." + ) + } + ], + } + + if not _validate_connection_string(cluster_uri): + return {"status": "error", "content": [{"text": "Invalid MongoDB connection string format"}]} + + # Initialize MongoDB client with secure error handling + try: + client = MongoClient(cluster_uri, serverSelectionTimeoutMS=5000) + # Test connection + client.admin.command("ping") + + database = client[database_name] + collection = database[collection_name] + + except ConnectionFailure as e: + # Use masked connection string in error messages for security + masked_uri = _mask_connection_string(cluster_uri) + logger.error(f"MongoDB connection failed for {masked_uri}: {str(e)}") + return {"status": "error", "content": [{"text": f"Unable to connect to MongoDB cluster at {masked_uri}"}]} + except Exception as e: + # Use masked connection string in error messages for security + masked_uri = _mask_connection_string(cluster_uri) + logger.error(f"MongoDB client initialization failed for {masked_uri}: {str(e)}") + return {"status": "error", "content": [{"text": f"Failed to initialize MongoDB client for {masked_uri}"}]} + + # Initialize Amazon Bedrock client for embeddings + try: + bedrock_runtime = boto3.client("bedrock-runtime", region_name=region) + except Exception as e: + return {"status": "error", "content": [{"text": f"Failed to initialize Bedrock client: {str(e)}"}]} + + # Ensure vector search index exists for retrieve operations + if action in [MemoryAction.RETRIEVE.value]: + _ensure_vector_search_index(collection, vector_index_name) + + # Validate action + try: + action_enum = MemoryAction(action) + except ValueError: + return { + "status": "error", + "content": [ + { + "text": f"Action '{action}' is not supported. " + f"Supported actions: {', '.join([a.value for a in MemoryAction])}" + } + ], + } + + # Validate required parameters + param_values = { + "content": content, + "query": query, + "memory_id": memory_id, + } + + missing_params = [param for param in REQUIRED_PARAMS[action_enum] if param_values.get(param) is None] + + if missing_params: + return { + "status": "error", + "content": [ + { + "text": ( + f"The following parameters are required for {action_enum.value} action: " + f"{', '.join(missing_params)}" + ) + } + ], + } + + # Execute the appropriate action + try: + if action_enum == MemoryAction.RECORD: + response = _record_memory(collection, bedrock_runtime, namespace, embedding_model, content, metadata) + return { + "status": "success", + "content": [{"text": "Memory stored successfully"}, {"json": response}], + } + + elif action_enum == MemoryAction.RETRIEVE: + response = _retrieve_memories( + collection, + bedrock_runtime, + namespace, + embedding_model, + query, + max_results, + next_token, + vector_index_name, + ) + # Optimize response size for retrieve operations + optimized_response = _optimize_response_size(response, "retrieve") + return { + "status": "success", + "content": [{"text": "Memories retrieved successfully"}, {"json": optimized_response}], + } + + elif action_enum == MemoryAction.LIST: + response = _list_memories(collection, namespace, max_results, next_token) + # Optimize response size for list operations + optimized_response = _optimize_response_size(response, "list") + return { + "status": "success", + "content": [{"text": "Memories listed successfully"}, {"json": optimized_response}], + } + + elif action_enum == MemoryAction.GET: + response = _get_memory(collection, namespace, memory_id) + return { + "status": "success", + "content": [{"text": "Memory retrieved successfully"}, {"json": response}], + } + + elif action_enum == MemoryAction.DELETE: + response = _delete_memory(collection, namespace, memory_id) + return { + "status": "success", + "content": [{"text": f"Memory deleted successfully: {memory_id}"}], + } + + except Exception as e: + error_msg = f"API error: {str(e)}" + logger.error(error_msg) + return {"status": "error", "content": [{"text": error_msg}]} + + except Exception as e: + logger.error(f"Unexpected error in mongodb_memory tool: {str(e)}") + return {"status": "error", "content": [{"text": str(e)}]} diff --git a/tests/test_mongodb_memory.py b/tests/test_mongodb_memory.py new file mode 100644 index 00000000..ed747a0f --- /dev/null +++ b/tests/test_mongodb_memory.py @@ -0,0 +1,840 @@ +""" +Tests for the mongodb_memory tool. +""" + +import json +import os +from unittest import mock +from unittest.mock import MagicMock + +import pytest +from strands import Agent + +from src.strands_tools.mongodb_memory import mongodb_memory, MongoDBMemoryTool + + +@pytest.fixture +def mock_mongodb_client(): + """Mock MongoDB client to avoid actual connections.""" + with mock.patch("src.strands_tools.mongodb_memory.MongoClient") as mock_mongo: + # Create mock client instance + mock_client = MagicMock() + mock_mongo.return_value = mock_client + + # Configure admin.command to return success (ping test) + mock_client.admin.command.return_value = {"ok": 1} + + # Create mock database and collection + mock_database = MagicMock() + mock_collection = MagicMock() + mock_client.__getitem__.return_value = mock_database + mock_database.__getitem__.return_value = mock_collection + + # Configure collection methods + mock_collection.list_search_indexes.return_value = [] + mock_collection.create_search_index.return_value = None + + yield { + "mongo_class": mock_mongo, + "client": mock_client, + "database": mock_database, + "collection": mock_collection, + } + + +@pytest.fixture +def mock_bedrock_client(): + """Mock Amazon Bedrock client for embeddings.""" + with mock.patch("boto3.client") as mock_boto_client: + # Create mock bedrock runtime client + mock_bedrock = MagicMock() + + # Configure boto3.client to return our mock for bedrock-runtime + def client_side_effect(service, **kwargs): + if service == "bedrock-runtime": + return mock_bedrock + return MagicMock() + + mock_boto_client.side_effect = client_side_effect + + # Configure embedding response + mock_response = MagicMock() + mock_response.__getitem__.return_value.read.return_value = json.dumps( + { + "embedding": [0.1] * 1024 # Mock 1024-dimensional embedding (Titan v2) + } + ).encode() + mock_bedrock.invoke_model.return_value = mock_response + + yield { + "boto_client": mock_boto_client, + "bedrock": mock_bedrock, + } + + +@pytest.fixture +def agent(mock_mongodb_client, mock_bedrock_client): + """Create an agent with the direct mongodb_memory tool.""" + return Agent(tools=[mongodb_memory]) + + +@pytest.fixture +def config(): + """Configuration parameters for testing.""" + return { + "cluster_uri": "mongodb+srv://test:test@cluster.mongodb.net/", + "database_name": "test_db", + "collection_name": "test_collection", + "namespace": "test_namespace", + "region": "us-east-1", + } + + +def test_missing_required_params(mock_mongodb_client, mock_bedrock_client): + """Test tool with missing required parameters.""" + agent = Agent(tools=[mongodb_memory]) + + # Test missing cluster_uri + result = agent.tool.mongodb_memory(action="record", content="test") + assert result["status"] == "error" + assert "cluster_uri is required for MongoDB Memory Tool" in result["content"][0]["text"] + + +def test_connection_failure(mock_mongodb_client, mock_bedrock_client): + """Test tool with connection failure.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure admin.command to raise ConnectionFailure + from pymongo.errors import ConnectionFailure + + mock_mongodb_client["client"].admin.command.side_effect = ConnectionFailure("Connection failed") + + result = agent.tool.mongodb_memory( + action="record", content="test", cluster_uri="mongodb+srv://test:test@cluster.mongodb.net/" + ) + + assert result["status"] == "error" + assert "Unable to connect to MongoDB cluster" in result["content"][0]["text"] + + +def test_vector_index_creation(mock_mongodb_client, mock_bedrock_client, config): + """Test that vector search index is created with proper configuration.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock responses + mock_mongodb_client["collection"].insert_one.return_value = MagicMock(inserted_id="test_id") + + agent.tool.mongodb_memory(action="record", content="Test content", **config) + + # Verify index creation was called for record (it shouldn't be) + # Index creation only happens for retrieve operations + mock_mongodb_client["collection"].create_search_index.assert_not_called() + + # Test retrieve action which should create index + mock_mongodb_client["collection"].aggregate.return_value = [] + agent.tool.mongodb_memory(action="retrieve", query="test query", **config) + + # Verify index creation was called + mock_mongodb_client["collection"].create_search_index.assert_called_once() + + # Get the call arguments + call_args = mock_mongodb_client["collection"].create_search_index.call_args[0][0] + assert call_args["name"] == "vector_index" + assert call_args["definition"]["mappings"]["fields"]["embedding"]["type"] == "knnVector" + assert call_args["definition"]["mappings"]["fields"]["embedding"]["dimensions"] == 1024 + assert call_args["definition"]["mappings"]["fields"]["embedding"]["similarity"] == "cosine" + assert call_args["definition"]["mappings"]["fields"]["namespace"]["type"] == "string" + + +def test_record_memory(mock_mongodb_client, mock_bedrock_client, config): + """Test recording a memory.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock responses + mock_result = MagicMock() + mock_result.inserted_id = "test_object_id" + mock_mongodb_client["collection"].insert_one.return_value = mock_result + + # Call the tool + result = agent.tool.mongodb_memory( + action="record", content="Test memory content", metadata={"category": "test"}, **config + ) + + # Verify success response + assert result["status"] == "success" + assert "text" in result["content"][0] + assert "Memory stored successfully" in result["content"][0]["text"] + # Get JSON from second content item + assert "json" in result["content"][1] + response_data = result["content"][1]["json"] + assert "memory_id" in response_data + assert response_data["content"] == "Test memory content" + + # Verify MongoDB insert was called + mock_mongodb_client["collection"].insert_one.assert_called_once() + + # Verify embedding generation was called + mock_bedrock_client["bedrock"].invoke_model.assert_called_once() + + +def test_retrieve_memories(mock_mongodb_client, mock_bedrock_client, config): + """Test retrieving memories with semantic search.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock search response + mock_mongodb_client["collection"].aggregate.return_value = [ + { + "memory_id": "mem_123", + "content": "Test content", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {}, + "score": 0.95, + } + ] + + # Call the tool + result = agent.tool.mongodb_memory(action="retrieve", query="test query", max_results=5, **config) + + # Verify success response + assert result["status"] == "success" + assert "text" in result["content"][0] + assert "Memories retrieved successfully" in result["content"][0]["text"] + # Get JSON from second content item + assert "json" in result["content"][1] + response_data = result["content"][1]["json"] + assert "memories" in response_data + assert len(response_data["memories"]) >= 0 + + # Verify aggregate was called with vector search pipeline + mock_mongodb_client["collection"].aggregate.assert_called() + call_args = mock_mongodb_client["collection"].aggregate.call_args[0][0] + assert "$vectorSearch" in call_args[0] + assert call_args[0]["$vectorSearch"]["path"] == "embedding" + + # Verify embedding generation for query + mock_bedrock_client["bedrock"].invoke_model.assert_called_once() + + +def test_list_memories(mock_mongodb_client, mock_bedrock_client, config): + """Test listing all memories.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock find response + mock_cursor = MagicMock() + mock_cursor.sort.return_value = mock_cursor + mock_cursor.skip.return_value = mock_cursor + mock_cursor.limit.return_value = mock_cursor + mock_cursor.__iter__.return_value = [ + { + "memory_id": "mem_123", + "content": "Test content 1", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {}, + }, + { + "memory_id": "mem_456", + "content": "Test content 2", + "timestamp": "2023-01-02T00:00:00Z", + "metadata": {}, + }, + ] + + mock_mongodb_client["collection"].find.return_value = mock_cursor + mock_mongodb_client["collection"].count_documents.return_value = 2 + + # Call the tool + result = agent.tool.mongodb_memory(action="list", max_results=10, **config) + + # Verify success response + assert result["status"] == "success" + assert "text" in result["content"][0] + assert "Memories listed successfully" in result["content"][0]["text"] + # Get JSON from second content item + assert "json" in result["content"][1] + response_data = result["content"][1]["json"] + assert "memories" in response_data + assert "total" in response_data + + # Verify find was called with proper query + mock_mongodb_client["collection"].find.assert_called_once() + call_args = mock_mongodb_client["collection"].find.call_args[0] + assert call_args[0] == {"namespace": "test_namespace"} + + +def test_get_memory(mock_mongodb_client, mock_bedrock_client, config): + """Test getting a specific memory by ID.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock find_one response + mock_mongodb_client["collection"].find_one.return_value = { + "memory_id": "mem_123", + "content": "Test content", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {"category": "test"}, + "namespace": "test_namespace", + } + + # Call the tool + result = agent.tool.mongodb_memory(action="get", memory_id="mem_123", **config) + + # Verify success response + assert result["status"] == "success" + assert "text" in result["content"][0] + assert "Memory retrieved successfully" in result["content"][0]["text"] + # Get JSON from second content item + assert "json" in result["content"][1] + response_data = result["content"][1]["json"] + assert "memory_id" in response_data + assert response_data["memory_id"] == "mem_123" + + # Verify find_one was called + mock_mongodb_client["collection"].find_one.assert_called_once() + call_args = mock_mongodb_client["collection"].find_one.call_args[0] + assert call_args[0] == {"memory_id": "mem_123"} + + +def test_delete_memory(mock_mongodb_client, mock_bedrock_client, config): + """Test deleting a memory.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock responses + mock_mongodb_client["collection"].find_one.return_value = { + "memory_id": "mem_123", + "content": "Test content", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {}, + "namespace": "test_namespace", + } + + mock_delete_result = MagicMock() + mock_delete_result.deleted_count = 1 + mock_mongodb_client["collection"].delete_one.return_value = mock_delete_result + + # Call the tool + result = agent.tool.mongodb_memory(action="delete", memory_id="mem_123", **config) + + # Verify success response + assert result["status"] == "success" + assert "text" in result["content"][0] + assert "Memory deleted successfully: mem_123" in result["content"][0]["text"] + + # Verify delete was called + mock_mongodb_client["collection"].delete_one.assert_called_once() + call_args = mock_mongodb_client["collection"].delete_one.call_args[0] + assert call_args[0] == {"memory_id": "mem_123", "namespace": "test_namespace"} + + +def test_unsupported_action(mock_mongodb_client, mock_bedrock_client, config): + """Test tool with an unsupported action.""" + agent = Agent(tools=[mongodb_memory]) + + result = agent.tool.mongodb_memory(action="unsupported_action", **config) + + # Verify error response + assert result["status"] == "error" + assert "is not supported" in result["content"][0]["text"] + assert "record" in result["content"][0]["text"] + assert "retrieve" in result["content"][0]["text"] + + +def test_missing_required_parameters(mock_mongodb_client, mock_bedrock_client, config): + """Test tool with missing required parameters.""" + agent = Agent(tools=[mongodb_memory]) + + # Test record action without content + result = agent.tool.mongodb_memory(action="record", **config) + + # Verify error response + assert result["status"] == "error" + assert "parameters are required" in result["content"][0]["text"] + assert "content" in result["content"][0]["text"] + + # Test retrieve action without query + result = agent.tool.mongodb_memory(action="retrieve", **config) + + # Verify error response + assert result["status"] == "error" + assert "parameters are required" in result["content"][0]["text"] + assert "query" in result["content"][0]["text"] + + # Test get action without memory_id + result = agent.tool.mongodb_memory(action="get", **config) + + # Verify error response + assert result["status"] == "error" + assert "parameters are required" in result["content"][0]["text"] + assert "memory_id" in result["content"][0]["text"] + + +def test_mongodb_api_error_handling(mock_mongodb_client, mock_bedrock_client, config): + """Test handling of MongoDB API errors.""" + agent = Agent(tools=[mongodb_memory]) + + # Set up mock to raise an exception + mock_mongodb_client["collection"].insert_one.side_effect = Exception("MongoDB error") + + # Call the tool + result = agent.tool.mongodb_memory(action="record", content="Test content", **config) + + # Verify error response + assert result["status"] == "error" + assert "API error" in result["content"][0]["text"] + assert "MongoDB error" in result["content"][0]["text"] + + +def test_bedrock_api_error_handling(mock_mongodb_client, mock_bedrock_client, config): + """Test handling of Bedrock API errors.""" + agent = Agent(tools=[mongodb_memory]) + + # Set up mock to raise an exception + mock_bedrock_client["bedrock"].invoke_model.side_effect = Exception("Bedrock error") + + # Call the tool + result = agent.tool.mongodb_memory(action="record", content="Test content", **config) + + # Verify error response + assert result["status"] == "error" + assert "API error" in result["content"][0]["text"] + assert "Embedding generation failed" in result["content"][0]["text"] + + +def test_memory_not_found(mock_mongodb_client, mock_bedrock_client, config): + """Test handling when memory is not found.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock to return None (not found) + mock_mongodb_client["collection"].find_one.return_value = None + + # Call the tool + result = agent.tool.mongodb_memory(action="get", memory_id="nonexistent", **config) + + # Verify error response + assert result["status"] == "error" + assert "Memory nonexistent not found" in result["content"][0]["text"] + + +def test_namespace_validation(mock_mongodb_client, mock_bedrock_client, config): + """Test that memories are properly filtered by namespace.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock find_one response with wrong namespace + mock_mongodb_client["collection"].find_one.return_value = { + "memory_id": "mem_123", + "content": "Test content", + "namespace": "wrong_namespace", + } + + # Call the tool + result = agent.tool.mongodb_memory(action="get", memory_id="mem_123", **config) + + # Verify error response + assert result["status"] == "error" + assert "not found in namespace test_namespace" in result["content"][0]["text"] + + +def test_pagination_support(mock_mongodb_client, mock_bedrock_client, config): + """Test pagination support in list and retrieve operations.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock find response with pagination + mock_cursor = MagicMock() + mock_cursor.sort.return_value = mock_cursor + mock_cursor.skip.return_value = mock_cursor + mock_cursor.limit.return_value = mock_cursor + mock_cursor.__iter__.return_value = [ + { + "memory_id": "mem_123", + "content": "Test content", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {}, + } + ] + + mock_mongodb_client["collection"].find.return_value = mock_cursor + mock_mongodb_client["collection"].count_documents.return_value = 20 # More results available + + # Test list with pagination + agent.tool.mongodb_memory(action="list", max_results=5, next_token="10", **config) + + # Verify skip was called with correct offset + mock_cursor.skip.assert_called_with(10) + mock_cursor.limit.assert_called_with(5) + + +def test_environment_variable_defaults(mock_mongodb_client, mock_bedrock_client): + """Test that environment variables are used for defaults.""" + agent = Agent(tools=[mongodb_memory]) + + with mock.patch.dict( + os.environ, + { + "MONGODB_ATLAS_CLUSTER_URI": "mongodb+srv://env:env@cluster.mongodb.net/", + "MONGODB_DATABASE_NAME": "env_db", + "MONGODB_COLLECTION_NAME": "env_collection", + "MONGODB_NAMESPACE": "env_namespace", + "MONGODB_EMBEDDING_MODEL": "env_model", + "AWS_REGION": "env_region", + }, + ): + # Configure mock responses + mock_result = MagicMock() + mock_result.inserted_id = "test_id" + mock_mongodb_client["collection"].insert_one.return_value = mock_result + + # Call tool without explicit parameters (should use env vars) + result = agent.tool.mongodb_memory(action="record", content="Test content") + + # Verify success (means env vars were used correctly) + assert result["status"] == "success" + assert "text" in result["content"][0] + assert "Memory stored successfully" in result["content"][0]["text"] + # Get JSON from second content item + assert "json" in result["content"][1] + response_data = result["content"][1]["json"] + assert "memory_id" in response_data + + +def test_agent_tool_usage(mock_mongodb_client, mock_bedrock_client): + """Test using the mongodb_memory tool through agent.tool pattern.""" + # Configure mock responses + mock_result = MagicMock() + mock_result.inserted_id = "test_id" + mock_mongodb_client["collection"].insert_one.return_value = mock_result + + # Create agent with direct tool usage - this demonstrates the standard pattern + agent = Agent(tools=[mongodb_memory]) + + # Test calling the tool through agent.tool with configuration parameters + result = agent.tool.mongodb_memory( + action="record", + content="Test memory content", + cluster_uri="mongodb+srv://test:test@cluster.mongodb.net/", + database_name="test_db", + collection_name="test_collection", + namespace="test_namespace", + ) + + # Verify success response + assert result["status"] == "success" + assert "Memory stored successfully" in result["content"][0]["text"] + + # Verify MongoDB insert was called + mock_mongodb_client["collection"].insert_one.assert_called_once() + + # Verify embedding generation was called + mock_bedrock_client["bedrock"].invoke_model.assert_called_once() + + +def test_custom_embedding_model(mock_mongodb_client, mock_bedrock_client, config): + """Test using custom embedding model.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock responses + mock_result = MagicMock() + mock_result.inserted_id = "test_id" + mock_mongodb_client["collection"].insert_one.return_value = mock_result + + # Call tool with custom embedding model + result = agent.tool.mongodb_memory( + action="record", content="Test memory content", embedding_model="amazon.titan-embed-text-v1:0", **config + ) + + # Verify success response + assert result["status"] == "success" + assert "Memory stored successfully" in result["content"][0]["text"] + + # Verify Bedrock was called with custom model + mock_bedrock_client["bedrock"].invoke_model.assert_called_once() + call_args = mock_bedrock_client["bedrock"].invoke_model.call_args + assert call_args[1]["modelId"] == "amazon.titan-embed-text-v1:0" + + +def test_multiple_namespaces(mock_mongodb_client, mock_bedrock_client, config): + """Test using different namespaces for data isolation.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock responses + mock_result = MagicMock() + mock_result.inserted_id = "test_id" + mock_mongodb_client["collection"].insert_one.return_value = mock_result + + # Store memory in user namespace + result1 = agent.tool.mongodb_memory( + action="record", + content="Alice likes Italian food", + namespace="user_alice", + **{k: v for k, v in config.items() if k != "namespace"}, + ) + + # Store memory in system namespace + result2 = agent.tool.mongodb_memory( + action="record", + content="System maintenance scheduled", + namespace="system_global", + **{k: v for k, v in config.items() if k != "namespace"}, + ) + + # Verify both operations succeeded + assert result1["status"] == "success" + assert result2["status"] == "success" + + # Verify both calls were made + assert mock_mongodb_client["collection"].insert_one.call_count == 2 + + +def test_configuration_dictionary_pattern(mock_mongodb_client, mock_bedrock_client): + """Test using configuration dictionary for cleaner code.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock responses + mock_result = MagicMock() + mock_result.inserted_id = "test_id" + mock_mongodb_client["collection"].insert_one.return_value = mock_result + + mock_mongodb_client["collection"].aggregate.return_value = [ + { + "memory_id": "mem_123", + "content": "Test content", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {}, + "score": 0.95, + } + ] + + # Create configuration dictionary + config = { + "cluster_uri": "mongodb+srv://test:test@cluster.mongodb.net/", + "database_name": "memories_db", + "collection_name": "memories", + "namespace": "user_123", + "region": "us-east-1", + } + + # Store memory using config dictionary + result1 = agent.tool.mongodb_memory(action="record", content="User prefers vegetarian pizza", **config) + + # Search memories using config dictionary + result2 = agent.tool.mongodb_memory(action="retrieve", query="food preferences", max_results=5, **config) + + # Verify both operations succeeded + assert result1["status"] == "success" + assert result2["status"] == "success" + assert "Memory stored successfully" in result1["content"][0]["text"] + assert "Memories retrieved successfully" in result2["content"][0]["text"] + + +def test_batch_operations(mock_mongodb_client, mock_bedrock_client, config): + """Test storing multiple related memories in batch.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock responses + mock_result = MagicMock() + mock_result.inserted_id = "test_id" + mock_mongodb_client["collection"].insert_one.return_value = mock_result + + # Store multiple related memories + memories = ["User likes Italian food", "User is allergic to nuts", "User prefers evening meetings"] + + results = [] + for content in memories: + result = agent.tool.mongodb_memory( + action="record", + content=content, + metadata={"batch": "user_preferences", "category": "preferences"}, + **config, + ) + results.append(result) + + # Verify all operations succeeded + for result in results: + assert result["status"] == "success" + assert "Memory stored successfully" in result["content"][0]["text"] + + # Verify correct number of calls were made + assert mock_mongodb_client["collection"].insert_one.call_count == len(memories) + + +def test_error_handling_scenarios(mock_mongodb_client, mock_bedrock_client, config): + """Test comprehensive error handling scenarios.""" + agent = Agent(tools=[mongodb_memory]) + + # Test connection errors + from pymongo.errors import ConnectionFailure + + mock_mongodb_client["client"].admin.command.side_effect = ConnectionFailure("Connection failed") + result = agent.tool.mongodb_memory(action="record", content="test", **config) + assert result["status"] == "error" + assert "Unable to connect to MongoDB cluster" in result["content"][0]["text"] + + # Reset admin.command to return success for subsequent tests + mock_mongodb_client["client"].admin.command.side_effect = None + mock_mongodb_client["client"].admin.command.return_value = {"ok": 1} + + # Test MongoDB API errors + mock_mongodb_client["collection"].insert_one.side_effect = Exception("MongoDB connection failed") + result = agent.tool.mongodb_memory(action="record", content="test", **config) + assert result["status"] == "error" + assert "API error" in result["content"][0]["text"] + + # Reset side effect + mock_mongodb_client["collection"].insert_one.side_effect = None + + # Test Bedrock API errors + mock_bedrock_client["bedrock"].invoke_model.side_effect = Exception("Bedrock access denied") + result = agent.tool.mongodb_memory(action="record", content="test", **config) + assert result["status"] == "error" + assert "Embedding generation failed" in result["content"][0]["text"] + + +def test_metadata_usage_scenarios(mock_mongodb_client, mock_bedrock_client, config): + """Test various metadata usage patterns.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock responses + mock_result = MagicMock() + mock_result.inserted_id = "test_id" + mock_mongodb_client["collection"].insert_one.return_value = mock_result + + # Test structured metadata + structured_metadata = { + "type": "deadline", + "project": "project_alpha", + "priority": "high", + "due_date": "2024-02-01", + "assigned_to": ["alice", "bob"], + } + + result = agent.tool.mongodb_memory( + action="record", content="Important project deadline", metadata=structured_metadata, **config + ) + + assert result["status"] == "success" + assert "Memory stored successfully" in result["content"][0]["text"] + + # Verify the insert call included metadata + mock_mongodb_client["collection"].insert_one.assert_called() + call_args = mock_mongodb_client["collection"].insert_one.call_args[0][0] + assert call_args["metadata"] == structured_metadata + + +def test_performance_scenarios(mock_mongodb_client, mock_bedrock_client, config): + """Test performance-related scenarios like pagination.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock find response with pagination + mock_cursor = MagicMock() + mock_cursor.sort.return_value = mock_cursor + mock_cursor.skip.return_value = mock_cursor + mock_cursor.limit.return_value = mock_cursor + mock_cursor.__iter__.return_value = [ + { + "memory_id": f"mem_{i}", + "content": f"Test content {i}", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {}, + } + for i in range(5) + ] + + mock_mongodb_client["collection"].find.return_value = mock_cursor + mock_mongodb_client["collection"].count_documents.return_value = 25 # More results available + + # Test pagination with next_token + result = agent.tool.mongodb_memory(action="list", max_results=5, next_token="10", **config) + + assert result["status"] == "success" + assert "Memories listed successfully" in result["content"][0]["text"] + + # Verify pagination parameters were used + mock_cursor.skip.assert_called_with(10) + mock_cursor.limit.assert_called_with(5) + + +def test_security_scenarios(mock_mongodb_client, mock_bedrock_client): + """Test security-related scenarios like namespace isolation.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock find_one response with wrong namespace + mock_mongodb_client["collection"].find_one.return_value = { + "memory_id": "mem_123", + "content": "Test content", + "namespace": "wrong_namespace", + } + + # Test namespace validation + result = agent.tool.mongodb_memory( + action="get", + memory_id="mem_123", + cluster_uri="mongodb+srv://test:test@cluster.mongodb.net/", + database_name="test_db", + collection_name="test_collection", + namespace="correct_namespace", + ) + + assert result["status"] == "error" + assert "not found in namespace correct_namespace" in result["content"][0]["text"] + + +def test_troubleshooting_scenarios(mock_mongodb_client, mock_bedrock_client, config): + """Test troubleshooting scenarios mentioned in documentation.""" + agent = Agent(tools=[mongodb_memory]) + + # Test index creation failure - now it should succeed with warning, not error + mock_mongodb_client["collection"].create_search_index.side_effect = Exception("Index creation failed") + mock_mongodb_client["collection"].aggregate.return_value = [] + result = agent.tool.mongodb_memory(action="retrieve", query="test", **config) + assert result["status"] == "success" # Should succeed despite index creation failure + + # Reset side effect + mock_mongodb_client["collection"].create_search_index.side_effect = None + + # Test authentication errors (simulated by connection failure) + from pymongo.errors import ConnectionFailure + + mock_mongodb_client["client"].admin.command.side_effect = ConnectionFailure("Authentication failed") + result = agent.tool.mongodb_memory(action="record", content="test", **config) + assert result["status"] == "error" + assert "Unable to connect to MongoDB cluster" in result["content"][0]["text"] + + +def test_vector_search_pipeline_structure(mock_mongodb_client, mock_bedrock_client, config): + """Test that the vector search pipeline is structured correctly.""" + agent = Agent(tools=[mongodb_memory]) + + # Configure mock aggregate response + mock_mongodb_client["collection"].aggregate.return_value = [ + { + "memory_id": "mem_123", + "content": "Test content", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {}, + "score": 0.95, + } + ] + + # Call retrieve action + agent.tool.mongodb_memory(action="retrieve", query="test query", **config) + + # Verify aggregate was called + mock_mongodb_client["collection"].aggregate.assert_called() + + # Get the pipeline structure - there should be two calls to aggregate + # First call is the main search pipeline, second is for total count + aggregate_calls = mock_mongodb_client["collection"].aggregate.call_args_list + assert len(aggregate_calls) >= 1 + + # Get the first (main search) pipeline + main_pipeline = aggregate_calls[0][0][0] + + # Verify pipeline structure + assert len(main_pipeline) == 5 # Should have vectorSearch, skip, limit, addFields, project stages + assert "$vectorSearch" in main_pipeline[0] + assert "$skip" in main_pipeline[1] + assert "$limit" in main_pipeline[2] + assert "$addFields" in main_pipeline[3] + assert "$project" in main_pipeline[4] + + # Verify vectorSearch configuration + vector_search = main_pipeline[0]["$vectorSearch"] + assert vector_search["index"] == "vector_index" + assert vector_search["path"] == "embedding"