Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): support async/streaming output mode in api layer #179

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions hugegraph-llm/src/hugegraph_llm/api/admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,18 @@
# under the License.
import os

from fastapi import status, APIRouter
from fastapi import status, APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import LogStreamRequest
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import admin_settings


# FIXME: line 31: E0702: Raising dict while only classes or instances are allowed (raising-bad-type)
def admin_http_api(router: APIRouter, log_stream):
@router.post("/logs", status_code=status.HTTP_200_OK)
async def log_stream_api(req: LogStreamRequest):
if admin_settings.admin_token != req.admin_token:
raise generate_response(RAGResponse(status_code=status.HTTP_403_FORBIDDEN, message="Invalid admin_token")) #pylint: disable=E0702
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid admin_token")
log_path = os.path.join("logs", req.log_file)

# Create a StreamingResponse that reads from the log stream generator
Expand Down
97 changes: 97 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/api/config_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from fastapi import status, APIRouter

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import llm_settings



async def graph_config_route(router: APIRouter, apply_graph_conf):
@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
async def graph_config_api(req: GraphConfigRequest):
# Accept status code
res = await apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
return graph_config_api

async def llm_config_route(router: APIRouter, apply_llm_conf):
# TODO: restructure the implement of llm to three types, like "/config/chat_llm"
@router.post("/config/llm", status_code=status.HTTP_201_CREATED)
async def llm_config_api(req: LLMConfigRequest):
llm_settings.llm_type = req.llm_type

if req.llm_type == "openai":
res = await apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens,
origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = await apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
else:
res = await apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

return llm_config_api

async def embedding_config_route(router: APIRouter, apply_embedding_conf):
@router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
async def embedding_config_api(req: LLMConfigRequest):
llm_settings.embedding_type = req.llm_type

if req.llm_type == "openai":
res = await apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = await apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http")
else:
res = await apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

return embedding_config_api

async def rerank_config_route(router: APIRouter, apply_reranker_conf):
@router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
async def rerank_config_api(req: RerankerConfigRequest):
llm_settings.reranker_type = req.reranker_type

if req.reranker_type == "cohere":
res = await apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http")
elif req.reranker_type == "siliconflow":
res = await apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http")
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

return rerank_config_api


async def config_http_api(
router: APIRouter,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
):
await graph_config_route(router, apply_graph_conf)
await llm_config_route(router, apply_llm_conf)
await embedding_config_route(router, apply_embedding_conf)
await rerank_config_route(router, apply_reranker_conf)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

class ExternalException(HTTPException):
def __init__(self):
super().__init__(status_code=400, detail="Connect failed with error code -1, please check the input.")
super().__init__(status_code=400, detail="Connect failed with error code -1, "
"please check the input.")


class ConnectionFailedException(HTTPException):
Expand Down
2 changes: 2 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class RAGRequest(BaseModel):
topk_per_keyword : int = Query(1, description="TopK results returned for each keyword \
extracted from the query, by default only the most similar one is returned.")
client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.")
stream: bool = Query(False, description="Whether to use streaming response")

# Keep prompt params in the end
answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.")
Expand Down Expand Up @@ -77,6 +78,7 @@ class GraphRAGRequest(BaseModel):

client_config : Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.")
get_vertex_only: bool = Query(False, description="return only keywords & vertex (early stop).")
stream: bool = Query(False, description="Whether to use streaming response")

gremlin_tmpl_num: int = Query(
1, description="Number of Gremlin templates to use. If num <=0 means template is not provided"
Expand Down
80 changes: 22 additions & 58 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, could tell me how u test the APIs? By directly request them?

The gradio UI loss the API link now

Before:
image

Now:
image

Maybe refer here: (Or Gradio's mount doc?)

Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,36 @@

from fastapi import status, APIRouter, HTTPException

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
RAGRequest,
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest,
GraphRAGRequest,
)
from hugegraph_llm.config import huge_settings
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import llm_settings, prompt
from hugegraph_llm.config import prompt
from hugegraph_llm.utils.log import log

from hugegraph_llm.api.config_api import (
graph_config_route,
llm_config_route,
embedding_config_route,
rerank_config_route
)

# pylint: disable=too-many-statements
def rag_http_api(
async def rag_http_api(
Copy link
Member

@imbajin imbajin Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we use a async api call, seems we should also update the usage in app.py?

image

still not work as expected (Also note admin_api need used in the same way

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to resolve the issues and then I will ping you again sir!

router: APIRouter,
rag_answer_func,
graph_rag_recall_func,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
apply_graph_conf= None,
apply_llm_conf= None,
apply_embedding_conf= None,
apply_reranker_conf= None,
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
async def rag_answer_api(req: RAGRequest):
set_graph_config(req)

result = rag_answer_func(
result = await rag_answer_func(
text=req.query,
raw_answer=req.raw_answer,
vector_only_answer=req.vector_only,
Expand Down Expand Up @@ -86,11 +88,11 @@ def set_graph_config(req):
huge_settings.graph_space = req.client_config.gs

@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(req: GraphRAGRequest):
async def graph_rag_recall_api(req: GraphRAGRequest):
try:
set_graph_config(req)

result = graph_rag_recall_func(
result = await graph_rag_recall_func(
query=req.query,
max_graph_items=req.max_graph_items,
topk_return_results=req.topk_return_results,
Expand All @@ -108,7 +110,7 @@ def graph_rag_recall_api(req: GraphRAGRequest):
from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery
graph_rag = GraphRAGQuery()
graph_rag.init_client(result)
vertex_details = graph_rag.get_vertex_details(result["match_vids"])
vertex_details = await graph_rag.get_vertex_details(result["match_vids"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so as here


if vertex_details:
result["match_vids"] = vertex_details
Expand Down Expand Up @@ -137,45 +139,7 @@ def graph_rag_recall_api(req: GraphRAGRequest):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred."
) from e

@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
# Accept status code
res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

# TODO: restructure the implement of llm to three types, like "/config/chat_llm"
@router.post("/config/llm", status_code=status.HTTP_201_CREATED)
def llm_config_api(req: LLMConfigRequest):
llm_settings.llm_type = req.llm_type

if req.llm_type == "openai":
res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
else:
res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
def embedding_config_api(req: LLMConfigRequest):
llm_settings.embedding_type = req.llm_type

if req.llm_type == "openai":
res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http")
else:
res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
def rerank_config_api(req: RerankerConfigRequest):
llm_settings.reranker_type = req.reranker_type

if req.reranker_type == "cohere":
res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http")
elif req.reranker_type == "siliconflow":
res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http")
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
await graph_config_route(router, apply_graph_conf)
await llm_config_route(router, apply_llm_conf)
await embedding_config_route(router, apply_embedding_conf)
await rerank_config_route(router, apply_reranker_conf)
Loading
Loading