16
16
17
17
import logging
18
18
import warnings
19
- from typing import Any , Optional
19
+ from typing import Any , List , Optional , Union
20
20
21
21
from pydantic import ValidationError
22
22
28
28
from neo4j_graphrag .generation .types import RagInitModel , RagResultModel , RagSearchModel
29
29
from neo4j_graphrag .llm import LLMInterface
30
30
from neo4j_graphrag .llm .types import LLMMessage
31
+ from neo4j_graphrag .message_history import MessageHistory
31
32
from neo4j_graphrag .retrievers .base import Retriever
32
33
from neo4j_graphrag .types import RetrieverResult
33
34
@@ -84,7 +85,7 @@ def __init__(
84
85
def search (
85
86
self ,
86
87
query_text : str = "" ,
87
- message_history : Optional [list [ LLMMessage ]] = None ,
88
+ message_history : Optional [Union [ List [ LLMMessage ], MessageHistory ]] = None ,
88
89
examples : str = "" ,
89
90
retriever_config : Optional [dict [str , Any ]] = None ,
90
91
return_context : bool | None = None ,
@@ -102,7 +103,8 @@ def search(
102
103
103
104
Args:
104
105
query_text (str): The user question.
105
- message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
106
+ message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
107
+ with each message having a specific role assigned.
106
108
examples (str): Examples added to the LLM prompt.
107
109
retriever_config (Optional[dict]): Parameters passed to the retriever.
108
110
search method; e.g.: top_k
@@ -127,7 +129,9 @@ def search(
127
129
)
128
130
except ValidationError as e :
129
131
raise SearchValidationError (e .errors ())
130
- query = self .build_query (validated_data .query_text , message_history )
132
+ if isinstance (message_history , MessageHistory ):
133
+ message_history = message_history .messages
134
+ query = self ._build_query (validated_data .query_text , message_history )
131
135
retriever_result : RetrieverResult = self .retriever .search (
132
136
query_text = query , ** validated_data .retriever_config
133
137
)
@@ -147,12 +151,14 @@ def search(
147
151
result ["retriever_result" ] = retriever_result
148
152
return RagResultModel (** result )
149
153
150
- def build_query (
151
- self , query_text : str , message_history : Optional [list [LLMMessage ]] = None
154
+ def _build_query (
155
+ self ,
156
+ query_text : str ,
157
+ message_history : Optional [List [LLMMessage ]] = None ,
152
158
) -> str :
153
159
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
154
160
if message_history :
155
- summarization_prompt = self .chat_summary_prompt (
161
+ summarization_prompt = self ._chat_summary_prompt (
156
162
message_history = message_history
157
163
)
158
164
summary = self .llm .invoke (
@@ -162,10 +168,9 @@ def build_query(
162
168
return self .conversation_prompt (summary = summary , current_query = query_text )
163
169
return query_text
164
170
165
- def chat_summary_prompt (self , message_history : list [LLMMessage ]) -> str :
171
+ def _chat_summary_prompt (self , message_history : List [LLMMessage ]) -> str :
166
172
message_list = [
167
- ": " .join ([f"{ value } " for _ , value in message .items ()])
168
- for message in message_history
173
+ f"{ message ['role' ]} : { message ['content' ]} " for message in message_history
169
174
]
170
175
history = "\n " .join (message_list )
171
176
return f"""
0 commit comments