-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(llm): support async/streaming output mode in api layer #179
base: main
Are you sure you want to change the base?
Changes from 12 commits
1e39f04
23e4680
8c7dbaf
8c9f0f8
66c221e
fb233f8
210b4e2
06ef02d
6dd0b7f
1d70b2a
05e7452
dfde6b0
d4cd537
d37b52e
ebdc387
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from fastapi import status, APIRouter | ||
|
||
from hugegraph_llm.api.exceptions.rag_exceptions import generate_response | ||
from hugegraph_llm.api.models.rag_requests import ( | ||
GraphConfigRequest, | ||
LLMConfigRequest, | ||
RerankerConfigRequest, | ||
) | ||
from hugegraph_llm.api.models.rag_response import RAGResponse | ||
from hugegraph_llm.config import llm_settings | ||
|
||
|
||
|
||
async def graph_config_route(router: APIRouter, apply_graph_conf): | ||
@router.post("/config/graph", status_code=status.HTTP_201_CREATED) | ||
async def graph_config_api(req: GraphConfigRequest): | ||
# Accept status code | ||
res = await apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http") | ||
return generate_response(RAGResponse(status_code=res, message="Missing Value")) | ||
return graph_config_api | ||
|
||
async def llm_config_route(router: APIRouter, apply_llm_conf): | ||
# TODO: restructure the implement of llm to three types, like "/config/chat_llm" | ||
@router.post("/config/llm", status_code=status.HTTP_201_CREATED) | ||
async def llm_config_api(req: LLMConfigRequest): | ||
llm_settings.llm_type = req.llm_type | ||
|
||
if req.llm_type == "openai": | ||
res = await apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, | ||
origin_call="http") | ||
elif req.llm_type == "qianfan_wenxin": | ||
res = await apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http") | ||
else: | ||
res = await apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") | ||
return generate_response(RAGResponse(status_code=res, message="Missing Value")) | ||
|
||
return llm_config_api | ||
|
||
async def embedding_config_route(router: APIRouter, apply_embedding_conf): | ||
@router.post("/config/embedding", status_code=status.HTTP_201_CREATED) | ||
async def embedding_config_api(req: LLMConfigRequest): | ||
llm_settings.embedding_type = req.llm_type | ||
|
||
if req.llm_type == "openai": | ||
res = await apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http") | ||
elif req.llm_type == "qianfan_wenxin": | ||
res = await apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http") | ||
else: | ||
res = await apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") | ||
return generate_response(RAGResponse(status_code=res, message="Missing Value")) | ||
|
||
return embedding_config_api | ||
|
||
async def rerank_config_route(router: APIRouter, apply_reranker_conf): | ||
@router.post("/config/rerank", status_code=status.HTTP_201_CREATED) | ||
async def rerank_config_api(req: RerankerConfigRequest): | ||
llm_settings.reranker_type = req.reranker_type | ||
|
||
if req.reranker_type == "cohere": | ||
res = await apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http") | ||
elif req.reranker_type == "siliconflow": | ||
res = await apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http") | ||
else: | ||
res = status.HTTP_501_NOT_IMPLEMENTED | ||
return generate_response(RAGResponse(status_code=res, message="Missing Value")) | ||
|
||
return rerank_config_api | ||
|
||
|
||
async def config_http_api( | ||
router: APIRouter, | ||
apply_graph_conf, | ||
apply_llm_conf, | ||
apply_embedding_conf, | ||
apply_reranker_conf, | ||
): | ||
await graph_config_route(router, apply_graph_conf) | ||
await llm_config_route(router, apply_llm_conf) | ||
await embedding_config_route(router, apply_embedding_conf) | ||
await rerank_config_route(router, apply_reranker_conf) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,34 +19,36 @@ | |
|
||
from fastapi import status, APIRouter, HTTPException | ||
|
||
from hugegraph_llm.api.exceptions.rag_exceptions import generate_response | ||
from hugegraph_llm.api.models.rag_requests import ( | ||
RAGRequest, | ||
GraphConfigRequest, | ||
LLMConfigRequest, | ||
RerankerConfigRequest, | ||
GraphRAGRequest, | ||
) | ||
from hugegraph_llm.config import huge_settings | ||
from hugegraph_llm.api.models.rag_response import RAGResponse | ||
from hugegraph_llm.config import llm_settings, prompt | ||
from hugegraph_llm.config import prompt | ||
from hugegraph_llm.utils.log import log | ||
|
||
from hugegraph_llm.api.config_api import ( | ||
graph_config_route, | ||
llm_config_route, | ||
embedding_config_route, | ||
rerank_config_route | ||
) | ||
|
||
# pylint: disable=too-many-statements | ||
def rag_http_api( | ||
async def rag_http_api( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will try to resolve the issues and then I will ping you again sir! |
||
router: APIRouter, | ||
rag_answer_func, | ||
graph_rag_recall_func, | ||
apply_graph_conf, | ||
apply_llm_conf, | ||
apply_embedding_conf, | ||
apply_reranker_conf, | ||
apply_graph_conf= None, | ||
apply_llm_conf= None, | ||
apply_embedding_conf= None, | ||
apply_reranker_conf= None, | ||
): | ||
@router.post("/rag", status_code=status.HTTP_200_OK) | ||
def rag_answer_api(req: RAGRequest): | ||
async def rag_answer_api(req: RAGRequest): | ||
set_graph_config(req) | ||
|
||
result = rag_answer_func( | ||
result = await rag_answer_func( | ||
text=req.query, | ||
raw_answer=req.raw_answer, | ||
vector_only_answer=req.vector_only, | ||
|
@@ -86,11 +88,11 @@ def set_graph_config(req): | |
huge_settings.graph_space = req.client_config.gs | ||
|
||
@router.post("/rag/graph", status_code=status.HTTP_200_OK) | ||
def graph_rag_recall_api(req: GraphRAGRequest): | ||
async def graph_rag_recall_api(req: GraphRAGRequest): | ||
try: | ||
set_graph_config(req) | ||
|
||
result = graph_rag_recall_func( | ||
result = await graph_rag_recall_func( | ||
query=req.query, | ||
max_graph_items=req.max_graph_items, | ||
topk_return_results=req.topk_return_results, | ||
|
@@ -108,7 +110,7 @@ def graph_rag_recall_api(req: GraphRAGRequest): | |
from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery | ||
graph_rag = GraphRAGQuery() | ||
graph_rag.init_client(result) | ||
vertex_details = graph_rag.get_vertex_details(result["match_vids"]) | ||
vertex_details = await graph_rag.get_vertex_details(result["match_vids"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so as here |
||
|
||
if vertex_details: | ||
result["match_vids"] = vertex_details | ||
|
@@ -137,45 +139,7 @@ def graph_rag_recall_api(req: GraphRAGRequest): | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred." | ||
) from e | ||
|
||
@router.post("/config/graph", status_code=status.HTTP_201_CREATED) | ||
def graph_config_api(req: GraphConfigRequest): | ||
# Accept status code | ||
res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http") | ||
return generate_response(RAGResponse(status_code=res, message="Missing Value")) | ||
|
||
# TODO: restructure the implement of llm to three types, like "/config/chat_llm" | ||
@router.post("/config/llm", status_code=status.HTTP_201_CREATED) | ||
def llm_config_api(req: LLMConfigRequest): | ||
llm_settings.llm_type = req.llm_type | ||
|
||
if req.llm_type == "openai": | ||
res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http") | ||
elif req.llm_type == "qianfan_wenxin": | ||
res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http") | ||
else: | ||
res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") | ||
return generate_response(RAGResponse(status_code=res, message="Missing Value")) | ||
|
||
@router.post("/config/embedding", status_code=status.HTTP_201_CREATED) | ||
def embedding_config_api(req: LLMConfigRequest): | ||
llm_settings.embedding_type = req.llm_type | ||
|
||
if req.llm_type == "openai": | ||
res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http") | ||
elif req.llm_type == "qianfan_wenxin": | ||
res = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http") | ||
else: | ||
res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") | ||
return generate_response(RAGResponse(status_code=res, message="Missing Value")) | ||
|
||
@router.post("/config/rerank", status_code=status.HTTP_201_CREATED) | ||
def rerank_config_api(req: RerankerConfigRequest): | ||
llm_settings.reranker_type = req.reranker_type | ||
|
||
if req.reranker_type == "cohere": | ||
res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http") | ||
elif req.reranker_type == "siliconflow": | ||
res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http") | ||
else: | ||
res = status.HTTP_501_NOT_IMPLEMENTED | ||
return generate_response(RAGResponse(status_code=res, message="Missing Value")) | ||
await graph_config_route(router, apply_graph_conf) | ||
await llm_config_route(router, apply_llm_conf) | ||
await embedding_config_route(router, apply_embedding_conf) | ||
await rerank_config_route(router, apply_reranker_conf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, could tell me how u test the APIs? By directly request them?
The gradio UI loss the API link now
Before:

Now:

Maybe refer here: (Or Gradio's mount doc?)
incubator-hugegraph-ai/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
Line 182 in 7ae5d6f