|
19 | 19 | import neo4j
|
20 | 20 | from pydantic import PositiveInt
|
21 | 21 |
|
22 |
| -from neo4j_graphrag.llm.types import ( |
23 |
| - LLMMessage, |
24 |
| -) |
25 | 22 | from neo4j_graphrag.types import (
|
| 23 | + LLMMessage, |
26 | 24 | Neo4jDriverModel,
|
27 | 25 | Neo4jMessageHistoryModel,
|
28 | 26 | )
|
29 | 27 |
|
30 | 28 | CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})"
|
31 | 29 |
|
32 |
| -CLEAR_SESSION_QUERY = ( |
| 30 | +DELETE_SESSION_AND_MESSAGES_QUERY = ( |
33 | 31 | "MATCH (s:`{node_label}`) "
|
34 | 32 | "WHERE s.id = $session_id "
|
35 | 33 | "OPTIONAL MATCH p=(s)-[:LAST_MESSAGE]->(:Message)<-[:NEXT*0..]-(:Message) "
|
|
38 | 36 | "DETACH DELETE node;"
|
39 | 37 | )
|
40 | 38 |
|
| 39 | +DELETE_MESSAGES_QUERY = ( |
| 40 | + "MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message:Message) " |
| 41 | + "WHERE s.id = $session_id " |
| 42 | + "MATCH p=(last_message)<-[:NEXT*0..]-(:Message) " |
| 43 | + "UNWIND nodes(p) as node " |
| 44 | + "DETACH DELETE node;" |
| 45 | +) |
| 46 | + |
41 | 47 | GET_MESSAGES_QUERY = (
|
42 | 48 | "MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) "
|
43 | 49 | "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
|
@@ -82,8 +88,8 @@ class InMemoryMessageHistory(MessageHistory):
|
82 | 88 |
|
83 | 89 | .. code-block:: python
|
84 | 90 |
|
85 |
| - from neo4j_graphrag.llm.types import LLMMessage |
86 | 91 | from neo4j_graphrag.message_history import InMemoryMessageHistory
|
| 92 | + from neo4j_graphrag.types import LLMMessage |
87 | 93 |
|
88 | 94 | history = InMemoryMessageHistory()
|
89 | 95 |
|
@@ -125,8 +131,8 @@ class Neo4jMessageHistory(MessageHistory):
|
125 | 131 | .. code-block:: python
|
126 | 132 |
|
127 | 133 | import neo4j
|
128 |
| - from neo4j_graphrag.llm.types import LLMMessage |
129 | 134 | from neo4j_graphrag.message_history import Neo4jMessageHistory
|
| 135 | + from neo4j_graphrag.types import LLMMessage |
130 | 136 |
|
131 | 137 | driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
|
132 | 138 |
|
@@ -204,9 +210,19 @@ def add_message(self, message: LLMMessage) -> None:
|
204 | 210 | },
|
205 | 211 | )
|
206 | 212 |
|
207 |
| - def clear(self) -> None: |
208 |
| - """Clear the message history.""" |
209 |
| - self._driver.execute_query( |
210 |
| - query_=CLEAR_SESSION_QUERY.format(node_label="Session"), |
211 |
| - parameters_={"session_id": self._session_id}, |
212 |
| - ) |
| 213 | + def clear(self, delete_session_node: bool = False) -> None: |
| 214 | + """Clear the message history. |
| 215 | +
|
| 216 | + Args: |
| 217 | + delete_session_node (bool): Whether to delete the session node. Defaults to False. |
| 218 | + """ |
| 219 | + if delete_session_node: |
| 220 | + self._driver.execute_query( |
| 221 | + query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(node_label="Session"), |
| 222 | + parameters_={"session_id": self._session_id}, |
| 223 | + ) |
| 224 | + else: |
| 225 | + self._driver.execute_query( |
| 226 | + query_=DELETE_MESSAGES_QUERY.format(node_label="Session"), |
| 227 | + parameters_={"session_id": self._session_id}, |
| 228 | + ) |
0 commit comments