Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
chengbiao-jin committed Oct 29, 2024
2 parents 696c52c + a477e31 commit 1957332
Show file tree
Hide file tree
Showing 98 changed files with 9,601 additions and 2,585 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ site/
pyTigerGraph.egg-info
tests1.py
.ipynb_checkpoints
pyrightconfig.json
venv
requirement.txt
pyrightconfig.json
1 change: 1 addition & 0 deletions pyTigerGraph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pyTigerGraph.pyTigerGraph import TigerGraphConnection
from pyTigerGraph.pytgasync.pyTigerGraph import AsyncTigerGraphConnection

__version__ = "1.7.4"

Expand Down
82 changes: 44 additions & 38 deletions pyTigerGraph/ai/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@

from pyTigerGraph import TigerGraphConnection


class AI:
def __init__(self, conn: TigerGraphConnection) -> None:
def __init__(self, conn: TigerGraphConnection) -> None:
"""NO DOC: Initiate an AI object. Currently in beta testing.
Args:
conn (TigerGraphConnection):
Accept a TigerGraphConnection to run queries with
Returns:
None
"""
Expand All @@ -47,7 +48,7 @@ def __init__(self, conn: TigerGraphConnection) -> None:
# split scheme and host
scheme, host = conn.host.split("://")
self.nlqs_host = scheme + "://copilot-" + host

def configureInquiryAIHost(self, hostname: str):
""" DEPRECATED: Configure the hostname of the InquiryAI service.
Not recommended to use. Use configureCoPilotHost() instead.
Expand All @@ -56,8 +57,8 @@ def configureInquiryAIHost(self, hostname: str):
The hostname (and port number) of the InquiryAI serivce.
"""
warnings.warn(
"The `configureInquiryAIHost()` function is deprecated; use `configureCoPilotHost()` function instead.",
DeprecationWarning)
"The `configureInquiryAIHost()` function is deprecated; use `configureCoPilotHost()` function instead.",
DeprecationWarning)
self.nlqs_host = hostname

def configureCoPilotHost(self, hostname: str):
Expand Down Expand Up @@ -96,14 +97,17 @@ def registerCustomQuery(self, query_name: str, description: str = None, docstrin
"queries": [query_name]
}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)
else:
if description is None:
raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.")
raise ValueError(
"When using TigerGraph 3.x, query descriptions are required parameters.")
if docstring is None:
raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.")
raise ValueError(
"When using TigerGraph 3.x, query docstrings are required parameters.")
if param_types is None:
raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.")
raise ValueError(
"When using TigerGraph 3.x, query parameter types are required parameters.")
data = {
"function_header": query_name,
"description": description,
Expand All @@ -112,8 +116,8 @@ def registerCustomQuery(self, query_name: str, description: str = None, docstrin
"graphname": self.conn.graphname
}
url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def updateCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None):
""" Update a custom query with the InquiryAI service.
Args:
Expand Down Expand Up @@ -142,14 +146,17 @@ def updateCustomQuery(self, query_name: str, description: str = None, docstring:
"queries": [query_name]
}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)
else:
if description is None:
raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.")
raise ValueError(
"When using TigerGraph 3.x, query descriptions are required parameters.")
if docstring is None:
raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.")
raise ValueError(
"When using TigerGraph 3.x, query docstrings are required parameters.")
if param_types is None:
raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.")
raise ValueError(
"When using TigerGraph 3.x, query parameter types are required parameters.")
data = {
"function_header": query_name,
"description": description,
Expand All @@ -160,8 +167,8 @@ def updateCustomQuery(self, query_name: str, description: str = None, docstring:

json_payload = {"id": "", "query_info": data}
url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs"
return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None)
return self.conn._req("POST", url, authMode="pwd", data=json_payload, jsonData=True, resKey=None)

def deleteCustomQuery(self, query_name: str):
""" Delete a custom query with the InquiryAI service.
Args:
Expand All @@ -172,9 +179,9 @@ def deleteCustomQuery(self, query_name: str):
"""
data = {"ids": [], "expr": "function_header == '"+query_name+"'"}
url = self.nlqs_host+"/"+self.conn.graphname+"/delete_docs"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def retrieveDocs(self, query:str, top_k:int = 3):
def retrieveDocs(self, query: str, top_k: int = 3):
""" Retrieve docs from the vector store.
Args:
query (str):
Expand All @@ -188,8 +195,9 @@ def retrieveDocs(self, query:str, top_k:int = 3):
"query": query
}

url = self.nlqs_host+"/"+self.conn.graphname+"/retrieve_docs?top_k="+str(top_k)
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None, skipCheck=True)
url = self.nlqs_host+"/"+self.conn.graphname + \
"/retrieve_docs?top_k="+str(top_k)
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None, skipCheck=True)

def query(self, query):
""" Query the database with natural language.
Expand All @@ -204,24 +212,24 @@ def query(self, query):
}

url = self.nlqs_host+"/"+self.conn.graphname+"/query"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def coPilotHealth(self):
""" Check the health of the CoPilot service.
Returns:
JSON response from the CoPilot service.
"""
url = self.nlqs_host+"/health"
return self.conn._req("GET", url, authMode="pwd", resKey=None)

def initializeSupportAI(self):
""" Initialize the SupportAI service.
Returns:
JSON response from the SupportAI service.
"""
url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/initialize"
return self.conn._req("POST", url, authMode="pwd", resKey=None)

def createDocumentIngest(self, data_source, data_source_config, loader_config, file_format):
""" Create a document ingest.
Args:
Expand All @@ -244,8 +252,8 @@ def createDocumentIngest(self, data_source, data_source_config, loader_config, f
}

url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/create_ingest"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def runDocumentIngest(self, load_job_id, data_source_id, data_path):
""" Run a document ingest.
Args:
Expand All @@ -264,9 +272,9 @@ def runDocumentIngest(self, load_job_id, data_source_id, data_path):
"file_path": data_path
}
url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/ingest"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
def searchDocuments(self, query, method = "hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}):
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def searchDocuments(self, query, method="hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}):
""" Search documents.
Args:
query (str):
Expand All @@ -284,9 +292,9 @@ def searchDocuments(self, query, method = "hnswoverlap", method_parameters: dict
"method_params": method_parameters
}
url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/search"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
def answerQuestion(self, query, method = "hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}):
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def answerQuestion(self, query, method="hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}):
""" Answer a question.
Args:
query (str):
Expand All @@ -304,8 +312,8 @@ def answerQuestion(self, query, method = "hnswoverlap", method_parameters: dict
"method_params": method_parameters
}
url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/answerquestion"
return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None)
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def forceConsistencyUpdate(self, method="supportai"):
""" Force a consistency update for SupportAI embeddings.
Args:
Expand All @@ -317,8 +325,7 @@ def forceConsistencyUpdate(self, method="supportai"):
"""
url = f"{self.nlqs_host}/{self.conn.graphname}/{method}/forceupdate/"
return self.conn._req("GET", url, authMode="pwd", resKey=None)

''' TODO: Add support in CoPilot

def checkConsistencyProgress(self, method="supportai"):
""" Check the progress of the consistency update.
Args:
Expand All @@ -330,4 +337,3 @@ def checkConsistencyProgress(self, method="supportai"):
"""
url = f"{self.nlqs_host}/{self.conn.graphname}/supportai/consistency_status/{method}"
return self.conn._req("GET", url, authMode="pwd", resKey=None)
'''
Empty file added pyTigerGraph/common/__init__.py
Empty file.
140 changes: 140 additions & 0 deletions pyTigerGraph/common/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
import logging

from datetime import datetime
from typing import Union, Tuple, Dict

from pyTigerGraph.common.exception import TigerGraphException

logger = logging.getLogger(__name__)

def _parse_get_secrets(response: str) -> Dict[str, str]:
secrets_dict = {}
lines = response.split("\n")

i = 0
while i < len(lines):
line = lines[i]
# s = ""
if "- Secret" in line:
secret = line.split(": ")[1]
i += 1
line = lines[i]
if "- Alias" in line:
secrets_dict[line.split(": ")[1]] = secret
i += 1
return secrets_dict

def _parse_create_secret(response: str, alias: str = "", withAlias: bool = False) -> Union[str, Dict[str, str]]:
try:
if "already exists" in response:
error_msg = "The secret "
if alias != "":
error_msg += "with alias {} ".format(alias)
error_msg += "already exists."
raise TigerGraphException(error_msg, "E-00001")

secret = "".join(response).replace('\n', '').split(
'The secret: ')[1].split(" ")[0].strip()

if not withAlias:
if logger.level == logging.DEBUG:
logger.debug("return: " + str(secret))
logger.info("exit: createSecret (withAlias")

return secret

if alias:
ret = {alias: secret}

if logger.level == logging.DEBUG:
logger.debug("return: " + str(ret))
logger.info("exit: createSecret (alias)")

return ret

return secret

except IndexError as e:
raise TigerGraphException(
"Failed to parse secret from response.", "E-00002") from e

def _prep_token_request(restppUrl: str,
gsUrl: str,
graphname: str,
version: str = None,
secret: str = None,
lifetime: int = None,
token: str = None,
method: str = None):
major_ver, minor_ver, patch_ver = (0, 0, 0)
if version:
major_ver, minor_ver, patch_ver = version.split(".")

if 0 < int(major_ver) < 3 or (int(major_ver) == 3 and int(minor_ver) < 5):
method = "GET"
url = restppUrl + "/requesttoken?secret=" + secret + \
("&lifetime=" + str(lifetime) if lifetime else "") + \
("&token=" + token if token else "")
authMode = None
if not secret:
raise TigerGraphException(
"Cannot request a token with username/password for versions < 3.5.")
else:
method = "POST"
url = gsUrl + "/gsql/v1/tokens" # used for TG 4.x
data = {"graph": graphname}

# alt_url and alt_data used to construct the method and url for functions run in TG version 3.x
alt_url = restppUrl+"/requesttoken" # used for TG 3.x
alt_data = {}

if lifetime:
data["lifetime"] = str(lifetime)
alt_data["lifetime"] = str(lifetime)
if token:
data["tokens"] = token
alt_data["token"] = token
if secret:
authMode = "None"
data["secret"] = secret
alt_data["secret"] = secret
else:
authMode = "pwd"

alt_data = json.dumps(alt_data)

return method, url, alt_url, authMode, data, alt_data

def _parse_token_response(response: dict,
setToken: bool,
mainVer: int,
base64_credential: str) -> Tuple[Union[Tuple[str, str], str], dict]:
if not response.get("error"):
token = response["token"]
if setToken:
apiToken = token
authHeader = {'Authorization': "Bearer " + apiToken}
else:
apiToken = None
authHeader = {
'Authorization': 'Basic {0}'.format(base64_credential)}

if response.get("expiration"):
# On >=4.1 the format for the date of expiration changed. Convert back to old format
# Can't use self._versionGreaterThan4_0 since you need a token for that
if mainVer == 4:
return (token, response.get("expiration")), authHeader
else:
return (token, response.get("expiration"), \
datetime.utcfromtimestamp(
float(response.get("expiration"))).strftime('%Y-%m-%d %H:%M:%S')), authHeader
else:
return token, authHeader

elif "Endpoint is not found from url = /requesttoken" in response["message"]:
raise TigerGraphException("REST++ authentication is not enabled, can't generate token.",
None)
else:
raise TigerGraphException(
response["message"], (response["code"] if "code" in response else None))
Loading

0 comments on commit 1957332

Please sign in to comment.