Skip to content

Commit 69c9c68

Browse files
authored
Fixes import error & adds delete_session_node option to Neo4jMessageHistory (neo4j#282)
* Moves LLMMessage to avoid a cicular import with LLM classes * Updated more imports * Updates docs * Updated more imports * Updated Neo4jMessageHistory to allow for optional session node deletion * Updated LLMMessage deprecation warning
1 parent c944dca commit 69c9c68

18 files changed

+93
-41
lines changed

docs/source/types.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ LLMResponse
3131
LLMMessage
3232
===========
3333

34-
.. autoclass:: neo4j_graphrag.llm.types.LLMMessage
34+
.. autoclass:: neo4j_graphrag.types.LLMMessage
3535

3636

3737
RagResultModel

examples/customize/llms/custom_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, List, Optional, Union
44

55
from neo4j_graphrag.llm import LLMInterface, LLMResponse
6-
from neo4j_graphrag.llm.types import LLMMessage
76
from neo4j_graphrag.message_history import MessageHistory
7+
from neo4j_graphrag.types import LLMMessage
88

99

1010
class CustomLLM(LLMInterface):

src/neo4j_graphrag/generation/graphrag.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
from neo4j_graphrag.generation.prompts import RagTemplate
2828
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
2929
from neo4j_graphrag.llm import LLMInterface
30-
from neo4j_graphrag.llm.types import LLMMessage
3130
from neo4j_graphrag.message_history import MessageHistory
3231
from neo4j_graphrag.retrievers.base import Retriever
33-
from neo4j_graphrag.types import RetrieverResult
32+
from neo4j_graphrag.types import LLMMessage, RetrieverResult
3433

3534
logger = logging.getLogger(__name__)
3635

src/neo4j_graphrag/llm/anthropic_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
from neo4j_graphrag.llm.base import LLMInterface
2222
from neo4j_graphrag.llm.types import (
2323
BaseMessage,
24-
LLMMessage,
2524
LLMResponse,
2625
MessageList,
2726
UserMessage,
2827
)
2928
from neo4j_graphrag.message_history import MessageHistory
29+
from neo4j_graphrag.types import LLMMessage
3030

3131
if TYPE_CHECKING:
3232
from anthropic.types.message_param import MessageParam

src/neo4j_graphrag/llm/base.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
from typing import Any, List, Optional, Union
1919

2020
from neo4j_graphrag.message_history import MessageHistory
21+
from neo4j_graphrag.types import LLMMessage
2122

22-
from .types import (
23-
LLMMessage,
24-
LLMResponse,
25-
)
23+
from .types import LLMResponse
2624

2725

2826
class LLMInterface(ABC):

src/neo4j_graphrag/llm/cohere_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from neo4j_graphrag.llm.base import LLMInterface
2323
from neo4j_graphrag.llm.types import (
2424
BaseMessage,
25-
LLMMessage,
2625
LLMResponse,
2726
MessageList,
2827
SystemMessage,
2928
UserMessage,
3029
)
3130
from neo4j_graphrag.message_history import MessageHistory
31+
from neo4j_graphrag.types import LLMMessage
3232

3333
if TYPE_CHECKING:
3434
from cohere import ChatMessages

src/neo4j_graphrag/llm/mistralai_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
from neo4j_graphrag.llm.base import LLMInterface
2424
from neo4j_graphrag.llm.types import (
2525
BaseMessage,
26-
LLMMessage,
2726
LLMResponse,
2827
MessageList,
2928
SystemMessage,
3029
UserMessage,
3130
)
3231
from neo4j_graphrag.message_history import MessageHistory
32+
from neo4j_graphrag.types import LLMMessage
3333

3434
try:
3535
from mistralai import Messages, Mistral

src/neo4j_graphrag/llm/ollama_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
from neo4j_graphrag.exceptions import LLMGenerationError
2222
from neo4j_graphrag.message_history import MessageHistory
23+
from neo4j_graphrag.types import LLMMessage
2324

2425
from .base import LLMInterface
2526
from .types import (
2627
BaseMessage,
27-
LLMMessage,
2828
LLMResponse,
2929
MessageList,
3030
SystemMessage,

src/neo4j_graphrag/llm/openai_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
from pydantic import ValidationError
2121

2222
from neo4j_graphrag.message_history import MessageHistory
23+
from neo4j_graphrag.types import LLMMessage
2324

2425
from ..exceptions import LLMGenerationError
2526
from .base import LLMInterface
2627
from .types import (
2728
BaseMessage,
28-
LLMMessage,
2929
LLMResponse,
3030
MessageList,
3131
SystemMessage,

src/neo4j_graphrag/llm/types.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1-
from typing import Literal, TypedDict
1+
import warnings
2+
from typing import Any, Literal
23

34
from pydantic import BaseModel
45

6+
from neo4j_graphrag.types import LLMMessage as _LLMMessage
57

6-
class LLMResponse(BaseModel):
7-
content: str
88

9+
def __getattr__(name: str) -> Any:
10+
if name == "LLMMessage":
11+
warnings.warn(
12+
"LLMMessage has been moved to neo4j_graphrag.types. Please update your imports.",
13+
DeprecationWarning,
14+
stacklevel=2,
15+
)
16+
return _LLMMessage
17+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
918

10-
class LLMMessage(TypedDict):
11-
role: Literal["system", "user", "assistant"]
19+
20+
class LLMResponse(BaseModel):
1221
content: str
1322

1423

src/neo4j_graphrag/llm/vertexai_llm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
2121
from neo4j_graphrag.llm.base import LLMInterface
22-
from neo4j_graphrag.llm.types import BaseMessage, LLMMessage, LLMResponse, MessageList
22+
from neo4j_graphrag.llm.types import BaseMessage, LLMResponse, MessageList
2323
from neo4j_graphrag.message_history import MessageHistory
24+
from neo4j_graphrag.types import LLMMessage
2425

2526
try:
2627
from vertexai.generative_models import (

src/neo4j_graphrag/message_history.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,15 @@
1919
import neo4j
2020
from pydantic import PositiveInt
2121

22-
from neo4j_graphrag.llm.types import (
23-
LLMMessage,
24-
)
2522
from neo4j_graphrag.types import (
23+
LLMMessage,
2624
Neo4jDriverModel,
2725
Neo4jMessageHistoryModel,
2826
)
2927

3028
CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})"
3129

32-
CLEAR_SESSION_QUERY = (
30+
DELETE_SESSION_AND_MESSAGES_QUERY = (
3331
"MATCH (s:`{node_label}`) "
3432
"WHERE s.id = $session_id "
3533
"OPTIONAL MATCH p=(s)-[:LAST_MESSAGE]->(:Message)<-[:NEXT*0..]-(:Message) "
@@ -38,6 +36,14 @@
3836
"DETACH DELETE node;"
3937
)
4038

39+
DELETE_MESSAGES_QUERY = (
40+
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message:Message) "
41+
"WHERE s.id = $session_id "
42+
"MATCH p=(last_message)<-[:NEXT*0..]-(:Message) "
43+
"UNWIND nodes(p) as node "
44+
"DETACH DELETE node;"
45+
)
46+
4147
GET_MESSAGES_QUERY = (
4248
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) "
4349
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
@@ -82,8 +88,8 @@ class InMemoryMessageHistory(MessageHistory):
8288
8389
.. code-block:: python
8490
85-
from neo4j_graphrag.llm.types import LLMMessage
8691
from neo4j_graphrag.message_history import InMemoryMessageHistory
92+
from neo4j_graphrag.types import LLMMessage
8793
8894
history = InMemoryMessageHistory()
8995
@@ -125,8 +131,8 @@ class Neo4jMessageHistory(MessageHistory):
125131
.. code-block:: python
126132
127133
import neo4j
128-
from neo4j_graphrag.llm.types import LLMMessage
129134
from neo4j_graphrag.message_history import Neo4jMessageHistory
135+
from neo4j_graphrag.types import LLMMessage
130136
131137
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
132138
@@ -204,9 +210,19 @@ def add_message(self, message: LLMMessage) -> None:
204210
},
205211
)
206212

207-
def clear(self) -> None:
208-
"""Clear the message history."""
209-
self._driver.execute_query(
210-
query_=CLEAR_SESSION_QUERY.format(node_label="Session"),
211-
parameters_={"session_id": self._session_id},
212-
)
213+
def clear(self, delete_session_node: bool = False) -> None:
214+
"""Clear the message history.
215+
216+
Args:
217+
delete_session_node (bool): Whether to delete the session node. Defaults to False.
218+
"""
219+
if delete_session_node:
220+
self._driver.execute_query(
221+
query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(node_label="Session"),
222+
parameters_={"session_id": self._session_id},
223+
)
224+
else:
225+
self._driver.execute_query(
226+
query_=DELETE_MESSAGES_QUERY.format(node_label="Session"),
227+
parameters_={"session_id": self._session_id},
228+
)

src/neo4j_graphrag/types.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from enum import Enum
18-
from typing import Any, Callable, Literal, Optional, Union
18+
from typing import Any, Callable, Literal, Optional, TypedDict, Union
1919

2020
import neo4j
2121
from pydantic import (
@@ -263,3 +263,8 @@ def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]:
263263
if isinstance(v, str) and len(v) == 0:
264264
raise ValueError("session_id cannot be empty")
265265
return v
266+
267+
268+
class LLMMessage(TypedDict):
269+
role: Literal["system", "user", "assistant"]
270+
content: str

tests/e2e/test_graphrag_e2e.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
from neo4j_graphrag.generation.graphrag import GraphRAG
2222
from neo4j_graphrag.generation.types import RagResultModel
2323
from neo4j_graphrag.llm import LLMResponse
24-
from neo4j_graphrag.llm.types import LLMMessage
2524
from neo4j_graphrag.message_history import Neo4jMessageHistory
2625
from neo4j_graphrag.retrievers import VectorCypherRetriever
27-
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
26+
from neo4j_graphrag.types import LLMMessage, RetrieverResult, RetrieverResultItem
2827

2928
from tests.e2e.conftest import BiologyEmbedder
3029
from tests.e2e.utils import build_data_objects, populate_neo4j

tests/e2e/test_message_history_e2e.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import neo4j
16-
from neo4j_graphrag.llm.types import LLMMessage
1716
from neo4j_graphrag.message_history import Neo4jMessageHistory
17+
from neo4j_graphrag.types import LLMMessage
1818

1919

2020
def test_neo4j_message_history_add_message(driver: neo4j.Driver) -> None:
@@ -62,7 +62,7 @@ def test_neo4j_message_history_add_messages(driver: neo4j.Driver) -> None:
6262
)
6363

6464

65-
def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None:
65+
def test_neo4j_message_history_clear_messages(driver: neo4j.Driver) -> None:
6666
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
6767
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
6868
message_history.add_messages(
@@ -74,12 +74,38 @@ def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None:
7474
assert len(message_history.messages) == 2
7575
message_history.clear()
7676
assert len(message_history.messages) == 0
77+
# Test that the session node is not deleted
78+
results = driver.execute_query(
79+
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
80+
)
81+
assert len(results.records) == 1
82+
assert results.records[0]["s"]["id"] == "123"
83+
assert list(results.records[0]["s"].labels) == ["Session"]
84+
85+
86+
def test_neo4j_message_history_clear_session_and_messages(driver: neo4j.Driver) -> None:
87+
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
88+
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
89+
message_history.add_messages(
90+
[
91+
LLMMessage(role="system", content="You are a helpful assistant."),
92+
LLMMessage(role="user", content="Hello"),
93+
]
94+
)
95+
assert len(message_history.messages) == 2
96+
message_history.clear(delete_session_node=True)
97+
assert len(message_history.messages) == 0
98+
# Test that the session node is deleted
99+
results = driver.execute_query(
100+
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
101+
)
102+
assert results.records == []
77103

78104

79105
def test_neo4j_message_history_clear_no_messages(driver: neo4j.Driver) -> None:
80106
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
81107
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
82-
message_history.clear()
108+
message_history.clear(delete_session_node=True)
83109
results = driver.execute_query(
84110
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
85111
)

tests/unit/llm/test_vertexai_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
import pytest
2121
from neo4j_graphrag.exceptions import LLMGenerationError
22-
from neo4j_graphrag.llm.types import LLMMessage
2322
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
23+
from neo4j_graphrag.types import LLMMessage
2424
from vertexai.generative_models import Content, Part
2525

2626

tests/unit/test_graphrag.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
from neo4j_graphrag.generation.prompts import RagTemplate
2222
from neo4j_graphrag.generation.types import RagResultModel
2323
from neo4j_graphrag.llm import LLMResponse
24-
from neo4j_graphrag.llm.types import LLMMessage
2524
from neo4j_graphrag.message_history import InMemoryMessageHistory
26-
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
25+
from neo4j_graphrag.types import LLMMessage, RetrieverResult, RetrieverResultItem
2726

2827

2928
def test_graphrag_prompt_template() -> None:

tests/unit/test_message_history.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from unittest.mock import MagicMock
1616

1717
import pytest
18-
from neo4j_graphrag.llm.types import LLMMessage
1918
from neo4j_graphrag.message_history import InMemoryMessageHistory, Neo4jMessageHistory
19+
from neo4j_graphrag.types import LLMMessage
2020
from pydantic import ValidationError
2121

2222

0 commit comments

Comments
 (0)