Skip to content

Commit 3eb99e5

Browse files
committed
Address PR comments: refactor retriever-to-tool conversion
- Add abstract get_parameters() method to Retriever base class - Add convert_to_tool() instance method to Retriever class - Implement get_parameters() for all concrete retriever classes - Remove automatic query_text injection in ToolsRetriever - Update example to use new convert_to_tool() method - Remove unnecessary description from ObjectParameter in example
1 parent d7a0104 commit 3eb99e5

File tree

5 files changed

+832
-26
lines changed

5 files changed

+832
-26
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
"""
17+
Example demonstrating how to create multiple domain-specific tools from retrievers.
18+
19+
This example shows:
20+
1. How to create multiple tools from the same retriever type for different use cases
21+
2. How to provide custom parameter descriptions for each tool
22+
3. How type inference works automatically while descriptions are explicit
23+
"""
24+
25+
import neo4j
26+
from typing import cast, Any, Optional
27+
from unittest.mock import MagicMock
28+
29+
from neo4j_graphrag.retrievers.base import Retriever
30+
from neo4j_graphrag.types import RawSearchResult
31+
32+
33+
class MockVectorRetriever(Retriever):
34+
"""A mock vector retriever for demonstration purposes."""
35+
36+
VERIFY_NEO4J_VERSION = False
37+
38+
def __init__(self, driver: neo4j.Driver, index_name: str):
39+
super().__init__(driver)
40+
self.index_name = index_name
41+
42+
def get_search_results(
43+
self,
44+
query_vector: Optional[list[float]] = None,
45+
query_text: Optional[str] = None,
46+
top_k: int = 5,
47+
effective_search_ratio: int = 1,
48+
filters: Optional[dict[str, Any]] = None,
49+
) -> RawSearchResult:
50+
"""Get vector search results (mocked for demonstration)."""
51+
# Return empty results for demo
52+
return RawSearchResult(records=[], metadata={"index": self.index_name})
53+
54+
55+
def main() -> None:
56+
"""Demonstrate creating multiple domain-specific tools from retrievers."""
57+
58+
# Create mock driver (in real usage, this would be actual Neo4j driver)
59+
driver = cast(Any, MagicMock())
60+
61+
# Create retrievers for different domains using the same retriever type
62+
# In practice, these would point to different vector indexes
63+
64+
# Movie recommendations retriever
65+
movie_retriever = MockVectorRetriever(
66+
driver=driver,
67+
index_name="movie_embeddings"
68+
)
69+
70+
# Product search retriever
71+
product_retriever = MockVectorRetriever(
72+
driver=driver,
73+
index_name="product_embeddings"
74+
)
75+
76+
# Document search retriever
77+
document_retriever = MockVectorRetriever(
78+
driver=driver,
79+
index_name="document_embeddings"
80+
)
81+
82+
# Convert each retriever to a domain-specific tool with custom descriptions
83+
84+
# 1. Movie recommendation tool
85+
movie_tool = movie_retriever.convert_to_tool(
86+
name="movie_search",
87+
description="Find movie recommendations based on plot, genre, or actor preferences",
88+
parameter_descriptions={
89+
"query_text": "Movie title, plot description, genre, or actor name",
90+
"query_vector": "Pre-computed embedding vector for movie search",
91+
"top_k": "Number of movie recommendations to return (1-20)",
92+
"filters": "Optional filters for genre, year, rating, etc.",
93+
"effective_search_ratio": "Search pool multiplier for better accuracy"
94+
}
95+
)
96+
97+
# 2. Product search tool
98+
product_tool = product_retriever.convert_to_tool(
99+
name="product_search",
100+
description="Search for products matching customer needs and preferences",
101+
parameter_descriptions={
102+
"query_text": "Product name, description, or customer need",
103+
"query_vector": "Pre-computed embedding for product matching",
104+
"top_k": "Maximum number of product results (1-50)",
105+
"filters": "Filters for price range, brand, category, availability",
106+
"effective_search_ratio": "Breadth vs precision trade-off for search"
107+
}
108+
)
109+
110+
# 3. Document search tool
111+
document_tool = document_retriever.convert_to_tool(
112+
name="document_search",
113+
description="Find relevant documents and knowledge articles",
114+
parameter_descriptions={
115+
"query_text": "Question, keywords, or topic to search for",
116+
"query_vector": "Semantic embedding for document retrieval",
117+
"top_k": "Number of relevant documents to retrieve (1-10)",
118+
"filters": "Document type, date range, or department filters"
119+
}
120+
)
121+
122+
# Demonstrate that each tool has distinct, meaningful descriptions
123+
tools = [movie_tool, product_tool, document_tool]
124+
125+
for tool in tools:
126+
print(f"\n=== {tool.get_name().upper()} ===")
127+
print(f"Description: {tool.get_description()}")
128+
print("Parameters:")
129+
130+
params = tool.get_parameters()
131+
for param_name, param_def in params["properties"].items():
132+
required = "required" if param_name in params.get("required", []) else "optional"
133+
print(f" - {param_name} ({param_def['type']}, {required}): {param_def['description']}")
134+
135+
# Show how the same parameter type gets different contextual descriptions
136+
print(f"\n=== PARAMETER COMPARISON ===")
137+
print("Same parameter 'query_text' with different contextual descriptions:")
138+
139+
for tool in tools:
140+
params = tool.get_parameters()
141+
query_text_desc = params["properties"]["query_text"]["description"]
142+
print(f" {tool.get_name()}: {query_text_desc}")
143+
144+
print(f"\nSame parameter 'top_k' with different contextual descriptions:")
145+
for tool in tools:
146+
params = tool.get_parameters()
147+
top_k_desc = params["properties"]["top_k"]["description"]
148+
print(f" {tool.get_name()}: {top_k_desc}")
149+
150+
151+
if __name__ == "__main__":
152+
main()

examples/retrieve/tools/retriever_to_tool_example.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
Example demonstrating how to convert a retriever to a tool.
1818
1919
This example shows:
20-
1. How to convert a custom StaticRetriever to a Tool
21-
2. How to define parameters for the tool
20+
1. How to convert a custom StaticRetriever to a Tool using the convert_to_tool method
21+
2. How to define parameters for the tool in the retriever class
2222
3. How to execute the tool
2323
"""
2424

@@ -32,7 +32,6 @@
3232
StringParameter,
3333
ObjectParameter,
3434
)
35-
from neo4j_graphrag.tools.utils import convert_retriever_to_tool
3635

3736

3837
# Create a Retriever that returns static results about Neo4j
@@ -50,7 +49,15 @@ def __init__(self, driver: neo4j.Driver):
5049
def get_search_results(
5150
self, query_text: Optional[str] = None, **kwargs: Any
5251
) -> RawSearchResult:
53-
"""Return static information about Neo4j regardless of the query."""
52+
"""Return static information about Neo4j regardless of the query.
53+
54+
Args:
55+
query_text (Optional[str]): The query about Neo4j (any query will return general Neo4j information)
56+
**kwargs (Any): Additional keyword arguments (not used)
57+
58+
Returns:
59+
RawSearchResult: Static Neo4j information with metadata
60+
"""
5461
# Create formatted Neo4j information
5562
neo4j_info = (
5663
"# Neo4j Graph Database\n\n"
@@ -73,26 +80,16 @@ def get_search_results(
7380

7481

7582
def main() -> None:
76-
# Convert a StaticRetriever to a tool with specific parameters
83+
# Convert a StaticRetriever to a tool using the new convert_to_tool method
7784
static_retriever = StaticRetriever(driver=cast(Any, MagicMock()))
7885

79-
# Define parameters for the static retriever tool
80-
static_parameters = ObjectParameter(
81-
description="Parameters for the Neo4j information retriever",
82-
properties={
83-
"query_text": StringParameter(
84-
description="The query about Neo4j (any query will return general Neo4j information)",
85-
required=True,
86-
),
87-
},
88-
)
89-
90-
# Convert the retriever to a tool with specific parameters
91-
static_tool = convert_retriever_to_tool(
92-
retriever=static_retriever,
93-
description="Get general information about Neo4j graph database",
94-
parameters=static_parameters,
86+
# Convert the retriever to a tool with custom parameter descriptions
87+
static_tool = static_retriever.convert_to_tool(
9588
name="Neo4jInfoTool",
89+
description="Get general information about Neo4j graph database",
90+
parameter_descriptions={
91+
"query_text": "Any query about Neo4j (the tool returns general information regardless)"
92+
},
9693
)
9794

9895
# Print tool information
@@ -107,7 +104,7 @@ def main() -> None:
107104
# Execute the static retriever tool
108105
print("\nExecuting the static retriever tool...")
109106
static_result = static_tool.execute(
110-
query="What is Neo4j?",
107+
query_text="What is Neo4j?",
111108
)
112109
print("Static Search Results:")
113110
for i, item in enumerate(static_result):

0 commit comments

Comments
 (0)