Skip to content

Commit

Permalink
Merge branch 'main' into add_variable
Browse files Browse the repository at this point in the history
  • Loading branch information
simon824 authored Oct 26, 2023
2 parents 5a4da8c + 31b1720 commit c6519fb
Show file tree
Hide file tree
Showing 16 changed files with 851 additions and 29 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/check-dependencies.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
name: "3rd-party check"

on:
push:
branches:
- main
- 'release-*'
pull_request:

permissions:
Expand Down
114 changes: 114 additions & 0 deletions hugegraph-llm/examples/graph_rag_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


import os

from hugegraph_llm.operators.graph_rag_operator import GraphRAG
from pyhugegraph.client import PyHugeClient


def prepare_data():
client = PyHugeClient(
"127.0.0.1", 18080, "hugegraph", "admin", "admin"
)
schema = client.schema()
schema.propertyKey("name").asText().ifNotExist().create()
schema.propertyKey("birthDate").asText().ifNotExist().create()
schema.vertexLabel("Person").properties("name", "birthDate") \
.useCustomizeStringId().ifNotExist().create()
schema.vertexLabel("Movie").properties("name").useCustomizeStringId().ifNotExist().create()
schema.indexLabel("PersonByName").onV("Person").by("name").secondary().ifNotExist().create()
schema.indexLabel("MovieByName").onV("Movie").by("name").secondary().ifNotExist().create()
schema.edgeLabel("ActedIn").sourceLabel("Person").targetLabel("Movie").ifNotExist().create()

graph = client.graph()
graph.addVertex("Person", {"name": "Al Pacino", "birthDate": "1940-04-25"}, id="Al Pacino")
graph.addVertex(
"Person", {"name": "Robert De Niro", "birthDate": "1943-08-17"}, id="Robert De Niro")
graph.addVertex("Movie", {"name": "The Godfather"}, id="The Godfather")
graph.addVertex("Movie", {"name": "The Godfather Part II"}, id="The Godfather Part II")
graph.addVertex("Movie", {"name": "The Godfather Coda The Death of Michael Corleone"},
id="The Godfather Coda The Death of Michael Corleone")

graph.addEdge("ActedIn", "Al Pacino", "The Godfather", {})
graph.addEdge("ActedIn", "Al Pacino", "The Godfather Part II", {})
graph.addEdge("ActedIn", "Al Pacino", "The Godfather Coda The Death of Michael Corleone", {})
graph.addEdge("ActedIn", "Robert De Niro", "The Godfather Part II", {})

graph.close()


if __name__ == '__main__':
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""
os.environ["OPENAI_API_KEY"] = ""

# prepare_data()

graph_rag = GraphRAG()

# configure operator with context dict
context = {
# hugegraph client
"ip": "localhost", # default to "localhost" if not set
"port": 18080, # default to 8080 if not set
"user": "admin", # default to "admin" if not set
"pwd": "admin", # default to "admin" if not set
"graph": "hugegraph", # default to "hugegraph" if not set

# query question
"query": "Tell me about Al Pacino.", # must be set

# keywords extraction
"max_keywords": 5, # default to 5 if not set
"language": "english", # default to "english" if not set

# graph rag query
"prop_to_match": "name", # default to None if not set
"max_deep": 2, # default to 2 if not set
"max_items": 30, # default to 30 if not set

# print intermediate processes result
"verbose": True, # default to False if not set
}
result = graph_rag \
.extract_keyword() \
.query_graph_for_rag() \
.synthesize_answer() \
.run(**context)
print(f"Query:\n- {context['query']}")
print(f"Answer:\n- {result['answer']}")

print("--------------------------------------------------------")

# configure operator with parameters
graph_client = PyHugeClient(
"127.0.0.1", 18080, "hugegraph", "admin", "admin"
)
result = graph_rag.extract_keyword(
text="Tell me about Al Pacino.",
max_keywords=5, # default to 5 if not set
language="english", # default to "english" if not set
).query_graph_for_rag(
graph_client=graph_client,
max_deep=2, # default to 2 if not set
max_items=30, # default to 30 if not set
prop_to_match=None, # default to None if not set
).synthesize_answer().run(verbose=True)
print("Query:\n- Tell me about Al Pacino.")
print(f"Answer:\n- {result['answer']}")
1 change: 1 addition & 0 deletions hugegraph-llm/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
openai==0.28.1
retry==0.9.2
tiktoken==0.5.1
nltk==3.8.1
6 changes: 3 additions & 3 deletions hugegraph-llm/src/hugegraph_llm/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


from abc import ABC, abstractmethod
from typing import Any, List, Optional, Callable
from typing import Any, List, Optional, Callable, Dict


class BaseLLM(ABC):
Expand All @@ -26,15 +26,15 @@ class BaseLLM(ABC):
@abstractmethod
def generate(
self,
messages: Optional[List[str]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
"""Comment"""

@abstractmethod
async def generate_streaming(
self,
messages: Optional[List[str]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
) -> List[Any]:
Expand Down
22 changes: 15 additions & 7 deletions hugegraph-llm/src/hugegraph_llm/llms/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.


from typing import Callable, List, Optional
import os
from typing import Callable, List, Optional, Dict, Any
import openai
import tiktoken
from retry import retry
Expand All @@ -34,17 +35,21 @@ def __init__(
max_tokens: int = 1000,
temperature: float = 0.0,
) -> None:
openai.api_key = api_key
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.model = model_name
self.max_tokens = max_tokens
self.temperature = temperature

@retry(tries=3, delay=1)
def generate(
self,
messages: Optional[List[str]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
"""Generate a response to the query messages/prompt."""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
try:
completions = openai.ChatCompletion.create(
model=self.model,
Expand All @@ -57,18 +62,19 @@ def generate(
except openai.error.InvalidRequestError as e:
return str(f"Error: {e}")
# catch authorization errors / do not retry
except openai.error.AuthenticationError as e:
return f"Error: The provided OpenAI API key is invalid, {e}"
except openai.error.AuthenticationError:
return "Error: The provided OpenAI API key is invalid"
except Exception as e:
print(f"Retrying LLM call {e}")
raise Exception() from e
raise e

async def generate_streaming(
self,
messages: Optional[List[str]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
) -> str:
"""Generate a response to the query messages/prompt in streaming mode."""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
Expand All @@ -89,10 +95,12 @@ async def generate_streaming(
return result

async def num_tokens_from_string(self, string: str) -> int:
"""Get token count from string."""
encoding = tiktoken.encoding_for_model(self.model)
num_tokens = len(encoding.encode(string))
return num_tokens

async def max_allowed_token_length(self) -> int:
"""Get max-allowed token length"""
# TODO: list all models and their max tokens from api
return 2049
16 changes: 16 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# under the License.


from hugegraph_llm.llms.base import BaseLLM
from hugegraph_llm.operators.hugegraph_op.commit_data_to_kg import CommitDataToKg
from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData
from hugegraph_llm.operators.llm_op.parse_text_to_data import (
ParseTextToData,
ParseTextToDataWithSchemas,
)
from hugegraph_llm.llms.base import BaseLLM


class KgBuilder:
Expand Down
89 changes: 89 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/operators/graph_rag_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


from typing import Dict, Any, Optional, List

from hugegraph_llm.llms.base import BaseLLM
from hugegraph_llm.llms.openai_llm import OpenAIChat
from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery
from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract
from pyhugegraph.client import PyHugeClient


class GraphRAG:
def __init__(self, llm: Optional[BaseLLM] = None):
self._llm = llm or OpenAIChat()
self._operators: List[Any] = []

def extract_keyword(
self,
text: Optional[str] = None,
max_keywords: int = 5,
language: str = 'english',
extract_template: Optional[str] = None,
expand_template: Optional[str] = None,
):
self._operators.append(
KeywordExtract(
text=text,
max_keywords=max_keywords,
language=language,
extract_template=extract_template,
expand_template=expand_template,
)
)
return self

def query_graph_for_rag(
self,
graph_client: Optional[PyHugeClient] = None,
max_deep: int = 2,
max_items: int = 30,
prop_to_match: Optional[str] = None,
):
self._operators.append(
GraphRAGQuery(
client=graph_client,
max_deep=max_deep,
max_items=max_items,
prop_to_match=prop_to_match,
)
)
return self

def synthesize_answer(
self,
prompt_template: Optional[str] = None,
):
self._operators.append(
AnswerSynthesize(
prompt_template=prompt_template,
)
)
return self

def run(self, **kwargs) -> Dict[str, Any]:
if len(self._operators) == 0:
self.extract_keyword().query_graph_for_rag().synthesize_answer()

context = kwargs
context["llm"] = self._llm
for op in self._operators:
context = op.run(context=context)
return context
Loading

0 comments on commit c6519fb

Please sign in to comment.