Skip to content

Commit cf571b1

Browse files
committed
Moves LLMMessage to avoid a cicular import with LLM classes
1 parent c944dca commit cf571b1

File tree

10 files changed

+25
-18
lines changed

10 files changed

+25
-18
lines changed

examples/customize/llms/custom_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, List, Optional, Union
44

55
from neo4j_graphrag.llm import LLMInterface, LLMResponse
6-
from neo4j_graphrag.llm.types import LLMMessage
76
from neo4j_graphrag.message_history import MessageHistory
7+
from neo4j_graphrag.types import LLMMessage
88

99

1010
class CustomLLM(LLMInterface):

src/neo4j_graphrag/generation/graphrag.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
from neo4j_graphrag.generation.prompts import RagTemplate
2828
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
2929
from neo4j_graphrag.llm import LLMInterface
30-
from neo4j_graphrag.llm.types import LLMMessage
3130
from neo4j_graphrag.message_history import MessageHistory
3231
from neo4j_graphrag.retrievers.base import Retriever
33-
from neo4j_graphrag.types import RetrieverResult
32+
from neo4j_graphrag.types import LLMMessage, RetrieverResult
3433

3534
logger = logging.getLogger(__name__)
3635

src/neo4j_graphrag/llm/types.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
from typing import Literal, TypedDict
1+
import warnings
2+
from typing import Literal
23

34
from pydantic import BaseModel
45

6+
from neo4j_graphrag.types import LLMMessage as _LLMMessage
57

6-
class LLMResponse(BaseModel):
7-
content: str
8+
warnings.warn(
9+
"LLMMessage has been moved to neo4j_graphrag.types. Please update your imports.",
10+
DeprecationWarning,
11+
stacklevel=2,
12+
)
13+
LLMMessage = _LLMMessage
814

915

10-
class LLMMessage(TypedDict):
11-
role: Literal["system", "user", "assistant"]
16+
class LLMResponse(BaseModel):
1217
content: str
1318

1419

src/neo4j_graphrag/message_history.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class InMemoryMessageHistory(MessageHistory):
8282
8383
.. code-block:: python
8484
85-
from neo4j_graphrag.llm.types import LLMMessage
8685
from neo4j_graphrag.message_history import InMemoryMessageHistory
86+
from neo4j_graphrag.types import LLMMessage
8787
8888
history = InMemoryMessageHistory()
8989
@@ -125,8 +125,8 @@ class Neo4jMessageHistory(MessageHistory):
125125
.. code-block:: python
126126
127127
import neo4j
128-
from neo4j_graphrag.llm.types import LLMMessage
129128
from neo4j_graphrag.message_history import Neo4jMessageHistory
129+
from neo4j_graphrag.types import LLMMessage
130130
131131
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
132132

src/neo4j_graphrag/types.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from enum import Enum
18-
from typing import Any, Callable, Literal, Optional, Union
18+
from typing import Any, Callable, Literal, Optional, TypedDict, Union
1919

2020
import neo4j
2121
from pydantic import (
@@ -263,3 +263,8 @@ def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]:
263263
if isinstance(v, str) and len(v) == 0:
264264
raise ValueError("session_id cannot be empty")
265265
return v
266+
267+
268+
class LLMMessage(TypedDict):
269+
role: Literal["system", "user", "assistant"]
270+
content: str

tests/e2e/test_graphrag_e2e.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
from neo4j_graphrag.generation.graphrag import GraphRAG
2222
from neo4j_graphrag.generation.types import RagResultModel
2323
from neo4j_graphrag.llm import LLMResponse
24-
from neo4j_graphrag.llm.types import LLMMessage
2524
from neo4j_graphrag.message_history import Neo4jMessageHistory
2625
from neo4j_graphrag.retrievers import VectorCypherRetriever
27-
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
26+
from neo4j_graphrag.types import LLMMessage, RetrieverResult, RetrieverResultItem
2827

2928
from tests.e2e.conftest import BiologyEmbedder
3029
from tests.e2e.utils import build_data_objects, populate_neo4j

tests/e2e/test_message_history_e2e.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import neo4j
16-
from neo4j_graphrag.llm.types import LLMMessage
1716
from neo4j_graphrag.message_history import Neo4jMessageHistory
17+
from neo4j_graphrag.types import LLMMessage
1818

1919

2020
def test_neo4j_message_history_add_message(driver: neo4j.Driver) -> None:

tests/unit/llm/test_vertexai_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
import pytest
2121
from neo4j_graphrag.exceptions import LLMGenerationError
22-
from neo4j_graphrag.llm.types import LLMMessage
2322
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
23+
from neo4j_graphrag.types import LLMMessage
2424
from vertexai.generative_models import Content, Part
2525

2626

tests/unit/test_graphrag.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
from neo4j_graphrag.generation.prompts import RagTemplate
2222
from neo4j_graphrag.generation.types import RagResultModel
2323
from neo4j_graphrag.llm import LLMResponse
24-
from neo4j_graphrag.llm.types import LLMMessage
2524
from neo4j_graphrag.message_history import InMemoryMessageHistory
26-
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
25+
from neo4j_graphrag.types import LLMMessage, RetrieverResult, RetrieverResultItem
2726

2827

2928
def test_graphrag_prompt_template() -> None:

tests/unit/test_message_history.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from unittest.mock import MagicMock
1616

1717
import pytest
18-
from neo4j_graphrag.llm.types import LLMMessage
1918
from neo4j_graphrag.message_history import InMemoryMessageHistory, Neo4jMessageHistory
19+
from neo4j_graphrag.types import LLMMessage
2020
from pydantic import ValidationError
2121

2222

0 commit comments

Comments
 (0)