Skip to content

Commit

Permalink
refactor(llm): replace vid by full vertexes info (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
imbajin authored Mar 3, 2025
1 parent 8c1ffbb commit ca28faf
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 9 deletions.
2 changes: 1 addition & 1 deletion hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class GraphRAGRequest(BaseModel):
from the query, by default only the most similar one is returned.")

client_config : Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.")
get_vid_only: bool = Query(False, description="return only keywords & vid (early stop).")
get_vertex_only: bool = Query(False, description="return only keywords & vertex (early stop).")

gremlin_tmpl_num: int = Query(
1, description="Number of Gremlin templates to use. If num <=0 means template is not provided"
Expand Down
13 changes: 11 additions & 2 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from hugegraph_llm.config import llm_settings, prompt
from hugegraph_llm.utils.log import log


# pylint: disable=too-many-statements
def rag_http_api(
router: APIRouter,
rag_answer_func,
Expand Down Expand Up @@ -101,9 +101,18 @@ def graph_rag_recall_api(req: GraphRAGRequest):
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
get_vid_only=req.get_vid_only
get_vertex_only=req.get_vertex_only
)

if req.get_vertex_only:
from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery
graph_rag = GraphRAGQuery()
graph_rag.init_client(result)
vertex_details = graph_rag.get_vertex_details(result["match_vids"])

if vertex_details:
result["match_vids"] = vertex_details

if isinstance(result, dict):
params = [
"query",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,16 @@ def graph_rag_recall(
topk_return_results: int,
vector_dis_threshold: float,
topk_per_keyword: int,
get_vid_only: bool
get_vertex_only: bool = False,
) -> dict:
store_schema(prompt.text2gql_graph_schema, query, gremlin_prompt)
rag = RAGPipeline()
rag.extract_keywords().keywords_to_vid(
vector_dis_threshold=vector_dis_threshold,
topk_per_keyword=topk_per_keyword,
)
if not get_vid_only:

if not get_vertex_only:
rag.import_schema(huge_settings.graph_name).query_graphdb(
num_gremlin_generate_example=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt

def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
self._init_client(context)
self.init_client(context)

# initial flag: -1 means no result, 0 means subgraph query, 1 means gremlin query
context["graph_result_flag"] = -1
Expand Down Expand Up @@ -239,7 +239,9 @@ def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]:
)
return context

def _init_client(self, context):
# TODO: move this method to a util file for reuse (remove self param)
def init_client(self, context):
"""Initialize the HugeGraph client from context or default settings."""
# pylint: disable=R0915 (too-many-statements)
if self._client is None:
if isinstance(context.get("graph_client"), PyHugeClient):
Expand All @@ -254,6 +256,15 @@ def _init_client(self, context):
self._client = PyHugeClient(ip, port, graph, user, pwd, gs)
assert self._client is not None, "No valid graph to search."

def get_vertex_details(self, vertex_ids: List[str]) -> List[Dict[str, Any]]:
if not vertex_ids:
return []

formatted_ids = ", ".join(f"'{vid}'" for vid in vertex_ids)
gremlin_query = f"g.V({formatted_ids}).limit(20)"
result = self._client.gremlin().exec(gremlin=gremlin_query)["data"]
return result

def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]:
knowledge = set()
for item in query_result:
Expand Down Expand Up @@ -374,8 +385,8 @@ def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]:
schema = self._get_graph_schema()
vertex_props_str, edge_props_str = schema.split("\n")[:2]
# TODO: rename to vertex (also need update in the schema)
vertex_props_str = vertex_props_str[len("Vertex properties: ") :].strip("[").strip("]")
edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]")
vertex_props_str = vertex_props_str[len("Vertex properties: "):].strip("[").strip("]")
edge_props_str = edge_props_str[len("Edge properties: "):].strip("[").strip("]")
vertex_labels = self._extract_label_names(vertex_props_str)
edge_labels = self._extract_label_names(edge_props_str)
return vertex_labels, edge_labels
Expand Down

0 comments on commit ca28faf

Please sign in to comment.