Skip to content

Commit 9d40495

Browse files
committed
Updated Neo4jMessageHistory to allow for optional session node deletion
1 parent 769ad49 commit 9d40495

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

src/neo4j_graphrag/message_history.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})"
2929

30-
CLEAR_SESSION_QUERY = (
30+
DELETE_SESSION_AND_MESSAGES_QUERY = (
3131
"MATCH (s:`{node_label}`) "
3232
"WHERE s.id = $session_id "
3333
"OPTIONAL MATCH p=(s)-[:LAST_MESSAGE]->(:Message)<-[:NEXT*0..]-(:Message) "
@@ -36,6 +36,14 @@
3636
"DETACH DELETE node;"
3737
)
3838

39+
DELETE_MESSAGES_QUERY = (
40+
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message:Message) "
41+
"WHERE s.id = $session_id "
42+
"MATCH p=(last_message)<-[:NEXT*0..]-(:Message) "
43+
"UNWIND nodes(p) as node "
44+
"DETACH DELETE node;"
45+
)
46+
3947
GET_MESSAGES_QUERY = (
4048
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) "
4149
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
@@ -202,9 +210,19 @@ def add_message(self, message: LLMMessage) -> None:
202210
},
203211
)
204212

205-
def clear(self) -> None:
206-
"""Clear the message history."""
207-
self._driver.execute_query(
208-
query_=CLEAR_SESSION_QUERY.format(node_label="Session"),
209-
parameters_={"session_id": self._session_id},
210-
)
213+
def clear(self, delete_session_node: bool = False) -> None:
214+
"""Clear the message history.
215+
216+
Args:
217+
delete_session_node (bool): Whether to delete the session node. Defaults to False.
218+
"""
219+
if delete_session_node:
220+
self._driver.execute_query(
221+
query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(node_label="Session"),
222+
parameters_={"session_id": self._session_id},
223+
)
224+
else:
225+
self._driver.execute_query(
226+
query_=DELETE_MESSAGES_QUERY.format(node_label="Session"),
227+
parameters_={"session_id": self._session_id},
228+
)

tests/e2e/test_message_history_e2e.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_neo4j_message_history_add_messages(driver: neo4j.Driver) -> None:
6262
)
6363

6464

65-
def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None:
65+
def test_neo4j_message_history_clear_messages(driver: neo4j.Driver) -> None:
6666
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
6767
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
6868
message_history.add_messages(
@@ -74,12 +74,38 @@ def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None:
7474
assert len(message_history.messages) == 2
7575
message_history.clear()
7676
assert len(message_history.messages) == 0
77+
# Test that the session node is not deleted
78+
results = driver.execute_query(
79+
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
80+
)
81+
assert len(results.records) == 1
82+
assert results.records[0]["s"]["id"] == "123"
83+
assert list(results.records[0]["s"].labels) == ["Session"]
84+
85+
86+
def test_neo4j_message_history_clear_session_and_messages(driver: neo4j.Driver) -> None:
87+
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
88+
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
89+
message_history.add_messages(
90+
[
91+
LLMMessage(role="system", content="You are a helpful assistant."),
92+
LLMMessage(role="user", content="Hello"),
93+
]
94+
)
95+
assert len(message_history.messages) == 2
96+
message_history.clear(delete_session_node=True)
97+
assert len(message_history.messages) == 0
98+
# Test that the session node is deleted
99+
results = driver.execute_query(
100+
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
101+
)
102+
assert results.records == []
77103

78104

79105
def test_neo4j_message_history_clear_no_messages(driver: neo4j.Driver) -> None:
80106
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
81107
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
82-
message_history.clear()
108+
message_history.clear(delete_session_node=True)
83109
results = driver.execute_query(
84110
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
85111
)

0 commit comments

Comments
 (0)