Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): support async/streaming output mode in api layer #179

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ jobs:
- name: Prepare HugeGraph Server Environment
run: |
docker run -d --name=graph -p 8080:8080 -e PASSWORD=admin hugegraph/hugegraph:1.3.0
sleep 1
# wait server init-done (avoid some test error:)
sleep 5

- uses: actions/checkout@v4

Expand Down
3 changes: 2 additions & 1 deletion hugegraph-llm/src/hugegraph_llm/api/admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def admin_http_api(router: APIRouter, log_stream):
@router.post("/logs", status_code=status.HTTP_200_OK)
async def log_stream_api(req: LogStreamRequest):
if admin_settings.admin_token != req.admin_token:
raise generate_response(RAGResponse(status_code=status.HTTP_403_FORBIDDEN, message="Invalid admin_token")) #pylint: disable=E0702
raise generate_response(RAGResponse(status_code=status.HTTP_403_FORBIDDEN,
message="Invalid admin_token")) #pylint: disable=E0702
log_path = os.path.join("logs", req.log_file)

# Create a StreamingResponse that reads from the log stream generator
Expand Down
85 changes: 85 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/api/config_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from fastapi import status, APIRouter

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import llm_settings


async def config_http_api(
router: APIRouter,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
):
@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
async def graph_config_api(req: GraphConfigRequest):
res = await apply_graph_conf(req.ip, req.port, req.name,
req.user, req.pwd, req.gs, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/llm", status_code=status.HTTP_201_CREATED)
async def llm_config_api(req: LLMConfigRequest):
llm_settings.llm_type = req.llm_type

if req.llm_type == "openai":
res = await apply_llm_conf(req.api_key, req.api_base, req.language_model,
req.max_tokens, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = await apply_llm_conf(req.api_key, req.secret_key, req.language_model,
None, origin_call="http")
else:
res = await apply_llm_conf(req.host, req.port, req.language_model,
None, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
async def embedding_config_api(req: LLMConfigRequest):
llm_settings.embedding_type = req.llm_type

if req.llm_type == "openai":
res = await apply_embedding_conf(req.api_key, req.api_base,
req.language_model, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = await apply_embedding_conf(req.api_key, req.api_base,
None, origin_call="http")
else:
res = await apply_embedding_conf(req.host, req.port, req.language_model,
origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
async def rerank_config_api(req: RerankerConfigRequest):
llm_settings.reranker_type = req.reranker_type

if req.reranker_type == "cohere":
res = await apply_reranker_conf(req.api_key, req.reranker_model,
req.cohere_base_url, origin_call="http")
elif req.reranker_type == "siliconflow":
res = await apply_reranker_conf(req.api_key, req.reranker_model,
None, origin_call="http")
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

class ExternalException(HTTPException):
def __init__(self):
super().__init__(status_code=400, detail="Connect failed with error code -1, please check the input.")
super().__init__(status_code=400, detail="Connect failed with error code -1, "
"please check the input.")


class ConnectionFailedException(HTTPException):
Expand Down
46 changes: 32 additions & 14 deletions hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,57 @@

class RAGRequest(BaseModel):
query: str = Query("", description="Query you want to ask")
raw_answer: bool = Query(False, description="Use LLM to generate answer directly")
vector_only: bool = Query(False, description="Use LLM to generate answer with vector")
graph_only: bool = Query(True, description="Use LLM to generate answer with graph RAG only")
graph_vector_answer: bool = Query(False, description="Use LLM to generate answer with vector & GraphRAG")
graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans & vector ans")
rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.")
near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.")
custom_priority_info: str = Query("", description="Custom information to prioritize certain results.")
answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.")
raw_answer: bool = Query(False,
description="Use LLM to generate answer directly")
vector_only: bool = Query(False,
description="Use LLM to generate answer with vector")
graph_only: bool = Query(True,
description="Use LLM to generate answer with graph RAG only")
graph_vector_answer: bool = Query(False,
description="Use LLM to generate answer with vector & GraphRAG")
graph_ratio: float = Query(0.5,
description="The ratio of GraphRAG ans & vector ans")
rerank_method: Literal["bleu", "reranker"] = Query("bleu",
description="Method to rerank the results.")
near_neighbor_first: bool = Query(False,
description="Prioritize near neighbors in the search results.")
custom_priority_info: str = Query("",
description="Custom information to prioritize certain results.")
answer_prompt: Optional[str] = Query(prompt.answer_prompt,
description="Prompt to guide the answer generation.")
keywords_extract_prompt: Optional[str] = Query(
prompt.keywords_extract_prompt,
description="Prompt for extracting keywords from query.",
)
gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates to use.")
gremlin_tmpl_num: int = Query(1,
description="Number of Gremlin templates to use.")
gremlin_prompt: Optional[str] = Query(
prompt.gremlin_generate_prompt,
description="Prompt for the Text2Gremlin query.",
)
stream: bool = Query(False,
description="Enable streaming response")


# TODO: import the default value of prompt.* dynamically
class GraphRAGRequest(BaseModel):
query: str = Query("", description="Query you want to ask")
gremlin_tmpl_num: int = Query(
1, description="Number of Gremlin templates to use. If num <=0 means template is not provided"
1,
description="Number of Gremlin templates to use. If num <=0 means template is not provided"
)
rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.")
near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.")
custom_priority_info: str = Query("", description="Custom information to prioritize certain results.")
rerank_method: Literal["bleu", "reranker"] = Query("bleu",
description="Method to rerank the results.")
near_neighbor_first: bool = Query(False,
description="Prioritize near neighbors in the search results.")
custom_priority_info: str = Query("",
description="Custom information to prioritize certain results.")
gremlin_prompt: Optional[str] = Query(
prompt.gremlin_generate_prompt,
description="Prompt for the Text2Gremlin query.",
)
stream: bool = Query(False,
description="Enable streaming response")


class GraphConfigRequest(BaseModel):
Expand Down
86 changes: 19 additions & 67 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, could tell me how u test the APIs? By directly request them?

The gradio UI loss the API link now

Before:
image

Now:
image

Maybe refer here: (Or Gradio's mount doc?)

Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,27 @@
# specific language governing permissions and limitations
# under the License.


import json

from fastapi import status, APIRouter, HTTPException

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from fastapi import status, APIRouter, HTTPException
from hugegraph_llm.api.models.rag_requests import (
RAGRequest,
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest,
GraphRAGRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import llm_settings, prompt
from hugegraph_llm.config import prompt
from hugegraph_llm.utils.log import log


def rag_http_api(
router: APIRouter,
rag_answer_func,
graph_rag_recall_func,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
async def rag_http_api(
Copy link
Member

@imbajin imbajin Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we use a async api call, seems we should also update the usage in app.py?

image

still not work as expected (Also note admin_api need used in the same way

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to resolve the issues and then I will ping you again sir!

router: APIRouter,
rag_answer_func,
graph_rag_recall_func,
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
result = rag_answer_func(
async def rag_answer_api(req: RAGRequest):
result = await rag_answer_func(
text=req.query,
raw_answer=req.raw_answer,
vector_only_answer=req.vector_only,
Expand All @@ -54,24 +46,26 @@ def rag_answer_api(req: RAGRequest):
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
answer_prompt=req.answer_prompt or prompt.answer_prompt,
keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt,
keywords_extract_prompt=req.keywords_extract_prompt
or prompt.keywords_extract_prompt,
gremlin_tmpl_num=req.gremlin_tmpl_num,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
)
# TODO: we need more info in the response for users to understand the query logic

return {
"query": req.query,
**{
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
for key, value in zip(["raw_answer", "vector_only", "graph_only",
"graph_vector_answer"], result)
if getattr(req, key)
},
}

@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(req: GraphRAGRequest):
async def graph_rag_recall_api(req: GraphRAGRequest):
try:
result = graph_rag_recall_func(
result = await graph_rag_recall_func(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
rerank_method=req.rerank_method,
Expand All @@ -92,7 +86,7 @@ def graph_rag_recall_api(req: GraphRAGRequest):
]
user_result = {key: result[key] for key in params if key in result}
return {"graph_recall": user_result}
# Note: Maybe only for qianfan/wenxin

return {"graph_recall": json.dumps(result)}

except TypeError as e:
Expand All @@ -101,48 +95,6 @@ def graph_rag_recall_api(req: GraphRAGRequest):
except Exception as e:
log.error("Unexpected error occurred: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred."
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred."
) from e

@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
# Accept status code
res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

# TODO: restructure the implement of llm to three types, like "/config/chat_llm"
@router.post("/config/llm", status_code=status.HTTP_201_CREATED)
def llm_config_api(req: LLMConfigRequest):
llm_settings.llm_type = req.llm_type

if req.llm_type == "openai":
res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
else:
res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
def embedding_config_api(req: LLMConfigRequest):
llm_settings.embedding_type = req.llm_type

if req.llm_type == "openai":
res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http")
else:
res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
def rerank_config_api(req: RerankerConfigRequest):
llm_settings.reranker_type = req.reranker_type

if req.reranker_type == "cohere":
res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http")
elif req.reranker_type == "siliconflow":
res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http")
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
Loading
Loading