Skip to content

Commit 7ae5d6f

Browse files
vichayturenChenZiHong-Gavinimbajin
authored
feat(llm): support async streaming output in RAG answer block (#190)
follow #172 In order to achieve asynchronization, we compromised by changing `gremlin_generate_operator` to a synchronous generation mode. This can be changed back to an asynchronous mode after achieving full asynchronization in the subsequent agentization process. --------- Co-authored-by: chenzihong <[email protected]> Co-authored-by: chenzihong <[email protected]> Co-authored-by: imbajin <[email protected]>
1 parent ca28faf commit 7ae5d6f

File tree

9 files changed

+484
-104
lines changed

9 files changed

+484
-104
lines changed

hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ async def log_stream(log_path: str, lines: int = 125):
5050
def read_llm_server_log(lines=250):
5151
log_path = "logs/llm-server.log"
5252
try:
53-
with open(log_path, "r", encoding='utf-8') as f:
53+
with open(log_path, "r", encoding='utf-8', errors="replace") as f:
5454
return ''.join(deque(f, maxlen=lines))
5555
except FileNotFoundError:
5656
log.critical("Log file not found: %s", log_path)

hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py

Lines changed: 127 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
# pylint: disable=E1101
1919

2020
import os
21-
from typing import Tuple, Literal, Optional
21+
from typing import AsyncGenerator, Tuple, Literal, Optional
2222

2323
import gradio as gr
2424
import pandas as pd
2525
from gradio.utils import NamedString
2626

2727
from hugegraph_llm.config import resource_path, prompt, huge_settings, llm_settings
2828
from hugegraph_llm.operators.graph_rag_task import RAGPipeline
29+
from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
2930
from hugegraph_llm.utils.log import log
3031

3132

@@ -56,25 +57,10 @@ def rag_answer(
5657
4. Synthesize the final answer.
5758
5. Run the pipeline and return the results.
5859
"""
59-
60-
gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
61-
should_update_prompt = (
62-
prompt.default_question != text
63-
or prompt.answer_prompt != answer_prompt
64-
or prompt.keywords_extract_prompt != keywords_extract_prompt
65-
or prompt.gremlin_generate_prompt != gremlin_prompt
66-
or prompt.custom_rerank_info != custom_related_information
67-
)
68-
if should_update_prompt:
69-
prompt.custom_rerank_info = custom_related_information
70-
prompt.default_question = text
71-
prompt.answer_prompt = answer_prompt
72-
prompt.keywords_extract_prompt = keywords_extract_prompt
73-
prompt.gremlin_generate_prompt = gremlin_prompt
74-
prompt.update_yaml_file()
75-
76-
vector_search = vector_only_answer or graph_vector_answer
77-
graph_search = graph_only_answer or graph_vector_answer
60+
graph_search, gremlin_prompt, vector_search = update_ui_configs(answer_prompt, custom_related_information,
61+
graph_only_answer, graph_vector_answer,
62+
gremlin_prompt, keywords_extract_prompt, text,
63+
vector_only_answer)
7864
if raw_answer is False and not vector_search and not graph_search:
7965
gr.Warning("Please select at least one generate mode.")
8066
return "", "", "", ""
@@ -121,6 +107,106 @@ def rag_answer(
121107
raise gr.Error(f"An unexpected error occurred: {str(e)}")
122108

123109

110+
def update_ui_configs(answer_prompt, custom_related_information, graph_only_answer, graph_vector_answer, gremlin_prompt,
111+
keywords_extract_prompt, text, vector_only_answer):
112+
gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
113+
should_update_prompt = (
114+
prompt.default_question != text
115+
or prompt.answer_prompt != answer_prompt
116+
or prompt.keywords_extract_prompt != keywords_extract_prompt
117+
or prompt.gremlin_generate_prompt != gremlin_prompt
118+
or prompt.custom_rerank_info != custom_related_information
119+
)
120+
if should_update_prompt:
121+
prompt.custom_rerank_info = custom_related_information
122+
prompt.default_question = text
123+
prompt.answer_prompt = answer_prompt
124+
prompt.keywords_extract_prompt = keywords_extract_prompt
125+
prompt.gremlin_generate_prompt = gremlin_prompt
126+
prompt.update_yaml_file()
127+
vector_search = vector_only_answer or graph_vector_answer
128+
graph_search = graph_only_answer or graph_vector_answer
129+
return graph_search, gremlin_prompt, vector_search
130+
131+
132+
async def rag_answer_streaming(
133+
text: str,
134+
raw_answer: bool,
135+
vector_only_answer: bool,
136+
graph_only_answer: bool,
137+
graph_vector_answer: bool,
138+
graph_ratio: float,
139+
rerank_method: Literal["bleu", "reranker"],
140+
near_neighbor_first: bool,
141+
custom_related_information: str,
142+
answer_prompt: str,
143+
keywords_extract_prompt: str,
144+
gremlin_tmpl_num: Optional[int] = 2,
145+
gremlin_prompt: Optional[str] = None,
146+
) -> AsyncGenerator[Tuple[str, str, str, str], None]:
147+
"""
148+
Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
149+
1. Initialize the RAGPipeline.
150+
2. Select vector search or graph search based on parameters.
151+
3. Merge, deduplicate, and rerank the results.
152+
4. Synthesize the final answer.
153+
5. Run the pipeline and return the results.
154+
"""
155+
156+
graph_search, gremlin_prompt, vector_search = update_ui_configs(answer_prompt, custom_related_information,
157+
graph_only_answer, graph_vector_answer,
158+
gremlin_prompt, keywords_extract_prompt, text,
159+
vector_only_answer)
160+
if raw_answer is False and not vector_search and not graph_search:
161+
gr.Warning("Please select at least one generate mode.")
162+
yield "", "", "", ""
163+
return
164+
165+
rag = RAGPipeline()
166+
if vector_search:
167+
rag.query_vector_index()
168+
if graph_search:
169+
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
170+
huge_settings.graph_name
171+
).query_graphdb(
172+
num_gremlin_generate_example=gremlin_tmpl_num,
173+
gremlin_prompt=gremlin_prompt,
174+
)
175+
rag.merge_dedup_rerank(
176+
graph_ratio,
177+
rerank_method,
178+
near_neighbor_first,
179+
)
180+
# rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt)
181+
182+
try:
183+
context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search)
184+
if context.get("switch_to_bleu"):
185+
gr.Warning("Online reranker fails, automatically switches to local bleu rerank.")
186+
answer_synthesize = AnswerSynthesize(
187+
raw_answer=raw_answer,
188+
vector_only_answer=vector_only_answer,
189+
graph_only_answer=graph_only_answer,
190+
graph_vector_answer=graph_vector_answer,
191+
prompt_template=answer_prompt,
192+
)
193+
async for context in answer_synthesize.run_streaming(context):
194+
if context.get("switch_to_bleu"):
195+
gr.Warning("Online reranker fails, automatically switches to local bleu rerank.")
196+
yield (
197+
context.get("raw_answer", ""),
198+
context.get("vector_only_answer", ""),
199+
context.get("graph_only_answer", ""),
200+
context.get("graph_vector_answer", ""),
201+
)
202+
except ValueError as e:
203+
log.critical(e)
204+
raise gr.Error(str(e))
205+
except Exception as e:
206+
log.critical(e)
207+
raise gr.Error(f"An unexpected error occurred: {str(e)}")
208+
209+
124210
def create_rag_block():
125211
# pylint: disable=R0915 (too-many-statements),C0301
126212
gr.Markdown("""## 1. HugeGraph RAG Query""")
@@ -130,13 +216,17 @@ def create_rag_block():
130216

131217
# TODO: Only support inline formula now. Should support block formula
132218
gr.Markdown("Basic LLM Answer", elem_classes="output-box-label")
133-
raw_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}])
219+
raw_out = gr.Markdown(elem_classes="output-box", show_copy_button=True,
220+
latex_delimiters=[{"left": "$", "right": "$", "display": False}])
134221
gr.Markdown("Vector-only Answer", elem_classes="output-box-label")
135-
vector_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}])
222+
vector_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True,
223+
latex_delimiters=[{"left": "$", "right": "$", "display": False}])
136224
gr.Markdown("Graph-only Answer", elem_classes="output-box-label")
137-
graph_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}])
225+
graph_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True,
226+
latex_delimiters=[{"left": "$", "right": "$", "display": False}])
138227
gr.Markdown("Graph-Vector Answer", elem_classes="output-box-label")
139-
graph_vector_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}])
228+
graph_vector_out = gr.Markdown(elem_classes="output-box", show_copy_button=True,
229+
latex_delimiters=[{"left": "$", "right": "$", "display": False}])
140230

141231
answer_prompt_input = gr.Textbox(
142232
value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7
@@ -184,7 +274,7 @@ def toggle_slider(enable):
184274
btn = gr.Button("Answer Question", variant="primary")
185275

186276
btn.click( # pylint: disable=no-member
187-
fn=rag_answer,
277+
fn=rag_answer_streaming,
188278
inputs=[
189279
inp,
190280
raw_radio,
@@ -254,13 +344,13 @@ def several_rag_answer(
254344
is_vector_only_answer: bool,
255345
is_graph_only_answer: bool,
256346
is_graph_vector_answer: bool,
257-
graph_ratio: float,
258-
rerank_method: Literal["bleu", "reranker"],
259-
near_neighbor_first: bool,
260-
custom_related_information: str,
347+
graph_ratio_ui: float,
348+
rerank_method_ui: Literal["bleu", "reranker"],
349+
near_neighbor_first_ui: bool,
350+
custom_related_information_ui: str,
261351
answer_prompt: str,
262352
keywords_extract_prompt: str,
263-
answer_max_line_count: int = 1,
353+
answer_max_line_count_ui: int = 1,
264354
progress=gr.Progress(track_tqdm=True),
265355
):
266356
df = pd.read_excel(questions_path, dtype=str)
@@ -273,10 +363,10 @@ def several_rag_answer(
273363
is_vector_only_answer,
274364
is_graph_only_answer,
275365
is_graph_vector_answer,
276-
graph_ratio,
277-
rerank_method,
278-
near_neighbor_first,
279-
custom_related_information,
366+
graph_ratio_ui,
367+
rerank_method_ui,
368+
near_neighbor_first_ui,
369+
custom_related_information_ui,
280370
answer_prompt,
281371
keywords_extract_prompt,
282372
)
@@ -285,9 +375,9 @@ def several_rag_answer(
285375
df.at[index, "Graph-only Answer"] = graph_only_answer
286376
df.at[index, "Graph-Vector Answer"] = graph_vector_answer
287377
progress((index + 1, total_rows))
288-
answers_path = os.path.join(resource_path, "demo", "questions_answers.xlsx")
289-
df.to_excel(answers_path, index=False)
290-
return df.head(answer_max_line_count), answers_path
378+
answers_path_ui = os.path.join(resource_path, "demo", "questions_answers.xlsx")
379+
df.to_excel(answers_path_ui, index=False)
380+
return df.head(answer_max_line_count_ui), answers_path_ui
291381

292382
with gr.Row():
293383
with gr.Column():

hugegraph-llm/src/hugegraph_llm/models/llms/base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
from abc import ABC, abstractmethod
19-
from typing import Any, List, Optional, Callable, Dict
19+
from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, Dict
2020

2121

2222
class BaseLLM(ABC):
@@ -43,8 +43,17 @@ def generate_streaming(
4343
self,
4444
messages: Optional[List[Dict[str, Any]]] = None,
4545
prompt: Optional[str] = None,
46-
on_token_callback: Callable = None,
47-
) -> List[Any]:
46+
on_token_callback: Optional[Callable] = None,
47+
) -> Generator[str, None, None]:
48+
"""Comment"""
49+
50+
@abstractmethod
51+
async def agenerate_streaming(
52+
self,
53+
messages: Optional[List[Dict[str, Any]]] = None,
54+
prompt: Optional[str] = None,
55+
on_token_callback: Optional[Callable] = None,
56+
) -> AsyncGenerator[str, None]:
4857
"""Comment"""
4958

5059
@abstractmethod

hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Callable, List, Optional, Dict, Any
18+
from typing import Callable, List, Optional, Dict, Any, AsyncGenerator
1919

2020
import tiktoken
2121
from litellm import completion, acompletion
@@ -137,6 +137,35 @@ def generate_streaming(
137137
log.error("Error in streaming LiteLLM call: %s", e)
138138
return f"Error: {str(e)}"
139139

140+
async def agenerate_streaming(
141+
self,
142+
messages: Optional[List[Dict[str, Any]]] = None,
143+
prompt: Optional[str] = None,
144+
on_token_callback: Optional[Callable] = None,
145+
) -> AsyncGenerator[str, None]:
146+
"""Generate a response to the query messages/prompt in async streaming mode."""
147+
if messages is None:
148+
assert prompt is not None, "Messages or prompt must be provided."
149+
messages = [{"role": "user", "content": prompt}]
150+
try:
151+
response = await acompletion(
152+
model=self.model,
153+
messages=messages,
154+
temperature=self.temperature,
155+
max_tokens=self.max_tokens,
156+
api_key=self.api_key,
157+
base_url=self.api_base,
158+
stream=True,
159+
)
160+
async for chunk in response:
161+
if chunk.choices[0].delta.content:
162+
if on_token_callback:
163+
on_token_callback(chunk)
164+
yield chunk.choices[0].delta.content
165+
except (RateLimitError, BudgetExceededError, APIError) as e:
166+
log.error("Error in async streaming LiteLLM call: %s", e)
167+
yield f"Error: {str(e)}"
168+
140169
def num_tokens_from_string(self, string: str) -> int:
141170
"""Get token count from string."""
142171
try:

hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
import json
20-
from typing import Any, List, Optional, Callable, Dict
20+
from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, Dict
2121

2222
import ollama
2323
from retry import retry
@@ -89,22 +89,49 @@ def generate_streaming(
8989
self,
9090
messages: Optional[List[Dict[str, Any]]] = None,
9191
prompt: Optional[str] = None,
92-
on_token_callback: Callable = None,
93-
) -> List[Any]:
92+
on_token_callback: Optional[Callable] = None,
93+
) -> Generator[str, None, None]:
9494
"""Comment"""
9595
if messages is None:
9696
assert prompt is not None, "Messages or prompt must be provided."
9797
messages = [{"role": "user", "content": prompt}]
98-
stream = self.client.chat(
98+
99+
for chunk in self.client.chat(
99100
model=self.model,
100101
messages=messages,
101102
stream=True
102-
)
103-
chunks = []
104-
for chunk in stream:
105-
on_token_callback(chunk["message"]["content"])
106-
chunks.append(chunk)
107-
return chunks
103+
):
104+
token = chunk["message"]["content"]
105+
if on_token_callback:
106+
on_token_callback(token)
107+
yield token
108+
109+
async def agenerate_streaming(
110+
self,
111+
messages: Optional[List[Dict[str, Any]]] = None,
112+
prompt: Optional[str] = None,
113+
on_token_callback: Optional[Callable] = None,
114+
) -> AsyncGenerator[str, None]:
115+
"""Comment"""
116+
if messages is None:
117+
assert prompt is not None, "Messages or prompt must be provided."
118+
messages = [{"role": "user", "content": prompt}]
119+
120+
try:
121+
async_generator = await self.async_client.chat(
122+
model=self.model,
123+
messages=messages,
124+
stream=True
125+
)
126+
async for chunk in async_generator:
127+
token = chunk.get("message", {}).get("content", "")
128+
if on_token_callback:
129+
on_token_callback(token)
130+
yield token
131+
except Exception as e:
132+
print(f"Retrying LLM call {e}")
133+
raise e
134+
108135

109136
def num_tokens_from_string(
110137
self,

0 commit comments

Comments
 (0)