Skip to content

Commit b893584

Browse files
authored
Neo4j message history (neo4j#273)
* Added message history classes * Updated Neo4jMessageHistoryModel * Fixed spelling error * Fixed tests * Added test_graphrag_happy_path_with_neo4j_message_history * Updated LLMs * Added missing copyright headers * Refactored graphrag * Added docstrings to message history classes * Added message history examples * Updated docs * Updated CHANGELOG * Removed Neo4jMessageHistory __del__ method * Makes the build_query and chat_summary_prompt methods in the GraphRAG class private * Added a threading lock to InMemoryMessageHistory * Removed node_label parameter from Neo4jMessageHistory * Updated CLEAR_SESSION_QUERY * Fixed CLEAR_SESSION_QUERY
1 parent 4ce3b56 commit b893584

23 files changed

+908
-81
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
- Support for effective_search_ratio parameter in vector and hybrid searches.
99
- Introduced upsert_vectors utility function for batch upserting embeddings to vector indexes.
1010
- Introduced `extract_cypher` function to enhance Cypher query extraction and formatting in `Text2CypherRetriever`.
11+
- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories.
12+
- Added examples and documentation for using message history with Neo4j and in-memory storage.
13+
- Updated LLM and GraphRAG classes to support new message history classes.
1114

1215
### Changed
1316

docs/source/api.rst

+9
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,15 @@ Database Interaction
403403
.. autofunction:: neo4j_graphrag.schema.format_schema
404404

405405

406+
***************
407+
Message History
408+
***************
409+
410+
.. autoclass:: neo4j_graphrag.message_history.InMemoryMessageHistory
411+
412+
.. autoclass:: neo4j_graphrag.message_history.Neo4jMessageHistory
413+
414+
406415
******
407416
Errors
408417
******

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ Note that the below example is not the only way you can upsert data into your Ne
148148

149149

150150
.. code:: python
151+
151152
from neo4j import GraphDatabase
152153
from neo4j_graphrag.indexes import upsert_vectors
153154
from neo4j_graphrag.types import EntityType

docs/source/user_guide_rag.rst

+1
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,7 @@ Populate a Vector Index
917917
==========================
918918

919919
.. code:: python
920+
920921
from random import random
921922
922923
from neo4j import GraphDatabase

examples/README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ are listed in [the last section of this file](#customize).
5252

5353
- [End to end GraphRAG](./answer/graphrag.py)
5454
- [GraphRAG with message history](./question_answering/graphrag_with_message_history.py)
55-
55+
- [GraphRAG with Neo4j message history](./question_answering/graphrag_with_neo4j_message_history.py)
5656

5757
## Customize
5858

@@ -75,6 +75,7 @@ are listed in [the last section of this file](#customize).
7575
- [Custom LLM](./customize/llms/custom_llm.py)
7676

7777
- [Message history](./customize/llms/llm_with_message_history.py)
78+
- [Message history with Neo4j](./customize/llms/llm_with_neo4j_message_history.py)
7879
- [System Instruction](./customize/llms/llm_with_system_instructions.py)
7980

8081

examples/customize/llms/custom_llm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import random
22
import string
3-
from typing import Any, Optional
3+
from typing import Any, List, Optional, Union
44

55
from neo4j_graphrag.llm import LLMInterface, LLMResponse
66
from neo4j_graphrag.llm.types import LLMMessage
7+
from neo4j_graphrag.message_history import MessageHistory
78

89

910
class CustomLLM(LLMInterface):
@@ -15,7 +16,7 @@ def __init__(
1516
def invoke(
1617
self,
1718
input: str,
18-
message_history: Optional[list[LLMMessage]] = None,
19+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
1920
system_instruction: Optional[str] = None,
2021
) -> LLMResponse:
2122
content: str = (
@@ -26,7 +27,7 @@ def invoke(
2627
async def ainvoke(
2728
self,
2829
input: str,
29-
message_history: Optional[list[LLMMessage]] = None,
30+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
3031
system_instruction: Optional[str] = None,
3132
) -> LLMResponse:
3233
raise NotImplementedError()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""This example illustrates the message_history feature
2+
of the LLMInterface by mocking a conversation between a user
3+
and an LLM about Tom Hanks.
4+
5+
Neo4j is used as the database for storing the message history.
6+
7+
OpenAILLM can be replaced by any supported LLM from this package.
8+
"""
9+
10+
import neo4j
11+
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
12+
from neo4j_graphrag.message_history import Neo4jMessageHistory
13+
14+
# Define database credentials
15+
URI = "neo4j+s://demo.neo4jlabs.com"
16+
AUTH = ("recommendations", "recommendations")
17+
DATABASE = "recommendations"
18+
INDEX = "moviePlotsEmbedding"
19+
20+
# set api key here on in the OPENAI_API_KEY env var
21+
api_key = None
22+
23+
llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)
24+
25+
questions = [
26+
"What are some movies Tom Hanks starred in?",
27+
"Is he also a director?",
28+
"Wow, that's impressive. And what about his personal life, does he have children?",
29+
]
30+
31+
driver = neo4j.GraphDatabase.driver(
32+
URI,
33+
auth=AUTH,
34+
database=DATABASE,
35+
)
36+
37+
history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)
38+
39+
for question in questions:
40+
res: LLMResponse = llm.invoke(
41+
question,
42+
message_history=history,
43+
)
44+
history.add_message(
45+
{
46+
"role": "user",
47+
"content": question,
48+
}
49+
)
50+
history.add_message(
51+
{
52+
"role": "assistant",
53+
"content": res.content,
54+
}
55+
)
56+
57+
print("#" * 50, question)
58+
print(res.content)
59+
print("#" * 50)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""End to end example of building a RAG pipeline backed by a Neo4j database,
2+
simulating a chat with message history which is also stored in Neo4j.
3+
4+
Requires OPENAI_API_KEY to be in the env var.
5+
"""
6+
7+
import neo4j
8+
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
9+
from neo4j_graphrag.generation import GraphRAG
10+
from neo4j_graphrag.llm import OpenAILLM
11+
from neo4j_graphrag.message_history import Neo4jMessageHistory
12+
from neo4j_graphrag.retrievers import VectorCypherRetriever
13+
14+
# Define database credentials
15+
URI = "neo4j+s://demo.neo4jlabs.com"
16+
AUTH = ("recommendations", "recommendations")
17+
DATABASE = "recommendations"
18+
INDEX = "moviePlotsEmbedding"
19+
20+
21+
driver = neo4j.GraphDatabase.driver(
22+
URI,
23+
auth=AUTH,
24+
)
25+
26+
embedder = OpenAIEmbeddings()
27+
28+
retriever = VectorCypherRetriever(
29+
driver,
30+
index_name=INDEX,
31+
retrieval_query="""
32+
WITH node as movie, score
33+
CALL(movie) {
34+
MATCH (movie)<-[:ACTED_IN]-(p:Person)
35+
RETURN collect(p.name) as actors
36+
}
37+
CALL(movie) {
38+
MATCH (movie)<-[:DIRECTED]-(p:Person)
39+
RETURN collect(p.name) as directors
40+
}
41+
RETURN movie.title as title, movie.plot as plot, movie.year as year, actors, directors
42+
""",
43+
embedder=embedder,
44+
neo4j_database=DATABASE,
45+
)
46+
47+
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
48+
49+
rag = GraphRAG(
50+
retriever=retriever,
51+
llm=llm,
52+
)
53+
54+
history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)
55+
56+
questions = [
57+
"Who starred in the Apollo 13 movies?",
58+
"Who was its director?",
59+
"In which year was this movie released?",
60+
]
61+
62+
for question in questions:
63+
result = rag.search(
64+
question,
65+
return_context=False,
66+
message_history=history,
67+
)
68+
69+
answer = result.answer
70+
print("#" * 50, question)
71+
print(answer)
72+
print("#" * 50)
73+
74+
history.add_message(
75+
{
76+
"role": "user",
77+
"content": question,
78+
}
79+
)
80+
history.add_message(
81+
{
82+
"role": "assistant",
83+
"content": answer,
84+
}
85+
)
86+
87+
driver.close()

src/neo4j_graphrag/generation/graphrag.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import logging
1818
import warnings
19-
from typing import Any, Optional
19+
from typing import Any, List, Optional, Union
2020

2121
from pydantic import ValidationError
2222

@@ -28,6 +28,7 @@
2828
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
2929
from neo4j_graphrag.llm import LLMInterface
3030
from neo4j_graphrag.llm.types import LLMMessage
31+
from neo4j_graphrag.message_history import MessageHistory
3132
from neo4j_graphrag.retrievers.base import Retriever
3233
from neo4j_graphrag.types import RetrieverResult
3334

@@ -84,7 +85,7 @@ def __init__(
8485
def search(
8586
self,
8687
query_text: str = "",
87-
message_history: Optional[list[LLMMessage]] = None,
88+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
8889
examples: str = "",
8990
retriever_config: Optional[dict[str, Any]] = None,
9091
return_context: bool | None = None,
@@ -102,7 +103,8 @@ def search(
102103
103104
Args:
104105
query_text (str): The user question.
105-
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
106+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
107+
with each message having a specific role assigned.
106108
examples (str): Examples added to the LLM prompt.
107109
retriever_config (Optional[dict]): Parameters passed to the retriever.
108110
search method; e.g.: top_k
@@ -127,7 +129,9 @@ def search(
127129
)
128130
except ValidationError as e:
129131
raise SearchValidationError(e.errors())
130-
query = self.build_query(validated_data.query_text, message_history)
132+
if isinstance(message_history, MessageHistory):
133+
message_history = message_history.messages
134+
query = self._build_query(validated_data.query_text, message_history)
131135
retriever_result: RetrieverResult = self.retriever.search(
132136
query_text=query, **validated_data.retriever_config
133137
)
@@ -147,12 +151,14 @@ def search(
147151
result["retriever_result"] = retriever_result
148152
return RagResultModel(**result)
149153

150-
def build_query(
151-
self, query_text: str, message_history: Optional[list[LLMMessage]] = None
154+
def _build_query(
155+
self,
156+
query_text: str,
157+
message_history: Optional[List[LLMMessage]] = None,
152158
) -> str:
153159
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
154160
if message_history:
155-
summarization_prompt = self.chat_summary_prompt(
161+
summarization_prompt = self._chat_summary_prompt(
156162
message_history=message_history
157163
)
158164
summary = self.llm.invoke(
@@ -162,10 +168,9 @@ def build_query(
162168
return self.conversation_prompt(summary=summary, current_query=query_text)
163169
return query_text
164170

165-
def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:
171+
def _chat_summary_prompt(self, message_history: List[LLMMessage]) -> str:
166172
message_list = [
167-
": ".join([f"{value}" for _, value in message.items()])
168-
for message in message_history
173+
f"{message['role']}: {message['content']}" for message in message_history
169174
]
170175
history = "\n".join(message_list)
171176
return f"""

0 commit comments

Comments
 (0)