Skip to content

Commit

Permalink
feat(llm): support async streaming output in RAG answer block (#190)
Browse files Browse the repository at this point in the history
follow #172 

In order to achieve asynchronization, we compromised by changing `gremlin_generate_operator` to a synchronous generation mode. This can be changed back to an asynchronous mode after achieving full asynchronization in the subsequent agentization process.

---------

Co-authored-by: chenzihong <[email protected]>
Co-authored-by: chenzihong <[email protected]>
Co-authored-by: imbajin <[email protected]>
  • Loading branch information
4 people authored Mar 6, 2025
1 parent ca28faf commit 7ae5d6f
Show file tree
Hide file tree
Showing 9 changed files with 484 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def log_stream(log_path: str, lines: int = 125):
def read_llm_server_log(lines=250):
log_path = "logs/llm-server.log"
try:
with open(log_path, "r", encoding='utf-8') as f:
with open(log_path, "r", encoding='utf-8', errors="replace") as f:
return ''.join(deque(f, maxlen=lines))
except FileNotFoundError:
log.critical("Log file not found: %s", log_path)
Expand Down
164 changes: 127 additions & 37 deletions hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
# pylint: disable=E1101

import os
from typing import Tuple, Literal, Optional
from typing import AsyncGenerator, Tuple, Literal, Optional

import gradio as gr
import pandas as pd
from gradio.utils import NamedString

from hugegraph_llm.config import resource_path, prompt, huge_settings, llm_settings
from hugegraph_llm.operators.graph_rag_task import RAGPipeline
from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
from hugegraph_llm.utils.log import log


Expand Down Expand Up @@ -56,25 +57,10 @@ def rag_answer(
4. Synthesize the final answer.
5. Run the pipeline and return the results.
"""

gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
should_update_prompt = (
prompt.default_question != text
or prompt.answer_prompt != answer_prompt
or prompt.keywords_extract_prompt != keywords_extract_prompt
or prompt.gremlin_generate_prompt != gremlin_prompt
or prompt.custom_rerank_info != custom_related_information
)
if should_update_prompt:
prompt.custom_rerank_info = custom_related_information
prompt.default_question = text
prompt.answer_prompt = answer_prompt
prompt.keywords_extract_prompt = keywords_extract_prompt
prompt.gremlin_generate_prompt = gremlin_prompt
prompt.update_yaml_file()

vector_search = vector_only_answer or graph_vector_answer
graph_search = graph_only_answer or graph_vector_answer
graph_search, gremlin_prompt, vector_search = update_ui_configs(answer_prompt, custom_related_information,
graph_only_answer, graph_vector_answer,
gremlin_prompt, keywords_extract_prompt, text,
vector_only_answer)
if raw_answer is False and not vector_search and not graph_search:
gr.Warning("Please select at least one generate mode.")
return "", "", "", ""
Expand Down Expand Up @@ -121,6 +107,106 @@ def rag_answer(
raise gr.Error(f"An unexpected error occurred: {str(e)}")


def update_ui_configs(answer_prompt, custom_related_information, graph_only_answer, graph_vector_answer, gremlin_prompt,
keywords_extract_prompt, text, vector_only_answer):
gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
should_update_prompt = (
prompt.default_question != text
or prompt.answer_prompt != answer_prompt
or prompt.keywords_extract_prompt != keywords_extract_prompt
or prompt.gremlin_generate_prompt != gremlin_prompt
or prompt.custom_rerank_info != custom_related_information
)
if should_update_prompt:
prompt.custom_rerank_info = custom_related_information
prompt.default_question = text
prompt.answer_prompt = answer_prompt
prompt.keywords_extract_prompt = keywords_extract_prompt
prompt.gremlin_generate_prompt = gremlin_prompt
prompt.update_yaml_file()
vector_search = vector_only_answer or graph_vector_answer
graph_search = graph_only_answer or graph_vector_answer
return graph_search, gremlin_prompt, vector_search


async def rag_answer_streaming(
text: str,
raw_answer: bool,
vector_only_answer: bool,
graph_only_answer: bool,
graph_vector_answer: bool,
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
gremlin_tmpl_num: Optional[int] = 2,
gremlin_prompt: Optional[str] = None,
) -> AsyncGenerator[Tuple[str, str, str, str], None]:
"""
Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
1. Initialize the RAGPipeline.
2. Select vector search or graph search based on parameters.
3. Merge, deduplicate, and rerank the results.
4. Synthesize the final answer.
5. Run the pipeline and return the results.
"""

graph_search, gremlin_prompt, vector_search = update_ui_configs(answer_prompt, custom_related_information,
graph_only_answer, graph_vector_answer,
gremlin_prompt, keywords_extract_prompt, text,
vector_only_answer)
if raw_answer is False and not vector_search and not graph_search:
gr.Warning("Please select at least one generate mode.")
yield "", "", "", ""
return

rag = RAGPipeline()
if vector_search:
rag.query_vector_index()
if graph_search:
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
huge_settings.graph_name
).query_graphdb(
num_gremlin_generate_example=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
)
rag.merge_dedup_rerank(
graph_ratio,
rerank_method,
near_neighbor_first,
)
# rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt)

try:
context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search)
if context.get("switch_to_bleu"):
gr.Warning("Online reranker fails, automatically switches to local bleu rerank.")
answer_synthesize = AnswerSynthesize(
raw_answer=raw_answer,
vector_only_answer=vector_only_answer,
graph_only_answer=graph_only_answer,
graph_vector_answer=graph_vector_answer,
prompt_template=answer_prompt,
)
async for context in answer_synthesize.run_streaming(context):
if context.get("switch_to_bleu"):
gr.Warning("Online reranker fails, automatically switches to local bleu rerank.")
yield (
context.get("raw_answer", ""),
context.get("vector_only_answer", ""),
context.get("graph_only_answer", ""),
context.get("graph_vector_answer", ""),
)
except ValueError as e:
log.critical(e)
raise gr.Error(str(e))
except Exception as e:
log.critical(e)
raise gr.Error(f"An unexpected error occurred: {str(e)}")


def create_rag_block():
# pylint: disable=R0915 (too-many-statements),C0301
gr.Markdown("""## 1. HugeGraph RAG Query""")
Expand All @@ -130,13 +216,17 @@ def create_rag_block():

# TODO: Only support inline formula now. Should support block formula
gr.Markdown("Basic LLM Answer", elem_classes="output-box-label")
raw_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}])
raw_out = gr.Markdown(elem_classes="output-box", show_copy_button=True,
latex_delimiters=[{"left": "$", "right": "$", "display": False}])
gr.Markdown("Vector-only Answer", elem_classes="output-box-label")
vector_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}])
vector_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True,
latex_delimiters=[{"left": "$", "right": "$", "display": False}])
gr.Markdown("Graph-only Answer", elem_classes="output-box-label")
graph_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}])
graph_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True,
latex_delimiters=[{"left": "$", "right": "$", "display": False}])
gr.Markdown("Graph-Vector Answer", elem_classes="output-box-label")
graph_vector_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}])
graph_vector_out = gr.Markdown(elem_classes="output-box", show_copy_button=True,
latex_delimiters=[{"left": "$", "right": "$", "display": False}])

answer_prompt_input = gr.Textbox(
value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7
Expand Down Expand Up @@ -184,7 +274,7 @@ def toggle_slider(enable):
btn = gr.Button("Answer Question", variant="primary")

btn.click( # pylint: disable=no-member
fn=rag_answer,
fn=rag_answer_streaming,
inputs=[
inp,
raw_radio,
Expand Down Expand Up @@ -254,13 +344,13 @@ def several_rag_answer(
is_vector_only_answer: bool,
is_graph_only_answer: bool,
is_graph_vector_answer: bool,
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str,
graph_ratio_ui: float,
rerank_method_ui: Literal["bleu", "reranker"],
near_neighbor_first_ui: bool,
custom_related_information_ui: str,
answer_prompt: str,
keywords_extract_prompt: str,
answer_max_line_count: int = 1,
answer_max_line_count_ui: int = 1,
progress=gr.Progress(track_tqdm=True),
):
df = pd.read_excel(questions_path, dtype=str)
Expand All @@ -273,10 +363,10 @@ def several_rag_answer(
is_vector_only_answer,
is_graph_only_answer,
is_graph_vector_answer,
graph_ratio,
rerank_method,
near_neighbor_first,
custom_related_information,
graph_ratio_ui,
rerank_method_ui,
near_neighbor_first_ui,
custom_related_information_ui,
answer_prompt,
keywords_extract_prompt,
)
Expand All @@ -285,9 +375,9 @@ def several_rag_answer(
df.at[index, "Graph-only Answer"] = graph_only_answer
df.at[index, "Graph-Vector Answer"] = graph_vector_answer
progress((index + 1, total_rows))
answers_path = os.path.join(resource_path, "demo", "questions_answers.xlsx")
df.to_excel(answers_path, index=False)
return df.head(answer_max_line_count), answers_path
answers_path_ui = os.path.join(resource_path, "demo", "questions_answers.xlsx")
df.to_excel(answers_path_ui, index=False)
return df.head(answer_max_line_count_ui), answers_path_ui

with gr.Row():
with gr.Column():
Expand Down
15 changes: 12 additions & 3 deletions hugegraph-llm/src/hugegraph_llm/models/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Callable, Dict
from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, Dict


class BaseLLM(ABC):
Expand All @@ -43,8 +43,17 @@ def generate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
) -> List[Any]:
on_token_callback: Optional[Callable] = None,
) -> Generator[str, None, None]:
"""Comment"""

@abstractmethod
async def agenerate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Optional[Callable] = None,
) -> AsyncGenerator[str, None]:
"""Comment"""

@abstractmethod
Expand Down
31 changes: 30 additions & 1 deletion hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Callable, List, Optional, Dict, Any
from typing import Callable, List, Optional, Dict, Any, AsyncGenerator

import tiktoken
from litellm import completion, acompletion
Expand Down Expand Up @@ -137,6 +137,35 @@ def generate_streaming(
log.error("Error in streaming LiteLLM call: %s", e)
return f"Error: {str(e)}"

async def agenerate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Optional[Callable] = None,
) -> AsyncGenerator[str, None]:
"""Generate a response to the query messages/prompt in async streaming mode."""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
try:
response = await acompletion(
model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
api_key=self.api_key,
base_url=self.api_base,
stream=True,
)
async for chunk in response:
if chunk.choices[0].delta.content:
if on_token_callback:
on_token_callback(chunk)
yield chunk.choices[0].delta.content
except (RateLimitError, BudgetExceededError, APIError) as e:
log.error("Error in async streaming LiteLLM call: %s", e)
yield f"Error: {str(e)}"

def num_tokens_from_string(self, string: str) -> int:
"""Get token count from string."""
try:
Expand Down
47 changes: 37 additions & 10 deletions hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


import json
from typing import Any, List, Optional, Callable, Dict
from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, Dict

import ollama
from retry import retry
Expand Down Expand Up @@ -89,22 +89,49 @@ def generate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
) -> List[Any]:
on_token_callback: Optional[Callable] = None,
) -> Generator[str, None, None]:
"""Comment"""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
stream = self.client.chat(

for chunk in self.client.chat(
model=self.model,
messages=messages,
stream=True
)
chunks = []
for chunk in stream:
on_token_callback(chunk["message"]["content"])
chunks.append(chunk)
return chunks
):
token = chunk["message"]["content"]
if on_token_callback:
on_token_callback(token)
yield token

async def agenerate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Optional[Callable] = None,
) -> AsyncGenerator[str, None]:
"""Comment"""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]

try:
async_generator = await self.async_client.chat(
model=self.model,
messages=messages,
stream=True
)
async for chunk in async_generator:
token = chunk.get("message", {}).get("content", "")
if on_token_callback:
on_token_callback(token)
yield token
except Exception as e:
print(f"Retrying LLM call {e}")
raise e


def num_tokens_from_string(
self,
Expand Down
Loading

0 comments on commit 7ae5d6f

Please sign in to comment.