Skip to content

Commit c33f9c8

Browse files
authored
Text2Cypher custom prompt: doc, example and bug fix (#229)
* Doc + bug fix * Do not change the behavior, just document they said * Use same order for patched functions and check order of mocked object
1 parent 140a057 commit c33f9c8

File tree

6 files changed

+139
-28
lines changed

6 files changed

+139
-28
lines changed

docs/source/api.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,21 @@ RagTemplate
338338

339339
.. autoclass:: neo4j_graphrag.generation.prompts.RagTemplate
340340
:members:
341+
:exclude-members: format
341342

342343
ERExtractionTemplate
343344
--------------------
344345

345346
.. autoclass:: neo4j_graphrag.generation.prompts.ERExtractionTemplate
346347
:members:
348+
:exclude-members: format
349+
350+
Text2CypherTemplate
351+
--------------------
352+
353+
.. autoclass:: neo4j_graphrag.generation.prompts.Text2CypherTemplate
354+
:members:
355+
:exclude-members: format
347356

348357

349358
****

examples/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ are listed in [the last section of this file](#customize).
5858

5959
- [Control result format for VectorRetriever](customize/retrievers/result_formatter_vector_retriever.py)
6060
- [Control result format for VectorCypherRetriever](customize/retrievers/result_formatter_vector_cypher_retriever.py)
61-
61+
- [Use pre-filters](customize/retrievers/use_pre_filters.py)
62+
- [Text2Cypher: use a custom prompt](customize/retrievers/text2cypher_custom_prompt.py)
6263

6364
### LLMs
6465

@@ -74,7 +75,7 @@ are listed in [the last section of this file](#customize).
7475

7576
### Prompts
7677

77-
- [Using a custom prompt](old/graphrag_custom_prompt.py)
78+
- [Using a custom prompt for RAG](customize/answer/custom_prompt.py)
7879

7980

8081
### Embedders
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""The example shows how to provide a custom prompt to Text2CypherRetriever.
2+
3+
Example using the OpenAILLM, hence the OPENAI_API_KEY needs to be set in the
4+
environment for this example to run.
5+
"""
6+
7+
import neo4j
8+
from neo4j_graphrag.llm import OpenAILLM
9+
from neo4j_graphrag.retrievers import Text2CypherRetriever
10+
from neo4j_graphrag.schema import get_schema
11+
12+
# Define database credentials
13+
URI = "neo4j+s://demo.neo4jlabs.com"
14+
AUTH = ("recommendations", "recommendations")
15+
DATABASE = "recommendations"
16+
17+
# Create LLM object
18+
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
19+
20+
# (Optional) Specify your own Neo4j schema
21+
# (also see get_structured_schema and get_schema functions)
22+
neo4j_schema = """
23+
Node properties:
24+
User {name: STRING}
25+
Person {name: STRING, born: INTEGER}
26+
Movie {tagline: STRING, title: STRING, released: INTEGER}
27+
Relationship properties:
28+
ACTED_IN {roles: LIST}
29+
DIRECTED {}
30+
REVIEWED {summary: STRING, rating: INTEGER}
31+
The relationships:
32+
(:Person)-[:ACTED_IN]->(:Movie)
33+
(:Person)-[:DIRECTED]->(:Movie)
34+
(:User)-[:REVIEWED]->(:Movie)
35+
"""
36+
37+
prompt = """Task: Generate a Cypher statement for querying a Neo4j graph database from a user input.
38+
39+
Do not use any properties or relationships not included in the schema.
40+
Do not include triple backticks ``` or any additional text except the generated Cypher statement in your response.
41+
42+
Always filter movies that have not already been reviewed by the user with name: '{user_name}' using for instance:
43+
(m:Movie)<-[:REVIEWED]-(:User {{name: <the_user_name>}})
44+
45+
Schema:
46+
{schema}
47+
48+
Input:
49+
{query_text}
50+
51+
Cypher query:
52+
"""
53+
54+
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
55+
# Initialize the retriever
56+
retriever = Text2CypherRetriever(
57+
driver=driver,
58+
llm=llm,
59+
neo4j_schema=neo4j_schema,
60+
# here we provide a custom prompt
61+
custom_prompt=prompt,
62+
neo4j_database=DATABASE,
63+
)
64+
65+
# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results
66+
query_text = "Which movies did Hugo Weaving star in?"
67+
print(
68+
retriever.search(
69+
query_text=query_text,
70+
prompt_params={
71+
# you have to specify all placeholder except the {query_text} one
72+
"schema": get_schema(driver),
73+
"user_name": "the user asking question",
74+
},
75+
)
76+
)

examples/retrieve/text2cypher_search.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,11 @@
2222
Movie {tagline: STRING, title: STRING, released: INTEGER}
2323
Relationship properties:
2424
ACTED_IN {roles: LIST}
25+
DIRECTED {}
2526
REVIEWED {summary: STRING, rating: INTEGER}
2627
The relationships:
2728
(:Person)-[:ACTED_IN]->(:Movie)
2829
(:Person)-[:DIRECTED]->(:Movie)
29-
(:Person)-[:PRODUCED]->(:Movie)
30-
(:Person)-[:WROTE]->(:Movie)
31-
(:Person)-[:FOLLOWS]->(:Person)
3230
(:Person)-[:REVIEWED]->(:Movie)
3331
"""
3432

src/neo4j_graphrag/retrievers/text2cypher.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Text2CypherRetriever(Retriever):
4848
"""
4949
Allows for the retrieval of records from a Neo4j database using natural language.
5050
Converts a user's natural language query to a Cypher query using an LLM,
51-
then retrieves records from a Neo4j database using the generated Cypher query
51+
then retrieves records from a Neo4j database using the generated Cypher query.
5252
5353
Args:
5454
driver (neo4j.Driver): The Neo4j Python driver.
@@ -98,23 +98,23 @@ def __init__(
9898
self.examples = validated_data.examples
9999
self.result_formatter = validated_data.result_formatter
100100
self.custom_prompt = validated_data.custom_prompt
101-
try:
101+
if validated_data.custom_prompt:
102+
neo4j_schema = ""
103+
else:
102104
if (
103-
not validated_data.custom_prompt
104-
): # don't need schema for a custom prompt
105-
self.neo4j_schema = (
106-
validated_data.neo4j_schema_model.neo4j_schema
107-
if validated_data.neo4j_schema_model
108-
else get_schema(validated_data.driver_model.driver)
109-
)
105+
validated_data.neo4j_schema_model
106+
and validated_data.neo4j_schema_model.neo4j_schema
107+
):
108+
neo4j_schema = validated_data.neo4j_schema_model.neo4j_schema
110109
else:
111-
self.neo4j_schema = ""
112-
113-
except (Neo4jError, DriverError) as e:
114-
error_message = getattr(e, "message", str(e))
115-
raise SchemaFetchError(
116-
f"Failed to fetch schema for Text2CypherRetriever: {error_message}"
117-
) from e
110+
try:
111+
neo4j_schema = get_schema(validated_data.driver_model.driver)
112+
except (Neo4jError, DriverError) as e:
113+
error_message = getattr(e, "message", str(e))
114+
raise SchemaFetchError(
115+
f"Failed to fetch schema for Text2CypherRetriever: {error_message}"
116+
) from e
117+
self.neo4j_schema = neo4j_schema
118118

119119
def get_search_results(
120120
self, query_text: str, prompt_params: Optional[Dict[str, Any]] = None
@@ -142,12 +142,10 @@ def get_search_results(
142142

143143
if prompt_params is not None:
144144
# parse the schema and examples inputs
145-
examples_to_use = prompt_params.get("examples") or (
145+
examples_to_use = prompt_params.pop("examples", None) or (
146146
"\n".join(self.examples) if self.examples else ""
147147
)
148-
schema_to_use = prompt_params.get("schema") or self.neo4j_schema
149-
prompt_params.pop("examples", None)
150-
prompt_params.pop("schema", None)
148+
schema_to_use = prompt_params.pop("schema", None) or self.neo4j_schema
151149
else:
152150
examples_to_use = "\n".join(self.examples) if self.examples else ""
153151
schema_to_use = self.neo4j_schema

tests/unit/retrievers/test_text2cypher.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from neo4j.exceptions import CypherSyntaxError, Neo4jError
2020
from neo4j_graphrag.exceptions import (
2121
RetrieverInitializationError,
22+
SchemaFetchError,
2223
SearchValidationError,
2324
Text2CypherRetrievalError,
2425
)
@@ -39,8 +40,8 @@ def test_t2c_retriever_initialization(driver: MagicMock, llm: MagicMock) -> None
3940
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
4041
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
4142
def test_t2c_retriever_schema_retrieval(
42-
_verify_version_mock: MagicMock,
4343
get_schema_mock: MagicMock,
44+
_verify_version_mock: MagicMock,
4445
driver: MagicMock,
4546
llm: MagicMock,
4647
) -> None:
@@ -51,13 +52,13 @@ def test_t2c_retriever_schema_retrieval(
5152
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
5253
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
5354
def test_t2c_retriever_schema_retrieval_failure(
54-
_verify_version_mock: MagicMock,
5555
get_schema_mock: MagicMock,
56+
_verify_version_mock: MagicMock,
5657
driver: MagicMock,
5758
llm: MagicMock,
5859
) -> None:
5960
get_schema_mock.side_effect = Neo4jError
60-
with pytest.raises(Neo4jError):
61+
with pytest.raises(SchemaFetchError):
6162
Text2CypherRetriever(driver, llm)
6263

6364

@@ -310,3 +311,31 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
310311
llm.invoke.assert_called_once_with(
311312
"""This is a custom prompt. test ['example A', 'example B']"""
312313
)
314+
315+
316+
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
317+
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
318+
def test_t2c_retriever_with_custom_prompt_and_schema(
319+
get_schema_mock: MagicMock,
320+
_verify_version_mock: MagicMock,
321+
driver: MagicMock,
322+
llm: MagicMock,
323+
neo4j_record: MagicMock,
324+
) -> None:
325+
prompt = "This is a custom prompt. {query_text} {schema}"
326+
query = "test"
327+
328+
driver.execute_query.return_value = (
329+
[neo4j_record],
330+
None,
331+
None,
332+
)
333+
334+
retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt)
335+
retriever.search(
336+
query_text=query,
337+
prompt_params={},
338+
)
339+
340+
get_schema_mock.assert_not_called()
341+
llm.invoke.assert_called_once_with("""This is a custom prompt. test """)

0 commit comments

Comments
 (0)