Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): support async/streaming output mode in api layer #179

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class RAGRequest(BaseModel):
prompt.gremlin_generate_prompt,
description="Prompt for the Text2Gremlin query.",
)
stream: bool = Query(False, description="Enable streaming response")


# TODO: import the default value of prompt.* dynamically
Expand All @@ -58,6 +59,7 @@ class GraphRAGRequest(BaseModel):
prompt.gremlin_generate_prompt,
description="Prompt for the Text2Gremlin query.",
)
stream: bool = Query(False, description="Enable streaming response")


class GraphConfigRequest(BaseModel):
Expand Down Expand Up @@ -94,4 +96,4 @@ class RerankerConfigRequest(BaseModel):

class LogStreamRequest(BaseModel):
admin_token: Optional[str] = None
log_file: Optional[str] = "llm-server.log"
log_file: Optional[str] = "llm-server.log"
296 changes: 233 additions & 63 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, could tell me how u test the APIs? By directly request them?

The gradio UI loss the API link now

Before:
image

Now:
image

Maybe refer here: (Or Gradio's mount doc?)

Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.

import json
import asyncio
from typing import AsyncGenerator

from fastapi import status, APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
Expand All @@ -33,76 +36,243 @@


def rag_http_api(
router: APIRouter,
rag_answer_func,
graph_rag_recall_func,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
router: APIRouter,
rag_answer_func,
graph_rag_recall_func,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
rag_answer_stream_func=None,
graph_rag_recall_stream_func=None,
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
result = rag_answer_func(
text=req.query,
raw_answer=req.raw_answer,
vector_only_answer=req.vector_only,
graph_only_answer=req.graph_only,
graph_vector_answer=req.graph_vector_answer,
graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
answer_prompt=req.answer_prompt or prompt.answer_prompt,
keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt,
gremlin_tmpl_num=req.gremlin_tmpl_num,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
)
# TODO: we need more info in the response for users to understand the query logic
return {
"query": req.query,
**{
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)
},
}
async def stream_rag_answer(
text,
raw_answer,
vector_only_answer,
graph_only_answer,
graph_vector_answer,
graph_ratio,
rerank_method,
near_neighbor_first,
custom_related_information,
answer_prompt,
keywords_extract_prompt,
gremlin_tmpl_num,
gremlin_prompt,
) -> AsyncGenerator[str, None]:
"""
Stream the RAG answer results
"""
if rag_answer_stream_func:
# If a streaming-specific function exists, use it
async for chunk in rag_answer_stream_func(
text=text,
raw_answer=raw_answer,
vector_only_answer=vector_only_answer,
graph_only_answer=graph_only_answer,
graph_vector_answer=graph_vector_answer,
graph_ratio=graph_ratio,
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
answer_prompt=answer_prompt,
keywords_extract_prompt=keywords_extract_prompt,
gremlin_tmpl_num=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
):
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
else:
# Otherwise, use the normal function but adapt it for streaming
# by sending the entire result at once
result = rag_answer_func(
text=text,
raw_answer=raw_answer,
vector_only_answer=vector_only_answer,
graph_only_answer=graph_only_answer,
graph_vector_answer=graph_vector_answer,
graph_ratio=graph_ratio,
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
answer_prompt=answer_prompt,
keywords_extract_prompt=keywords_extract_prompt,
gremlin_tmpl_num=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
)

@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(req: GraphRAGRequest):
try:
result = graph_rag_recall_func(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
response_data = {
"query": text,
**{
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if eval(key) # Convert string to boolean
},
}

yield f"data: {json.dumps(response_data)}\n\n"
# Signal end of stream
yield "data: [DONE]\n\n"

async def stream_graph_rag_recall(
query,
gremlin_tmpl_num,
rerank_method,
near_neighbor_first,
custom_related_information,
gremlin_prompt,
) -> AsyncGenerator[str, None]:
"""
Stream the graph RAG recall results
"""
if graph_rag_recall_stream_func:
# If a streaming-specific function exists, use it
async for chunk in graph_rag_recall_stream_func(
query=query,
gremlin_tmpl_num=gremlin_tmpl_num,
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
gremlin_prompt=gremlin_prompt,
):
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
else:
# Otherwise, use the normal function but adapt it for streaming
try:
result = graph_rag_recall_func(
query=query,
gremlin_tmpl_num=gremlin_tmpl_num,
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
gremlin_prompt=gremlin_prompt,
)

if isinstance(result, dict):
params = [
"query",
"keywords",
"match_vids",
"graph_result_flag",
"gremlin",
"graph_result",
"vertex_degree_list",
]
user_result = {key: result[key] for key in params if key in result}
yield f"data: {json.dumps({'graph_recall': user_result})}\n\n"
else:
# Note: Maybe only for qianfan/wenxin
yield f"data: {json.dumps({'graph_recall': json.dumps(result)})}\n\n"

# Signal end of stream
yield "data: [DONE]\n\n"

except TypeError as e:
log.error("TypeError in stream_graph_rag_recall: %s", e)
yield f"data: {json.dumps({'error': str(e), 'status': 400})}\n\n"
except Exception as e:
log.error("Unexpected error occurred: %s", e)
yield f"data: {json.dumps({'error': 'An unexpected error occurred.', 'status': 500})}\n\n"

@router.post("/rag", status_code=status.HTTP_200_OK)
async def rag_answer_api(req: RAGRequest):
if req.stream:
# Return a streaming response
return StreamingResponse(
stream_rag_answer(
text=req.query,
raw_answer=req.raw_answer,
vector_only_answer=req.vector_only,
graph_only_answer=req.graph_only,
graph_vector_answer=req.graph_vector_answer,
graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
answer_prompt=req.answer_prompt or prompt.answer_prompt,
keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt,
gremlin_tmpl_num=req.gremlin_tmpl_num,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
),
media_type="text/event-stream",
)
else:
# Synchronous response (original behavior)
result = rag_answer_func(
text=req.query,
raw_answer=req.raw_answer,
vector_only_answer=req.vector_only,
graph_only_answer=req.graph_only,
graph_vector_answer=req.graph_vector_answer,
graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
answer_prompt=req.answer_prompt or prompt.answer_prompt,
keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt,
gremlin_tmpl_num=req.gremlin_tmpl_num,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
)
# TODO: we need more info in the response for users to understand the query logic
return {
"query": req.query,
**{
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)
},
}

if isinstance(result, dict):
params = [
"query",
"keywords",
"match_vids",
"graph_result_flag",
"gremlin",
"graph_result",
"vertex_degree_list",
]
user_result = {key: result[key] for key in params if key in result}
return {"graph_recall": user_result}
# Note: Maybe only for qianfan/wenxin
return {"graph_recall": json.dumps(result)}

except TypeError as e:
log.error("TypeError in graph_rag_recall_api: %s", e)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
log.error("Unexpected error occurred: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred."
) from e
@router.post("/rag/graph", status_code=status.HTTP_200_OK)
async def graph_rag_recall_api(req: GraphRAGRequest):
if req.stream:
# Return a streaming response
return StreamingResponse(
stream_graph_rag_recall(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
),
media_type="text/event-stream",
)
else:
# Synchronous response (original behavior)
try:
result = graph_rag_recall_func(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
)

if isinstance(result, dict):
params = [
"query",
"keywords",
"match_vids",
"graph_result_flag",
"gremlin",
"graph_result",
"vertex_degree_list",
]
user_result = {key: result[key] for key in params if key in result}
return {"graph_recall": user_result}
# Note: Maybe only for qianfan/wenxin
return {"graph_recall": json.dumps(result)}

except TypeError as e:
log.error("TypeError in graph_rag_recall_api: %s", e)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
log.error("Unexpected error occurred: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred."
) from e

@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
Expand Down Expand Up @@ -145,4 +315,4 @@
res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http")
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
Loading