diff --git a/.gitignore b/.gitignore index 2c2b81b6..07cae8e4 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,6 @@ site/ pyTigerGraph.egg-info tests1.py .ipynb_checkpoints -pyrightconfig.json +venv +requirement.txt +pyrightconfig.json \ No newline at end of file diff --git a/pyTigerGraph/__init__.py b/pyTigerGraph/__init__.py index fe4ba969..f4d495a1 100644 --- a/pyTigerGraph/__init__.py +++ b/pyTigerGraph/__init__.py @@ -1,4 +1,5 @@ from pyTigerGraph.pyTigerGraph import TigerGraphConnection +from pyTigerGraph.pytgasync.pyTigerGraph import AsyncTigerGraphConnection __version__ = "1.7.4" diff --git a/pyTigerGraph/ai/ai.py b/pyTigerGraph/ai/ai.py index 84ac724b..bd798aa3 100644 --- a/pyTigerGraph/ai/ai.py +++ b/pyTigerGraph/ai/ai.py @@ -31,13 +31,14 @@ from pyTigerGraph import TigerGraphConnection + class AI: - def __init__(self, conn: TigerGraphConnection) -> None: + def __init__(self, conn: TigerGraphConnection) -> None: """NO DOC: Initiate an AI object. Currently in beta testing. Args: conn (TigerGraphConnection): Accept a TigerGraphConnection to run queries with - + Returns: None """ @@ -47,7 +48,7 @@ def __init__(self, conn: TigerGraphConnection) -> None: # split scheme and host scheme, host = conn.host.split("://") self.nlqs_host = scheme + "://copilot-" + host - + def configureInquiryAIHost(self, hostname: str): """ DEPRECATED: Configure the hostname of the InquiryAI service. Not recommended to use. Use configureCoPilotHost() instead. @@ -56,8 +57,8 @@ def configureInquiryAIHost(self, hostname: str): The hostname (and port number) of the InquiryAI serivce. """ warnings.warn( - "The `configureInquiryAIHost()` function is deprecated; use `configureCoPilotHost()` function instead.", - DeprecationWarning) + "The `configureInquiryAIHost()` function is deprecated; use `configureCoPilotHost()` function instead.", + DeprecationWarning) self.nlqs_host = hostname def configureCoPilotHost(self, hostname: str): @@ -96,14 +97,17 @@ def registerCustomQuery(self, query_name: str, description: str = None, docstrin "queries": [query_name] } url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) else: if description is None: - raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.") + raise ValueError( + "When using TigerGraph 3.x, query descriptions are required parameters.") if docstring is None: - raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.") + raise ValueError( + "When using TigerGraph 3.x, query docstrings are required parameters.") if param_types is None: - raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.") + raise ValueError( + "When using TigerGraph 3.x, query parameter types are required parameters.") data = { "function_header": query_name, "description": description, @@ -112,8 +116,8 @@ def registerCustomQuery(self, query_name: str, description: str = None, docstrin "graphname": self.conn.graphname } url = self.nlqs_host+"/"+self.conn.graphname+"/register_docs" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) - + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) + def updateCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None): """ Update a custom query with the InquiryAI service. Args: @@ -142,14 +146,17 @@ def updateCustomQuery(self, query_name: str, description: str = None, docstring: "queries": [query_name] } url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_from_gsql" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) else: if description is None: - raise ValueError("When using TigerGraph 3.x, query descriptions are required parameters.") + raise ValueError( + "When using TigerGraph 3.x, query descriptions are required parameters.") if docstring is None: - raise ValueError("When using TigerGraph 3.x, query docstrings are required parameters.") + raise ValueError( + "When using TigerGraph 3.x, query docstrings are required parameters.") if param_types is None: - raise ValueError("When using TigerGraph 3.x, query parameter types are required parameters.") + raise ValueError( + "When using TigerGraph 3.x, query parameter types are required parameters.") data = { "function_header": query_name, "description": description, @@ -160,8 +167,8 @@ def updateCustomQuery(self, query_name: str, description: str = None, docstring: json_payload = {"id": "", "query_info": data} url = self.nlqs_host+"/"+self.conn.graphname+"/upsert_docs" - return self.conn._req("POST", url, authMode="pwd", data = json_payload, jsonData=True, resKey=None) - + return self.conn._req("POST", url, authMode="pwd", data=json_payload, jsonData=True, resKey=None) + def deleteCustomQuery(self, query_name: str): """ Delete a custom query with the InquiryAI service. Args: @@ -172,9 +179,9 @@ def deleteCustomQuery(self, query_name: str): """ data = {"ids": [], "expr": "function_header == '"+query_name+"'"} url = self.nlqs_host+"/"+self.conn.graphname+"/delete_docs" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) - def retrieveDocs(self, query:str, top_k:int = 3): + def retrieveDocs(self, query: str, top_k: int = 3): """ Retrieve docs from the vector store. Args: query (str): @@ -188,8 +195,9 @@ def retrieveDocs(self, query:str, top_k:int = 3): "query": query } - url = self.nlqs_host+"/"+self.conn.graphname+"/retrieve_docs?top_k="+str(top_k) - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None, skipCheck=True) + url = self.nlqs_host+"/"+self.conn.graphname + \ + "/retrieve_docs?top_k="+str(top_k) + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None, skipCheck=True) def query(self, query): """ Query the database with natural language. @@ -204,8 +212,8 @@ def query(self, query): } url = self.nlqs_host+"/"+self.conn.graphname+"/query" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) - + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) + def coPilotHealth(self): """ Check the health of the CoPilot service. Returns: @@ -213,7 +221,7 @@ def coPilotHealth(self): """ url = self.nlqs_host+"/health" return self.conn._req("GET", url, authMode="pwd", resKey=None) - + def initializeSupportAI(self): """ Initialize the SupportAI service. Returns: @@ -221,7 +229,7 @@ def initializeSupportAI(self): """ url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/initialize" return self.conn._req("POST", url, authMode="pwd", resKey=None) - + def createDocumentIngest(self, data_source, data_source_config, loader_config, file_format): """ Create a document ingest. Args: @@ -244,8 +252,8 @@ def createDocumentIngest(self, data_source, data_source_config, loader_config, f } url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/create_ingest" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) - + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) + def runDocumentIngest(self, load_job_id, data_source_id, data_path): """ Run a document ingest. Args: @@ -264,9 +272,9 @@ def runDocumentIngest(self, load_job_id, data_source_id, data_path): "file_path": data_path } url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/ingest" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) - - def searchDocuments(self, query, method = "hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}): + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) + + def searchDocuments(self, query, method="hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}): """ Search documents. Args: query (str): @@ -284,9 +292,9 @@ def searchDocuments(self, query, method = "hnswoverlap", method_parameters: dict "method_params": method_parameters } url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/search" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) - - def answerQuestion(self, query, method = "hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}): + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) + + def answerQuestion(self, query, method="hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}): """ Answer a question. Args: query (str): @@ -304,8 +312,8 @@ def answerQuestion(self, query, method = "hnswoverlap", method_parameters: dict "method_params": method_parameters } url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/answerquestion" - return self.conn._req("POST", url, authMode="pwd", data = data, jsonData=True, resKey=None) - + return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None) + def forceConsistencyUpdate(self, method="supportai"): """ Force a consistency update for SupportAI embeddings. Args: @@ -317,8 +325,7 @@ def forceConsistencyUpdate(self, method="supportai"): """ url = f"{self.nlqs_host}/{self.conn.graphname}/{method}/forceupdate/" return self.conn._req("GET", url, authMode="pwd", resKey=None) - - ''' TODO: Add support in CoPilot + def checkConsistencyProgress(self, method="supportai"): """ Check the progress of the consistency update. Args: @@ -330,4 +337,3 @@ def checkConsistencyProgress(self, method="supportai"): """ url = f"{self.nlqs_host}/{self.conn.graphname}/supportai/consistency_status/{method}" return self.conn._req("GET", url, authMode="pwd", resKey=None) - ''' \ No newline at end of file diff --git a/pyTigerGraph/common/__init__.py b/pyTigerGraph/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyTigerGraph/common/auth.py b/pyTigerGraph/common/auth.py new file mode 100644 index 00000000..fac2e403 --- /dev/null +++ b/pyTigerGraph/common/auth.py @@ -0,0 +1,140 @@ +import json +import logging + +from datetime import datetime +from typing import Union, Tuple, Dict + +from pyTigerGraph.common.exception import TigerGraphException + +logger = logging.getLogger(__name__) + +def _parse_get_secrets(response: str) -> Dict[str, str]: + secrets_dict = {} + lines = response.split("\n") + + i = 0 + while i < len(lines): + line = lines[i] + # s = "" + if "- Secret" in line: + secret = line.split(": ")[1] + i += 1 + line = lines[i] + if "- Alias" in line: + secrets_dict[line.split(": ")[1]] = secret + i += 1 + return secrets_dict + +def _parse_create_secret(response: str, alias: str = "", withAlias: bool = False) -> Union[str, Dict[str, str]]: + try: + if "already exists" in response: + error_msg = "The secret " + if alias != "": + error_msg += "with alias {} ".format(alias) + error_msg += "already exists." + raise TigerGraphException(error_msg, "E-00001") + + secret = "".join(response).replace('\n', '').split( + 'The secret: ')[1].split(" ")[0].strip() + + if not withAlias: + if logger.level == logging.DEBUG: + logger.debug("return: " + str(secret)) + logger.info("exit: createSecret (withAlias") + + return secret + + if alias: + ret = {alias: secret} + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: createSecret (alias)") + + return ret + + return secret + + except IndexError as e: + raise TigerGraphException( + "Failed to parse secret from response.", "E-00002") from e + +def _prep_token_request(restppUrl: str, + gsUrl: str, + graphname: str, + version: str = None, + secret: str = None, + lifetime: int = None, + token: str = None, + method: str = None): + major_ver, minor_ver, patch_ver = (0, 0, 0) + if version: + major_ver, minor_ver, patch_ver = version.split(".") + + if 0 < int(major_ver) < 3 or (int(major_ver) == 3 and int(minor_ver) < 5): + method = "GET" + url = restppUrl + "/requesttoken?secret=" + secret + \ + ("&lifetime=" + str(lifetime) if lifetime else "") + \ + ("&token=" + token if token else "") + authMode = None + if not secret: + raise TigerGraphException( + "Cannot request a token with username/password for versions < 3.5.") + else: + method = "POST" + url = gsUrl + "/gsql/v1/tokens" # used for TG 4.x + data = {"graph": graphname} + + # alt_url and alt_data used to construct the method and url for functions run in TG version 3.x + alt_url = restppUrl+"/requesttoken" # used for TG 3.x + alt_data = {} + + if lifetime: + data["lifetime"] = str(lifetime) + alt_data["lifetime"] = str(lifetime) + if token: + data["tokens"] = token + alt_data["token"] = token + if secret: + authMode = "None" + data["secret"] = secret + alt_data["secret"] = secret + else: + authMode = "pwd" + + alt_data = json.dumps(alt_data) + + return method, url, alt_url, authMode, data, alt_data + +def _parse_token_response(response: dict, + setToken: bool, + mainVer: int, + base64_credential: str) -> Tuple[Union[Tuple[str, str], str], dict]: + if not response.get("error"): + token = response["token"] + if setToken: + apiToken = token + authHeader = {'Authorization': "Bearer " + apiToken} + else: + apiToken = None + authHeader = { + 'Authorization': 'Basic {0}'.format(base64_credential)} + + if response.get("expiration"): + # On >=4.1 the format for the date of expiration changed. Convert back to old format + # Can't use self._versionGreaterThan4_0 since you need a token for that + if mainVer == 4: + return (token, response.get("expiration")), authHeader + else: + return (token, response.get("expiration"), \ + datetime.utcfromtimestamp( + float(response.get("expiration"))).strftime('%Y-%m-%d %H:%M:%S')), authHeader + else: + return token, authHeader + + elif "Endpoint is not found from url = /requesttoken" in response["message"]: + raise TigerGraphException("REST++ authentication is not enabled, can't generate token.", + None) + else: + raise TigerGraphException( + response["message"], (response["code"] if "code" in response else None)) diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py new file mode 100644 index 00000000..0e2454b3 --- /dev/null +++ b/pyTigerGraph/common/base.py @@ -0,0 +1,395 @@ +""" +`TigerGraphConnection` + +A TigerGraphConnection object provides the HTTP(S) communication used by all other modules. +""" + +import base64 +import json +import logging +import sys +import re +import warnings +import requests + +from typing import Union +from urllib.parse import urlparse + +from pyTigerGraph.common.exception import TigerGraphException + + +def excepthook(type, value, traceback): + """NO DOC + + This function prints out a given traceback and exception to sys.stderr. + + See: https://docs.python.org/3/library/sys.html#sys.excepthook + """ + print(value) + # TODO Proper logging + + +logger = logging.getLogger(__name__) + + +class PyTigerGraphCore(object): + def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", + gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", + tgCloud: bool = False, restppPort: Union[int, str] = "9000", + gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", + apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): + """Initiate a connection object. + + Args: + host: + The host name or IP address of the TigerGraph server. Make sure to include the + protocol (http:// or https://). If `certPath` is `None` and the protocol is https, + a self-signed certificate will be used. + graphname: + The default graph for running queries. + gsqlSecret: + The secret key for GSQL. See https://docs.tigergraph.com/tigergraph-server/current/user-access/managing-credentials#_secrets. + username: + The username on the TigerGraph server. + password: + The password for that user. + tgCloud: + Set to `True` if using TigerGraph Cloud. If your hostname contains `tgcloud`, then + this is automatically set to `True`, and you do not need to set this argument. + restppPort: + The port for REST++ queries. + gsPort: + The port for gsql server. + gsqlVersion: + The version of the GSQL client to be used. Effectively the version of the database + being connected to. + apiToken (Optional): + Paremeter for specifying a RESTPP service token. Use `getToken()` to get a token. + version: + DEPRECATED; use `gsqlVersion`. + useCert: + DEPRECATED; the need for a CA certificate is now determined by URL scheme. + certPath: + The filesystem path to the CA certificate. Required in case of https connections. + debug: + DEPRECATED; configure standard logging in your app. + sslPort: + Port for fetching SSL certificate in case of firewall. + gcp: + DEPRECATED. Previously used for connecting to databases provisioned on GCP in TigerGraph Cloud. + jwtToken: + The JWT token generated from customer side for authentication + + Raises: + TigerGraphException: In case on invalid URL scheme. + + """ + logger.info("entry: __init__") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + inputHost = urlparse(host) + if inputHost.scheme not in ["http", "https"]: + raise TigerGraphException("Invalid URL scheme. Supported schemes are http and https.", + "E-0003") + self.netloc = inputHost.netloc + self.host = "{0}://{1}".format(inputHost.scheme, self.netloc) + if gsqlSecret != "": + self.username = "__GSQL__secret" + self.password = gsqlSecret + else: + self.username = username + self.password = password + self.graphname = graphname + self.responseConfigHeader = {} + self.awsIamHeaders = {} + + self.jwtToken = jwtToken + self.apiToken = apiToken + self.base64_credential = base64.b64encode( + "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") + + self.authHeader = self._set_auth_header() + + # TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version + if gsqlVersion: + self.version = gsqlVersion + elif version: + warnings.warn( + "The `version` parameter is deprecated; use the `gsqlVersion` parameter instead.", + DeprecationWarning) + self.version = version + else: + self.version = "" + + if debug is not None: + warnings.warn( + "The `debug` parameter is deprecated; configure standard logging in your app.", + DeprecationWarning) + if not debug: + sys.excepthook = excepthook # TODO Why was this necessary? Can it be removed? + sys.tracebacklimit = None + + self.schema = None + + # TODO Remove useCert parameter + if useCert is not None: + warnings.warn( + "The `useCert` parameter is deprecated; the need for a CA certificate is now determined by URL scheme.", + DeprecationWarning) + if inputHost.scheme == "http": + self.downloadCert = False + self.useCert = False + self.certPath = "" + elif inputHost.scheme == "https": + if not certPath: + self.downloadCert = True + else: + self.downloadCert = False + self.useCert = True + self.certPath = certPath + self.sslPort = str(sslPort) + + # TODO Remove gcp parameter + if gcp: + warnings.warn("The `gcp` parameter is deprecated.", + DeprecationWarning) + self.tgCloud = tgCloud or gcp + if "tgcloud" in self.netloc.lower(): + try: # If get request succeeds, using TG Cloud instance provisioned after 6/20/2022 + self._get(self.host + "/api/ping", resKey="message") + self.tgCloud = True + # If get request fails, using TG Cloud instance provisioned before 6/20/2022, before new firewall config + except requests.exceptions.RequestException: + self.tgCloud = False + except TigerGraphException: + raise (TigerGraphException("Incorrect graphname.")) + + restppPort = str(restppPort) + sslPort = str(sslPort) + if self.tgCloud and (restppPort == "9000" or restppPort == "443"): + self.restppPort = sslPort + self.restppUrl = self.host + ":" + sslPort + "/restpp" + else: + self.restppPort = restppPort + self.restppUrl = self.host + ":" + self.restppPort + self.gsPort = "" + gsPort = str(gsPort) + if self.tgCloud and (gsPort == "14240" or gsPort == "443"): + self.gsPort = sslPort + self.gsUrl = self.host + ":" + sslPort + else: + self.gsPort = gsPort + self.gsUrl = self.host + ":" + self.gsPort + self.url = "" + + if self.username.startswith("arn:aws:iam::"): + import boto3 + from botocore.awsrequest import AWSRequest + from botocore.auth import SigV4Auth + # Prepare a GetCallerIdentity request. + request = AWSRequest( + method="POST", + url="https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", + headers={ + 'Host': 'sts.amazonaws.com' + }) + # Get headers + SigV4Auth(boto3.Session().get_credentials(), + "sts", "us-east-1").add_auth(request) + self.awsIamHeaders["X-Amz-Date"] = request.headers["X-Amz-Date"] + self.awsIamHeaders["X-Amz-Security-Token"] = request.headers["X-Amz-Security-Token"] + self.awsIamHeaders["Authorization"] = request.headers["Authorization"] + + if self.jwtToken: + self._verify_jwt_token_support() + + self.asynchronous = False + + logger.info("exit: __init__") + + def _set_auth_header(self): + """Set the authentication header based on available tokens or credentials.""" + if self.jwtToken: + return {"Authorization": "Bearer " + self.jwtToken} + elif self.apiToken: + return {"Authorization": "Bearer " + self.apiToken} + else: + return {"Authorization": "Basic {0}".format(self.base64_credential)} + + def _verify_jwt_token_support(self): + try: + # Check JWT support for RestPP server + logger.debug( + "Attempting to verify JWT token support with getVer() on RestPP server.") + logger.debug(f"Using auth header: {self.authHeader}") + version = self.getVer() + logger.info(f"Database version: {version}") + + # Check JWT support for GSQL server + if self._version_greater_than_4_0(): + logger.debug( + f"Attempting to get auth info with URL: {self.gsUrl + '/gsql/v1/auth/simple'}") + self._get(f"{self.gsUrl}/gsql/v1/auth/simple", + authMode="token", resKey=None) + else: + logger.debug( + f"Attempting to get auth info with URL: {self.gsUrl + '/gsqlserver/gsql/simpleauth'}") + self._get(f"{self.gsUrl}/gsqlserver/gsql/simpleauth", + authMode="token", resKey=None) + except requests.exceptions.ConnectionError as e: + logger.error(f"Connection error: {e}.") + raise RuntimeError(f"Connection error: {e}.") from e + except Exception as e: + message = "The JWT token might be invalid or expired or DB version doesn't support JWT token. Please generate new JWT token or switch to API token or username/password." + logger.error(f"Error occurred: {e}. {message}") + raise RuntimeError(message) from e + + def _locals(self, _locals: dict) -> str: + del _locals["self"] + return str(_locals) + + def _error_check(self, res: dict) -> bool: + """Checks if the JSON document returned by an endpoint has contains `error: true`. If so, + it raises an exception. + + Args: + res: + The output from a request. + + Returns: + False if no error occurred. + + Raises: + TigerGraphException: if request returned with error, indicated in the returned JSON. + """ + if "error" in res and res["error"] and res["error"] != "false": + # Endpoint might return string "false" rather than Boolean false + raise TigerGraphException( + res["message"], (res["code"] if "code" in res else None)) + return False + + def _prep_req(self, authMode, headers, url, method, data): + logger.info("entry: _req") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + _headers = {} + + # If JWT token is provided, always use jwtToken as token + if authMode == "token": + if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "": + token = self.jwtToken + elif isinstance(self.apiToken, tuple): + token = self.apiToken[0] + elif isinstance(self.apiToken, str) and self.apiToken.strip() != "": + token = self.apiToken + else: + token = None + + if token: + self.authHeader = {'Authorization': "Bearer " + token} + _headers = self.authHeader + else: + self.authHeader = { + 'Authorization': 'Basic {0}'.format(self.base64_credential)} + _headers = self.authHeader + authMode = 'pwd' + + if authMode == "pwd": + if self.jwtToken: + _headers = {'Authorization': "Bearer " + self.jwtToken} + else: + _headers = {'Authorization': 'Basic {0}'.format( + self.base64_credential)} + + if headers: + _headers.update(headers) + if self.awsIamHeaders: + # version >=4.1 has removed /gsqlserver/ + if url.startswith(self.gsUrl + "/gsqlserver/") or (self._versionGreaterThan4_0() and url.startswith(self.gsUrl)): + _headers.update(self.awsIamHeaders) + if self.responseConfigHeader: + _headers.update(self.responseConfigHeader) + if method == "POST" or method == "PUT" or method == "DELETE": + _data = data + else: + _data = None + + if self.useCert is True or self.certPath is not None: + verify = False + else: + verify = True + + _headers.update({"X-User-Agent": "pyTigerGraph"}) + + return _headers, _data, verify + + def _parse_req(self, res, jsonResponse, strictJson, skipCheck, resKey): + if jsonResponse: + try: + res = json.loads(res.text, strict=strictJson) + except: + raise TigerGraphException("Cannot parse json: " + res.text) + else: + res = res.text + + if not skipCheck: + self._error_check(res) + if not resKey: + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: _req (no resKey)") + + return res + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res[resKey])) + logger.info("exit: _req (resKey)") + + return res[resKey] + + def customizeHeader(self, timeout: int = 16_000, responseSize: int = 3.2e+7): + """Method to configure the request header. + + Args: + tiemout (int, optional): + The timeout value desired in milliseconds. Defaults to 16,000 ms (16 sec) + responseSize: + The size of the response in bytes. Defaults to 3.2E7 bytes (32 MB). + + Returns: + Nothing. Sets `responseConfigHeader` class attribute. + """ + self.responseConfigHeader = { + "GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)} + + def _parse_get_ver(self, version, component, full): + ret = "" + for v in version: + if v["name"] == component.lower(): + ret = v["version"] + if ret != "": + if full: + return ret + ret = re.search("_.+_", ret) + ret = ret.group().strip("_") + return ret + else: + raise TigerGraphException( + "\"" + component + "\" is not a valid component.", None) + + def _parse_get_version(self, response, raw): + if raw: + return response + res = response.split("\n") + components = [] + for i in range(len(res)): + if 2 < i < len(res) - 1: + m = res[i].split() + component = {"name": m[0], "version": m[1], "hash": m[2], + "datetime": m[3] + " " + m[4] + " " + m[5]} + components.append(component) + + return components diff --git a/pyTigerGraph/common/dataset.py b/pyTigerGraph/common/dataset.py new file mode 100644 index 00000000..ca32789e --- /dev/null +++ b/pyTigerGraph/common/dataset.py @@ -0,0 +1,31 @@ +import logging + +from pyTigerGraph.datasets import Datasets + +logger = logging.getLogger(__name__) + +def _parse_ingest_dataset(responses: str, cleanup: bool, dataset: Datasets): + for resp in responses: + stats = resp[0]["statistics"] + if "vertex" in stats: + for vstats in stats["vertex"]: + print( + "Ingested {} objects into VERTEX {}".format( + vstats["validObject"], vstats["typeName"] + ), + flush=True, + ) + if "edge" in stats: + for estats in stats["edge"]: + print( + "Ingested {} objects into EDGE {}".format( + estats["validObject"], estats["typeName"] + ), + flush=True, + ) + if logger.level == logging.DEBUG: + logger.debug(str(resp)) + + if cleanup: + print("---- Cleaning ----", flush=True) + dataset.clean_up() diff --git a/pyTigerGraph/common/edge.py b/pyTigerGraph/common/edge.py new file mode 100644 index 00000000..5cf9c1b5 --- /dev/null +++ b/pyTigerGraph/common/edge.py @@ -0,0 +1,474 @@ +import json +import logging + +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + import pandas as pd + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.util import ( + _safe_char +) +from pyTigerGraph.common.schema import ( + _upsert_attrs +) + +logger = logging.getLogger(__name__) + +___trgvtxids = "___trgvtxids" + +def _parse_get_edge_source_vertex_type(edgeTypeDetails): + # Edge type with a single source vertex type + if edgeTypeDetails["FromVertexTypeName"] != "*": + ret = edgeTypeDetails["FromVertexTypeName"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdgeSourceVertexType (single source)") + + return ret + + # Edge type with multiple source vertex types + if "EdgePairs" in edgeTypeDetails: + # v3.0 and later notation + vts = set() + for ep in edgeTypeDetails["EdgePairs"]: + vts.add(ep["From"]) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(vts)) + logger.info("exit: getEdgeSourceVertexType (multi source)") + + return vts + else: + # 2.6.1 and earlier notation + if logger.level == logging.DEBUG: + logger.debug("return: *") + logger.info( + "exit: getEdgeSourceVertexType (multi source, pre-3.x)") + + return "*" + +def _parse_get_edge_target_vertex_type(edgeTypeDetails): + # Edge type with a single target vertex type + if edgeTypeDetails["ToVertexTypeName"] != "*": + ret = edgeTypeDetails["ToVertexTypeName"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdgeTargetVertexType (single target)") + + return ret + + # Edge type with multiple target vertex types + if "EdgePairs" in edgeTypeDetails: + # v3.0 and later notation + vts = set() + for ep in edgeTypeDetails["EdgePairs"]: + vts.add(ep["To"]) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(vts)) + logger.info("exit: getEdgeTargetVertexType (multi target)") + + return vts + else: + # 2.6.1 and earlier notation + if logger.level == logging.DEBUG: + logger.debug("return: *") + logger.info( + "exit: getEdgeTargetVertexType (multi target, pre-3.x)") + + return "*" + +def _prep_get_edge_count_from(restppUrl: str, + graphname: str, + sourceVertexType: str = "", + sourceVertexId: Union[str, int] = None, + edgeType: str = "", + targetVertexType: str = "", + targetVertexId: Union[str, int] = None, + where: str = ""): + data = None + # If WHERE condition is not specified, use /builtins else user /vertices + if where or (sourceVertexType and sourceVertexId): + if not sourceVertexType or not sourceVertexId: + raise TigerGraphException( + "If where condition is specified, then both sourceVertexType and sourceVertexId" + " must be provided too.", None) + url = restppUrl + "/graph/" + _safe_char(graphname) + "/edges/" + \ + _safe_char(sourceVertexType) + "/" + \ + _safe_char(sourceVertexId) + if edgeType: + url += "/" + _safe_char(edgeType) + if targetVertexType: + url += "/" + _safe_char(targetVertexType) + if targetVertexId: + url += "/" + _safe_char(targetVertexId) + url += "?count_only=true" + if where: + url += "&filter=" + _safe_char(where) + else: + if not edgeType: # TODO Is this a valid check? + raise TigerGraphException( + "A valid edge type or \"*\" must be specified for edge type.", None) + data = '{"function":"stat_edge_number","type":"' + edgeType + '"' \ + + (',"from_type":"' + sourceVertexType + '"' if sourceVertexType else '') \ + + (',"to_type":"' + targetVertexType + '"' if targetVertexType else '') \ + + '}' + url = restppUrl + "/builtins/" + graphname + return url, data + +def _parse_get_edge_count_from(res, edgeType): + if len(res) == 1 and res[0]["e_type"] == edgeType: + ret = res[0]["count"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdgeCountFrom (single edge type)") + + return ret + + ret = {} + for r in res: + ret[r["e_type"]] = r["count"] + return ret + +def _prep_upsert_edge(sourceVertexType: str, + sourceVertexId: str, + edgeType: str, + targetVertexType: str, + targetVertexId: str, + attributes: dict = None): + '''defining edge schema structure for upsertEdge()''' + if not(attributes): + attributes = {} + + vals = _upsert_attrs(attributes) + data = json.dumps({ + "edges": { + sourceVertexType: { + sourceVertexId: { + edgeType: { + targetVertexType: { + targetVertexId: vals + } + } + } + } + } + }) + return data + +def _dumps(data) -> str: + """Generates the JSON format expected by the endpoint (Used in upsertEdges()). + + The important thing this function does is converting the list of target vertex IDs and + the attributes belonging to the edge instances into a JSON object that can contain + multiple occurrences of the same key. If the these details were stored in a dictionary + then in case of MultiEdge only the last instance would be retained (as the key would be + the target vertex ID). + + Args: + data: + The Python data structure containing the edge instance details. + + Returns: + The JSON to be sent to the endpoint. + """ + ret = "" + if isinstance(data, dict): + c1 = 0 + for k1, v1 in data.items(): + if c1 > 0: + ret += "," + if k1 == ___trgvtxids: + # Dealing with the (possibly multiple instances of) edge details + # v1 should be a dict of lists + c2 = 0 + for k2, v2 in v1.items(): + if c2 > 0: + ret += "," + c3 = 0 + for v3 in v2: + if c3 > 0: + ret += "," + ret += json.dumps(k2) + ':' + json.dumps(v3) + c3 += 1 + c2 += 1 + else: + ret += json.dumps(k1) + ':' + _dumps(data[k1]) + c1 += 1 + return "{" + ret + "}" + +def _prep_upsert_edges(sourceVertexType, + edgeType, + targetVertexType, + edges): + '''converting vertex parameters into edge structure''' + data = {sourceVertexType: {}} + l1 = data[sourceVertexType] + for e in edges: + if len(e) > 2: + vals = _upsert_attrs(e[2]) + else: + vals = {} + # sourceVertexId + # Converted to string as the key in the JSON payload must be a string + sourceVertexId = str(e[0]) + if sourceVertexId not in l1: + l1[sourceVertexId] = {} + l2 = l1[sourceVertexId] + # edgeType + if edgeType not in l2: + l2[edgeType] = {} + l3 = l2[edgeType] + # targetVertexType + if targetVertexType not in l3: + l3[targetVertexType] = {} + l4 = l3[targetVertexType] + if ___trgvtxids not in l4: + l4[___trgvtxids] = {} + l4 = l4[___trgvtxids] + # targetVertexId + # Converted to string as the key in the JSON payload must be a string + targetVertexId = str(e[1]) + if targetVertexId not in l4: + l4[targetVertexId] = [] + l4[targetVertexId].append(vals) + + data = _dumps({"edges": data}) + return data + +def _prep_upsert_edge_dataframe(df, from_id, to_id, attributes): + '''converting dataframe into an upsertable object structure''' + json_up = [] + + for index in df.index: + json_up.append(json.loads(df.loc[index].to_json())) + json_up[-1] = ( + index if from_id is None else json_up[-1][from_id], + index if to_id is None else json_up[-1][to_id], + json_up[-1] if attributes is None + else {target: json_up[-1][source] for target, source in attributes.items()} + ) + return json_up + +def _prep_get_edges(restppUrl: str, + graphname: str, + sourceVertexType: str, + sourceVertexId: str, + edgeType: str = "", + targetVertexType: str = "", + targetVertexId: str = "", + select: str = "", + where: str = "", + limit: Union[int, str] = None, + sort: str = "", + timeout: int = 0): + '''url builder for getEdges()''' + # TODO Change sourceVertexId to sourceVertexIds and allow passing both str and list as + # parameter + if not sourceVertexType or not sourceVertexId: + raise TigerGraphException( + "Both source vertex type and source vertex ID must be provided.", None) + url = restppUrl + "/graph/" + graphname + "/edges/" + sourceVertexType + "/" + \ + str(sourceVertexId) + if edgeType: + url += "/" + edgeType + if targetVertexType: + url += "/" + targetVertexType + if targetVertexId: + url += "/" + str(targetVertexId) + isFirst = True + if select: + url += "?select=" + select + isFirst = False + if where: + url += ("?" if isFirst else "&") + "filter=" + where + isFirst = False + if limit: + url += ("?" if isFirst else "&") + "limit=" + str(limit) + isFirst = False + if sort: + url += ("?" if isFirst else "&") + "sort=" + sort + isFirst = False + if timeout and timeout > 0: + url += ("?" if isFirst else "&") + "timeout=" + str(timeout) + return url + +def _prep_get_edges_by_type(graphname, + sourceVertexType, + edgeType): + '''build the query to select edges for getEdgesByType()''' + # TODO Support edges with multiple source vertex types + if isinstance(sourceVertexType, set) or sourceVertexType == "*": + raise TigerGraphException( + "Edges with multiple source vertex types are not currently supported.", None) + + queryText = \ + 'INTERPRET QUERY () FOR GRAPH $graph { \ + SetAccum @@edges; \ + start = {ANY}; \ + res = \ + SELECT s \ + FROM start:s-(:e)->ANY:t \ + WHERE e.type == "$edgeType" \ + AND s.type == "$sourceEdgeType" \ + ACCUM @@edges += e; \ + PRINT @@edges AS edges; \ + }' + + queryText = queryText.replace("$graph", graphname) \ + .replace('$sourceEdgeType', sourceVertexType) \ + .replace('$edgeType', edgeType) + return queryText + +def _parse_get_edge_stats(responses, skipNA): + '''error checking and parsing responses for getEdgeStats()''' + ret = {} + for et, res in responses: + if res["error"]: + if "stat_edge_attr is skip" in res["message"] or \ + "No valid edge for the input edge type" in res["message"]: + if not skipNA: + ret[et] = {} + else: + raise TigerGraphException(res["message"], + (res["code"] if "code" in res else None)) + else: + res = res["results"] + for r in res: + ret[r["e_type"]] = r["attributes"] + return ret + +def _prep_del_edges(restppUrl: str, + graphname: str, + sourceVertexType, + sourceVertexId, + edgeType, + targetVertexType, + targetVertexId, + where, + limit, + sort, + timeout): + '''url building for delEdges()''' + if not sourceVertexType or not sourceVertexId: + raise TigerGraphException("Both sourceVertexType and sourceVertexId must be provided.", + None) + + url = restppUrl + "/graph/" + graphname + "/edges/" + sourceVertexType + "/" + str( + sourceVertexId) + + if edgeType: + url += "/" + edgeType + if targetVertexType: + url += "/" + targetVertexType + if targetVertexId: + url += "/" + str(targetVertexId) + + isFirst = True + if where: + url += ("?" if isFirst else "&") + "filter=" + where + isFirst = False + if limit and sort: # These two must be provided together + url += ("?" if isFirst else "&") + "limit=" + \ + str(limit) + "&sort=" + sort + isFirst = False + if timeout and timeout > 0: + url += ("?" if isFirst else "&") + "timeout=" + str(timeout) + return url + +def edgeSetToDataFrame(edgeSet: list, + withId: bool = True, + withType: bool = False) -> 'pd.DataFrame': + """Converts an edge set to Pandas DataFrame + + Edge sets contain instances of the same edge type. Edge sets are not generated "naturally" + like vertex sets. Instead, you need to collect edges in (global) accumulators, like when you + want to visualize them in GraphStudio or by other tools. + + For example: + ``` + SetAccum @@edges; + + start = {country.*}; + + result = + SELECT trg + FROM start:src -(city_in_country:e)- city:trg + ACCUM @@edges += e; + + PRINT start, result, @@edges; + ``` + + The `@@edges` is an edge set. + It contains, for each edge instance, the source and target vertex type and ID, the edge type, + a directedness indicator and the (optional) attributes. / + + [NOTE] + `start` and `result` are vertex sets. + + An edge set has this structure (when serialised as JSON): + + [source.wrap, json] + ---- + [ + { + "e_type": , + "from_type": , + "from_id": , + "to_type": , + "to_id": , + "directed": , + "attributes": + { + "attr1": , + "attr2": , + ⋮ + } + }, + ⋮ + ] + ---- + + Args: + edgeSet: + A JSON array containing an edge set in the format returned by queries (see below). + withId: + Whether to include the type and primary ID of source and target vertices as a column. Default is `True`. + withType: + Whether to include edge type info as a column. Default is `False`. + + Returns: + A pandas DataFrame containing the edge attributes and optionally the type and primary + ID or source and target vertices, and the edge type. + + """ + logger.info("entry: edgeSetToDataFrame") + logger.debug("params: " + str(locals())) + + try: + import pandas as pd + except ImportError: + raise ImportError("Pandas is required to use this function. " + "Download pandas using 'pip install pandas'.") + + df = pd.DataFrame(edgeSet) + cols = [] + if withId: + cols.extend([df["from_type"], df["from_id"], + df["to_type"], df["to_id"]]) + if withType: + cols.append(df["e_type"]) + cols.append(pd.DataFrame(df["attributes"].tolist())) + + ret = pd.concat(cols, axis=1) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: edgeSetToDataFrame") + + return ret diff --git a/pyTigerGraph/pyTigerGraphException.py b/pyTigerGraph/common/exception.py similarity index 100% rename from pyTigerGraph/pyTigerGraphException.py rename to pyTigerGraph/common/exception.py diff --git a/pyTigerGraph/common/gsql.py b/pyTigerGraph/common/gsql.py new file mode 100644 index 00000000..6ea67efb --- /dev/null +++ b/pyTigerGraph/common/gsql.py @@ -0,0 +1,96 @@ +"""GSQL Interface + +Use GSQL within pyTigerGraph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import logging +import re + +from typing import Union, Tuple, Dict +from urllib.parse import urlparse, quote_plus + +from pyTigerGraph.common.base import PyTigerGraphCore +from pyTigerGraph.common.exception import TigerGraphException + +logger = logging.getLogger(__name__) + +ANSI_ESCAPE = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + +# Once again could just put resand query parameter in but this is more braindead and allows for easier pattern +def _parse_gsql(res, query: str, graphname: str = None, options=None): + def check_error(query: str, resp: str) -> None: + if "CREATE VERTEX" in query.upper(): + if "Failed to create vertex types" in resp: + raise TigerGraphException(resp) + if ("CREATE DIRECTED EDGE" in query.upper()) or ("CREATE UNDIRECTED EDGE" in query.upper()): + if "Failed to create edge types" in resp: + raise TigerGraphException(resp) + if "CREATE GRAPH" in query.upper(): + if ("The graph" in resp) and ("could not be created!" in resp): + raise TigerGraphException(resp) + if "CREATE DATA_SOURCE" in query.upper(): + if ("Successfully created local data sources" not in resp) and ("Successfully created data sources" not in resp): + raise TigerGraphException(resp) + if "CREATE LOADING JOB" in query.upper(): + if "Successfully created loading jobs" not in resp: + raise TigerGraphException(resp) + if "RUN LOADING JOB" in query.upper(): + if "LOAD SUCCESSFUL" not in resp: + raise TigerGraphException(resp) + + def clean_res(resp: list) -> str: + ret = [] + for line in resp: + if not line.startswith("__GSQL__"): + ret.append(line) + return "\n".join(ret) + + if isinstance(res, list): + ret = clean_res(res) + else: + ret = clean_res(res.splitlines()) + + check_error(query, ret) + + string_without_ansi = ANSI_ESCAPE.sub('', ret) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: gsql (success)") + + return string_without_ansi + +def _prep_get_udf(ExprFunctions: bool = True, ExprUtil: bool = True): + urls = {} # urls when using TG 4.x + alt_urls = {} # urls when using TG 3.x + if ExprFunctions: + alt_urls["ExprFunctions"] = ( + "/gsqlserver/gsql/userdefinedfunction?filename=ExprFunctions") + urls["ExprFunctions"] = ("/gsql/v1/udt/files/ExprFunctions") + if ExprUtil: + alt_urls["ExprUtil"] = ( + "/gsqlserver/gsql/userdefinedfunction?filename=ExprUtil") + urls["ExprUtil"] = ("/gsql/v1/udt/files/ExprUtil") + + return urls, alt_urls + +def _parse_get_udf(responses, json_out): + rets = [] + for file_name in responses: + resp = responses[file_name] + if not resp["error"]: + logger.info(f"{file_name} get successfully") + rets.append(resp["results"]) + else: + logger.error(f"Failed to get {file_name}") + raise TigerGraphException(resp["message"]) + + if json_out: + # concatente the list of dicts into one dict + rets = rets[0].update(rets[-1]) + return rets + if len(rets) == 2: + return tuple(rets) + if rets: + return rets[0] + return "" diff --git a/pyTigerGraph/common/loading.py b/pyTigerGraph/common/loading.py new file mode 100644 index 00000000..f02c25a2 --- /dev/null +++ b/pyTigerGraph/common/loading.py @@ -0,0 +1,29 @@ +"""Loading Job Functions + +The functions on this page run loading jobs on the TigerGraph server. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import logging + + +logger = logging.getLogger(__name__) + +def _prep_run_loading_job_with_file(filePath, jobName, fileTag, sep, eol): + '''read file contents for runLoadingJobWithFile()''' + try: + data = open(filePath, 'rb').read() + params = { + "tag": jobName, + "filename": fileTag, + } + if sep is not None: + params["sep"] = sep + if eol is not None: + params["eol"] = eol + return data, params + except OSError as ose: + logger.error(ose.strerror) + logger.info("exit: runLoadingJobWithFile") + + return None, None + # TODO Should throw exception instead? diff --git a/pyTigerGraph/common/path.py b/pyTigerGraph/common/path.py new file mode 100644 index 00000000..66a7f145 --- /dev/null +++ b/pyTigerGraph/common/path.py @@ -0,0 +1,141 @@ +"""Path Finding Functions. + +The functions on this page find paths between vertices within the graph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" + +import json +import logging + +from typing import Union + + +logger = logging.getLogger(__name__) + + +def _prepare_path_params(sourceVertices: Union[dict, tuple, list], + targetVertices: Union[dict, tuple, list], + maxLength: int = None, + vertexFilters: Union[list, dict] = None, + edgeFilters: Union[list, dict] = None, + allShortestPaths: bool = False) -> str: + """Prepares the input parameters by transforming them to the format expected by the path + algorithms. + + See xref:tigergraph-server:API:built-in-endpoints.adoc#[Parameters and output format for path finding] + + A vertex set is a dict that has three top-level keys: `v_type`, `v_id`, `attributes` (also a dictionary). + + Args: + sourceVertices: + A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; + the source vertices of the shortest paths sought. + targetVertices: + A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; + the target vertices of the shortest paths sought. + maxLength: + The maximum length of a shortest path. Optional, default is `6`. + vertexFilters: + An optional list of `(vertexType, condition)` tuples or + `{"type": , "condition": }` dictionaries. + edgeFilters: + An optional list of `(edgeType, condition)` tuples or + `{"type": , "condition": }` dictionaries. + allShortestPaths: + If `True`, the endpoint will return all shortest paths between the source and target. + Default is `False`, meaning that the endpoint will return only one path. + + Returns: + A string representation of the dictionary of end-point parameters. + """ + + def parse_vertices(vertices: Union[dict, tuple, list]) -> list: + """Parses vertex input parameters and converts it to the format required by the path + finding endpoints. + + Args: + vertices: + A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; + the source or target vertices of the shortest paths sought. + Returns: + A list of vertices in the format required by the path finding endpoints. + """ + logger.info("entry: parseVertices") + logger.debug("params: " + str(locals)) + + ret = [] + if not isinstance(vertices, list): + vertices = [vertices] + for v in vertices: + if isinstance(v, tuple): + tmp = {"type": v[0], "id": v[1]} + ret.append(tmp) + elif isinstance(v, dict) and "v_type" in v and "v_id" in v: + tmp = {"type": v["v_type"], "id": v["v_id"]} + ret.append(tmp) + else: + logger.warning("Invalid vertex type or value: " + str(v)) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: parseVertices") + + return ret + + def parse_filters(filters: Union[dict, tuple, list]) -> list: + """Parses filter input parameters and converts it to the format required by the path + finding endpoints. + + Args: + filters: + A list of `(vertexType, condition)` tuples or + `{"type": , "condition": }` dictionaries. + + Returns: + A list of filters in the format required by the path finding endpoints. + """ + logger.info("entry: parseFilters") + logger.debug("params: " + str(locals())) + + ret = [] + if not isinstance(filters, list): + filters = [filters] + for f in filters: + if isinstance(f, tuple): + tmp = {"type": f[0], "condition": f[1]} + ret.append(tmp) + elif isinstance(f, dict) and "type" in f and "condition" in f: + tmp = {"type": f["type"], "condition": f["condition"]} + ret.append(tmp) + else: + logger.warning("Invalid filter type or value: " + str(f)) + + logger.debug("return: " + str(ret)) + logger.info("exit: parseFilters") + + return ret + + logger.info("entry: _preparePathParams") + logger.debug("params: " + str(locals())) + + # Assembling the input payload + if not sourceVertices or not targetVertices: + return "" + # TODO Should allow returning error instead of handling missing parameters here? + data = {"sources": parse_vertices( + sourceVertices), "targets": parse_vertices(targetVertices)} + if vertexFilters: + data["vertexFilters"] = parse_filters(vertexFilters) + if edgeFilters: + data["edgeFilters"] = parse_filters(edgeFilters) + if maxLength: + data["maxLength"] = maxLength + if allShortestPaths: + data["allShortestPaths"] = True + + ret = json.dumps(data) + + logger.debug("return: " + str(ret)) + logger.info("exit: _preparePathParams") + + return ret diff --git a/pyTigerGraph/common/query.py b/pyTigerGraph/common/query.py new file mode 100644 index 00000000..d5b0a5a4 --- /dev/null +++ b/pyTigerGraph/common/query.py @@ -0,0 +1,126 @@ +"""Query Functions. + +The functions on this page run installed or interpret queries in TigerGraph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import json +import logging + +from datetime import datetime +from typing import TYPE_CHECKING, Union, Optional + +if TYPE_CHECKING: + import pandas as pd + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.util import ( + _safe_char +) + +logger = logging.getLogger(__name__) + +# TODO getQueries() # List _all_ query names +def _parse_get_installed_queries(fmt, ret): + if fmt == "json": + ret = json.dumps(ret) + if fmt == "df": + try: + import pandas as pd + except ImportError: + raise ImportError("Pandas is required to use this function. " + "Download pandas using 'pip install pandas'.") + ret = pd.DataFrame(ret).T + return ret + +# TODO installQueries() +# POST /gsql/queries/install +# xref:tigergraph-server:API:built-in-endpoints.adoc#_install_a_query[Install a query] + +# TODO checkQueryInstallationStatus() +# GET /gsql/queries/install/{request_id} +# xref:tigergraph-server:API:built-in-endpoints.adoc#_check_query_installation_status[Check query installation status] + +def _parse_query_parameters(params: dict) -> str: + """Parses a dictionary of query parameters and converts them to query strings. + + While most of the values provided for various query parameter types can be easily converted + to query strings (key1=value1&key2=value2), `SET` and `BAG` parameter types, and especially + `VERTEX` and `SET` (i.e. vertex primary ID types without vertex type specification) + require special handling. + + See xref:tigergraph-server:API:built-in-endpoints.adoc#_query_parameter_passing[Query parameter passing] + + TODO Accept this format for SET: + "key": [([p_id1, p_id2, ...], "vtype"), ...] + I.e. multiple primary IDs of the same vertex type + """ + logger.info("entry: _parseQueryParameters") + logger.debug("params: " + str(params)) + + ret = "" + for k, v in params.items(): + if isinstance(v, tuple): + if len(v) == 2 and isinstance(v[1], str): + ret += k + "=" + str(v[0]) + "&" + k + \ + ".type=" + _safe_char(v[1]) + "&" + else: + raise TigerGraphException( + "Invalid parameter value: (vertex_primary_id, vertex_type)" + " was expected.") + elif isinstance(v, list): + i = 0 + for vv in v: + if isinstance(vv, tuple): + if len(vv) == 2 and isinstance(vv[1], str): + ret += k + "[" + str(i) + "]=" + _safe_char(vv[0]) + "&" + \ + k + "[" + str(i) + "].type=" + vv[1] + "&" + else: + raise TigerGraphException( + "Invalid parameter value: (vertex_primary_id , vertex_type)" + " was expected.") + else: + ret += k + "=" + _safe_char(vv) + "&" + i += 1 + elif isinstance(v, datetime): + ret += k + "=" + \ + _safe_char(v.strftime("%Y-%m-%d %H:%M:%S")) + "&" + else: + ret += k + "=" + _safe_char(v) + "&" + ret = ret[:-1] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: _parseQueryParameters") + + return ret + +def _prep_run_installed_query(timeout, sizeLimit, runAsync, replica, threadLimit, memoryLimit): + """header builder for runInstalledQuery()""" + headers = {} + res_key = "results" + if timeout and timeout > 0: + headers["GSQL-TIMEOUT"] = str(timeout) + if sizeLimit and sizeLimit > 0: + headers["RESPONSE-LIMIT"] = str(sizeLimit) + if runAsync: + headers["GSQL-ASYNC"] = "true" + res_key = "request_id" + if replica: + headers["GSQL-REPLICA"] = str(replica) + if threadLimit: + headers["GSQL-THREAD-LIMIT"] = str(threadLimit) + if memoryLimit: + headers["GSQL-QueryLocalMemLimitMB"] = str(memoryLimit) + return headers, res_key + +def _prep_get_statistics(seconds, segments): + '''parameter parsing for getStatistics()''' + if not seconds: + seconds = 10 + else: + seconds = max(min(seconds, 0), 60) + if not segments: + segments = 10 + else: + segments = max(min(segments, 0), 100) + return seconds, segments diff --git a/pyTigerGraph/common/schema.py b/pyTigerGraph/common/schema.py new file mode 100644 index 00000000..022ba949 --- /dev/null +++ b/pyTigerGraph/common/schema.py @@ -0,0 +1,112 @@ +"""Schema Functions. + +The functions in this page retrieve information about the graph schema. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import json +import logging + +from typing import Union + + +logger = logging.getLogger(__name__) + +def _get_attr_type(attrType: dict) -> str: + """Returns attribute data type in simple format. + + Args: + attribute: + The details of the attribute's data type. + + Returns: + Either "(scalar_type)" or "(complex_type, scalar_type)" string. + """ + ret = attrType["Name"] + if "KeyTypeName" in attrType: + ret += "(" + attrType["KeyTypeName"] + \ + "," + attrType["ValueTypeName"] + ")" + elif "ValueTypeName" in attrType: + ret += "(" + attrType["ValueTypeName"] + ")" + + return ret + +def _upsert_attrs(attributes: dict) -> dict: + """Transforms attributes (provided as a table) into a hierarchy as expected by the upsert + functions. + + Args: + attributes: A dictionary of attribute/value pairs (with an optional operator) in this + format: + {: |(, ), …} + + Returns: + A dictionary in this format: + { + : {"value": }, + : {"value": , "op": } + } + + Documentation: + xref:tigergraph-server:API:built-in-endpoints.adoc#operation-codes[Operation codes] + """ + logger.info("entry: _upsertAttrs") + logger.debug("params: " + str(locals())) + + if not isinstance(attributes, dict): + return {} + # TODO Should return something else or raise exception? + vals = {} + for attr in attributes: + val = attributes[attr] + if isinstance(val, tuple): + vals[attr] = {"value": val[0], "op": val[1]} + elif isinstance(val, dict): + vals[attr] = {"value": {"keylist": list( + val.keys()), "valuelist": list(val.values())}} + else: + vals[attr] = {"value": val} + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(vals)) + logger.info("exit: _upsertAttrs") + + return vals + +def _prep_upsert_data(data: Union[str, object], + atomic: bool = False, + ackAll: bool = False, + newVertexOnly: bool = False, + vertexMustExist: bool = False, + updateVertexOnly: bool = False): + if not isinstance(data, str): + data = json.dumps(data) + headers = {} + if atomic: + headers["gsql-atomic-level"] = "atomic" + params = {} + if ackAll: + params["ack"] = "all" + if newVertexOnly: + params["new_vertex_only"] = True + if vertexMustExist: + params["vertex_must_exist"] = True + if updateVertexOnly: + params["update_vertex_only"] = True + + return data, headers, params + +def _prep_get_endpoints(restppUrl: str, + graphname: str, + builtin, + dynamic, + static): + """Builds url starter and preps parameters of getEndpoints""" + ret = {} + if not (builtin or dynamic or static): + bui = dyn = sta = True + else: + bui = builtin + dyn = dynamic + sta = static + url = restppUrl + "/endpoints/" + graphname + "?" + return bui, dyn, sta, url, ret diff --git a/pyTigerGraph/common/util.py b/pyTigerGraph/common/util.py new file mode 100644 index 00000000..77aea0dc --- /dev/null +++ b/pyTigerGraph/common/util.py @@ -0,0 +1,71 @@ +"""Utility Functions. + +Utility functions for pyTigerGraph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" + +import logging +import urllib + +from typing import Any, TYPE_CHECKING +from urllib.parse import urlparse + +from pyTigerGraph.common.exception import TigerGraphException + +logger = logging.getLogger(__name__) + +def _safe_char(inputString: Any) -> str: + """Replace special characters in string using the %xx escape. + + Args: + inputString: + The string to process + + Returns: + Processed string. + + Documentation: + https://docs.python.org/3/library/urllib.parse.html#url-quoting + """ + return urllib.parse.quote(str(inputString), safe='') + +def _parse_get_license_info(res): + ret = {} + if not res["error"]: + ret["message"] = res["message"] + ret["expirationDate"] = res["results"][0]["Expiration date"] + ret["daysRemaining"] = res["results"][0]["Days remaining"] + elif "code" in res and res["code"] == "REST-5000": + ret["message"] = \ + "This instance does not have a valid enterprise license. Is this a trial version?" + ret["daysRemaining"] = -1 + else: + raise TigerGraphException(res["message"], res["code"]) + + return ret + +def _prep_get_system_metrics(from_ts: int = None, + to_ts: int = None, + latest: int = None, + who: str = None, + where: str = None): + params = {} + _json = {} # in >=4.1 we need a json request of different parameter names + if from_ts or to_ts: + _json["TimeRange"] = {} + if from_ts: + params["from"] = from_ts + _json['TimeRange']['StartTimestampNS'] = str(from_ts) + if to_ts: + params["to"] = to_ts + _json['TimeRange']['EndTimestampNS'] = str(from_ts) + if latest: + params["latest"] = latest + _json["LatestNum"] = str(latest) + if who: + params["who"] = who + if where: + params["where"] = where + _json["HostID"] = where + + return params, _json diff --git a/pyTigerGraph/common/vertex.py b/pyTigerGraph/common/vertex.py new file mode 100644 index 00000000..9eed7a39 --- /dev/null +++ b/pyTigerGraph/common/vertex.py @@ -0,0 +1,206 @@ +"""Vertex Functions. + +Functions to upsert, retrieve and delete vertices. + +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import json +import logging + +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + import pandas as pd + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.util import _safe_char + +logger = logging.getLogger(__name__) + + +def _parse_get_vertex_count(res, vertexType, where): + if where: + if vertexType == "*": + raise TigerGraphException( + "VertexType cannot be \"*\" if where condition is specified.", None) + else: + raise TigerGraphException( + "VertexType cannot be a list if where condition is specified.", None) + + ret = {d["v_type"]: d["count"] for d in res} + + if isinstance(vertexType, list): + ret = {vt: ret[vt] for vt in vertexType} + + return ret + +def _prep_upsert_vertex_dataframe(df, v_id, attributes): + json_up = [] + + for index in df.index: + json_up.append(json.loads(df.loc[index].to_json())) + json_up[-1] = ( + index if v_id is None else json_up[-1][v_id], + json_up[-1] if attributes is None + else {target: json_up[-1][source] for target, source in attributes.items()} + ) + return json_up + +def _prep_get_vertices(restppUrl: str, graphname: str, vertexType: str, select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", timeout: int = 0): + '''url builder for getVertices()''' + + url = restppUrl + "/graph/" + graphname + "/vertices/" + vertexType + isFirst = True + if select: + url += "?select=" + select + isFirst = False + if where: + url += ("?" if isFirst else "&") + "filter=" + where + isFirst = False + if limit: + url += ("?" if isFirst else "&") + "limit=" + str(limit) + isFirst = False + if sort: + url += ("?" if isFirst else "&") + "sort=" + sort + isFirst = False + if timeout and timeout > 0: + url += ("?" if isFirst else "&") + "timeout=" + str(timeout) + return url + +def _prep_get_vertices_by_id(restppUrl: str, graphname: str, vertexIds, vertexType): + '''parameter parsing and url building for getVerticesById()''' + + if not vertexIds: + raise TigerGraphException("No vertex ID was specified.", None) + vids = [] + if isinstance(vertexIds, (int, str)): + vids.append(vertexIds) + else: + vids = vertexIds + url = restppUrl + "/graph/" + graphname + "/vertices/" + vertexType + "/" + return vids, url + +def _parse_get_vertex_stats(responses, skipNA): + '''response parsing for getVertexStats()''' + ret = {} + for vt, res in responses: + if res["error"]: + if "stat_vertex_attr is skip" in res["message"]: + if not skipNA: + ret[vt] = {} + else: + raise TigerGraphException(res["message"], + (res["code"] if "code" in res else None)) + else: + res = res["results"] + for r in res: + ret[r["v_type"]] = r["attributes"] + + return ret + +def _prep_del_vertices(restppUrl: str, graphname: str, vertexType, + where, limit, sort, permanent, timeout): + '''url builder for delVertices()''' + url = restppUrl + "/graph/" + graphname + "/vertices/" + vertexType + isFirst = True + if where: + url += "?filter=" + where + isFirst = False + if limit and sort: # These two must be provided together + url += ("?" if isFirst else "&") + "limit=" + \ + str(limit) + "&sort=" + sort + isFirst = False + if permanent: + url += ("?" if isFirst else "&") + "permanent=true" + isFirst = False + if timeout and timeout > 0: + url += ("?" if isFirst else "&") + "timeout=" + str(timeout) + + return url + +def _prep_del_vertices_by_id(restppUrl: str, graphname: str, + vertexIds, vertexType, permanent, timeout): + '''url builder and param parser for delVerticesById()''' + if not vertexIds: + raise TigerGraphException("No vertex ID was specified.", None) + vids = [] + if isinstance(vertexIds, (int, str)): + vids.append(_safe_char(vertexIds)) + else: + vids = [_safe_char(f) for f in vertexIds] + + url1 = restppUrl + "/graph/" + \ + graphname + "/vertices/" + vertexType + "/" + url2 = "" + if permanent: + url2 = "?permanent=true" + if timeout and timeout > 0: + url2 += ("&" if url2 else "?") + "timeout=" + str(timeout) + return url1, url2, vids + +def vertexSetToDataFrame(vertexSet: list, withId: bool = True, + withType: bool = False) -> 'pd.DataFrame': + """Converts a vertex set to Pandas DataFrame. + + Vertex sets are used for both the input and output of `SELECT` statements. They contain + instances of vertices of the same type. + For each vertex instance, the vertex ID, the vertex type and the (optional) attributes are + present under the `v_id`, `v_type` and `attributes` keys, respectively. / + See an example in `edgeSetToDataFrame()`. + + A vertex set has this structure (when serialised as JSON): + [source.wrap,json] + ---- + [ + { + "v_id": , + "v_type": , + "attributes": + { + "attr1": , + "attr2": , + ⋮ + } + }, + ⋮ + ] + ---- + For more information on vertex sets see xref:gsql-ref:querying:declaration-and-assignment-statements.adoc#_vertex_set_variables[Vertex set variables]. + + Args: + vertexSet: + A JSON array containing a vertex set in the format returned by queries (see below). + withId: + Whether to include vertex primary ID as a column. + withType: + Whether to include vertex type info as a column. + + Returns: + A pandas DataFrame containing the vertex attributes (and optionally the vertex primary + ID and type). + """ + logger.info("entry: vertexSetToDataFrame") + logger.debug("params: " + str(locals())) + + try: + import pandas as pd + except ImportError: + raise ImportError("Pandas is required to use this function. " + "Download pandas using 'pip install pandas'.") + + df = pd.DataFrame(vertexSet) + cols = [] + if withId: + cols.append(df["v_id"]) + if withType: + cols.append(df["v_type"]) + cols.append(pd.DataFrame(df["attributes"].tolist())) + + ret = pd.concat(cols, axis=1) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: vertexSetToDataFrame") + + return ret diff --git a/pyTigerGraph/datasets.py b/pyTigerGraph/datasets.py index 75031d86..756cc6e1 100644 --- a/pyTigerGraph/datasets.py +++ b/pyTigerGraph/datasets.py @@ -18,6 +18,7 @@ class BaseDataset(ABC): "NO DOC" + def __init__(self, name: str = None) -> None: self.name = name self.ingest_ready = False @@ -72,7 +73,8 @@ def __init__(self, name: str = None, tmp_dir: str = "./tmp") -> None: name, tmp_dir ) ) - else: + + if not isdir(pjoin(tmp_dir, name)): dataset_url = self.get_dataset_url() # Check if it is an in-stock dataset. if not dataset_url: @@ -106,7 +108,8 @@ def download_extract(self) -> None: with tarfile.open(fileobj=raw, mode="r|gz") as tarobj: tarobj.extractall(path=self.tmp_dir) except ImportError: - warnings.warn("Cannot import tqdm. Downloading without progress report.") + warnings.warn( + "Cannot import tqdm. Downloading without progress report.") with tarfile.open(fileobj=resp.raw, mode="r|gz") as tarobj: tarobj.extractall(path=self.tmp_dir) print("Dataset downloaded.") @@ -162,4 +165,3 @@ def list(self) -> None: print("Available datasets:") for k in resp.json(): print("- {}".format(k)) - diff --git a/pyTigerGraph/gds/featurizer.py b/pyTigerGraph/gds/featurizer.py index bd0754d0..81024939 100644 --- a/pyTigerGraph/gds/featurizer.py +++ b/pyTigerGraph/gds/featurizer.py @@ -40,6 +40,7 @@ class AsyncFeaturizerResult: """AsyncFeaturizerResult Object to keep track of featurizer algorithms being ran in asynchronous mode. (`runAsync=True`). """ + def __init__(self, conn, algorithm, query_id, results=None): """NO DOC: class for asynchronous featurizer results. Populated during `runAlgorithm()` if `runAsync = True`. @@ -126,10 +127,10 @@ class Featurizer: print(res) ``` """ + def __init__( self, conn: "TigerGraphConnection", repo: str = None, algo_version: str = None ): - """NO DOC: Class for feature extraction. The job of a feature extracter is to install and run algorithms in the Graph Data Science (GDS) libarary. @@ -159,7 +160,8 @@ def __init__( self.major_ver ) else: - raise ValueError("Database version {} not supported.".format(self.algo_ver)) + raise ValueError( + "Database version {} not supported.".format(self.algo_ver)) self.repo = repo # Get algo dict from manifest try: @@ -198,7 +200,8 @@ def _get_db_version(self) -> Tuple[str, str, str]: def _get_algo_dict(self, manifest_file: str) -> dict: # Get algo dict from manifest if manifest_file.startswith("http"): - resp = requests.get(manifest_file, allow_redirects=False, timeout=10) + resp = requests.get( + manifest_file, allow_redirects=False, timeout=10) resp.raise_for_status() algo_dict = resp.json() else: @@ -228,7 +231,8 @@ def print_algos(algo_dict: dict, depth: int, algo_num: int = 0) -> int: for k, v in algo_dict.items(): if k == "name": algo_num += 1 - print("{}{:02}. name: {}".format(" " * depth, algo_num, v)) + print("{}{:02}. name: {}".format( + " " * depth, algo_num, v)) return algo_num if isinstance(v, dict): print("{}{}:".format(" " * depth, k)) @@ -245,7 +249,8 @@ def print_algos(algo_dict: dict, depth: int, algo_num: int = 0) -> int: else: print("Available algorithms per category:") for k in self.algo_dict: - print("- {}: {} algorithms".format(k, get_num_algos(self.algo_dict[k]))) + print("- {}: {} algorithms".format(k, + get_num_algos(self.algo_dict[k]))) print( "Call listAlgorithms() with the category name to see the list of algorithms" ) @@ -286,7 +291,8 @@ def _install_query_file( # Get query name from the first line firstline = query.split("\n", 1)[0] try: - query_name = re.search(r"QUERY (.+?)\(", firstline).group(1).strip() + query_name = re.search( + r"QUERY (.+?)\(", firstline).group(1).strip() except: raise ValueError( "Cannot parse the query file. It should start with CREATE QUERY ... " @@ -312,13 +318,14 @@ def _install_query_file( query = query.replace(placeholder, replace[placeholder]) self.query = query if ( - query_name == "tg_fastRP" + query_name == "tg_fastRP" and self.major_ver != "master" - and int(self.major_ver) <= 3 + and int(self.major_ver) <= 3 and int(self.minor_ver) <= 7 ): # Drop all jobs on the graph - self.conn.gsql("USE GRAPH {}\n".format(self.conn.graphname) + "drop job *") + self.conn.gsql("USE GRAPH {}\n".format( + self.conn.graphname) + "drop job *") res = add_attribute( self.conn, schema_type="VERTEX", @@ -381,9 +388,11 @@ def installAlgorithm( self.sch_type, ) = self._get_algo_details(self.algo_dict) if query_name not in self.algo_paths: - raise ValueError("Cannot find {} in the library.".format(query_name)) + raise ValueError( + "Cannot find {} in the library.".format(query_name)) for query in self.algo_paths[query_name]: - _ = self._install_query_file(query, global_change=global_change, distributed_mode=distributed_query) + _ = self._install_query_file( + query, global_change=global_change, distributed_mode=distributed_query) self.query_name = query_name return self.query_name @@ -392,9 +401,11 @@ def get_details(d: dict, paths: dict, types: dict, sch_obj: dict) -> None: if "name" in d.keys(): if "path" not in d.keys(): raise Exception( - "Cannot find path for {} in the manifest file".format(d["name"]) + "Cannot find path for {} in the manifest file".format( + d["name"]) ) - paths[d["name"]] = [pjoin(self.repo, p) for p in d["path"].split(";")] + paths[d["name"]] = [pjoin(self.repo, p) + for p in d["path"].split(";")] if "value_type" in d.keys(): types[d["name"]] = d["value_type"] if "schema_type" in d.keys(): @@ -419,7 +430,8 @@ def _get_query(self, query_name: str) -> str: self.sch_type, ) = self._get_algo_details(self.algo_dict) if query_name not in self.algo_paths: - raise ValueError("Cannot find {} in the library.".format(query_name)) + raise ValueError( + "Cannot find {} in the library.".format(query_name)) query_path = self.algo_paths[query_name][-1] if query_path.startswith("http"): resp = requests.get(query_path, allow_redirects=False, timeout=10) @@ -478,7 +490,7 @@ def _get_params(self, query: str): """ param_values = {} param_types = {} - header = query[query.find("(") + 1 : query.find(")")].strip() + header = query[query.find("(") + 1: query.find(")")].strip() if not header: return {}, {} header = header.split(",") @@ -507,7 +519,8 @@ def _get_params(self, query: str): param_values[param] = None param_types[param] = "bool" elif param_type.lower() == "string": - param_values[param] = default.strip('"').strip("'") if default else None + param_values[param] = default.strip( + '"').strip("'") if default else None param_types[param] = "str" else: param_values[param] = default @@ -654,14 +667,16 @@ def runAlgorithm( raise ValueError( "Please run installAlgorithm() to install this custom query first." ) - self.installAlgorithm(query_name, global_change=global_schema, distributed_query=distributed_query) + self.installAlgorithm( + query_name, global_change=global_schema, distributed_query=distributed_query) # Check query parameters for built-in queries. if not custom_query: if params is None: params = self.getParams(query_name, printout=False) if params: - missing_params = [k for k, v in params.items() if v is None] + missing_params = [ + k for k, v in params.items() if v is None] if missing_params: raise ValueError( 'Missing mandatory parameters: {}. Please run getParams("{}") for parameter details.'.format( @@ -678,7 +693,8 @@ def runAlgorithm( ) ) query_params.update(params) - missing_params = [k for k, v in query_params.items() if v is None] + missing_params = [ + k for k, v in query_params.items() if v is None] if missing_params: raise ValueError( 'Missing mandatory parameters: {}. Please run getParams("{}") for parameter details.'.format( @@ -740,9 +756,11 @@ def runAlgorithm( return result def _get_template_queries(self): - categories = self.conn.gsql("SHOW PACKAGE GDBMS_ALGO").strip().split("\n")[2:] + categories = self.conn.gsql( + "SHOW PACKAGE GDBMS_ALGO").strip().split("\n")[2:] for cat in categories: - resp = self.conn.gsql("SHOW PACKAGE GDBMS_ALGO.{}".format(cat.strip("- "))) + resp = self.conn.gsql( + "SHOW PACKAGE GDBMS_ALGO.{}".format(cat.strip("- "))) self.template_queries[cat.strip("- ")] = resp.strip() def _add_result_attribute( @@ -781,7 +799,8 @@ def _add_result_attribute( elif isinstance(params[key], list): schema_name = params[key] else: - raise ValueError("v_type should be either a list or string") + raise ValueError( + "v_type should be either a list or string") elif schema_type == "EDGE" and ( "e_type" in params or "e_type_set" in params ): @@ -791,7 +810,8 @@ def _add_result_attribute( elif isinstance(params[key], list): schema_name = params[key] else: - raise ValueError("e_type should be either a list or string") + raise ValueError( + "e_type should be either a list or string") # Find whether global or local changes are needed by checking schema type. global_types = [] local_types = [] diff --git a/pyTigerGraph/gds/metrics.py b/pyTigerGraph/gds/metrics.py index a2368117..65e70547 100644 --- a/pyTigerGraph/gds/metrics.py +++ b/pyTigerGraph/gds/metrics.py @@ -10,7 +10,8 @@ import warnings from typing import Union -__all__ = ["Accumulator", "Accuracy", "BinaryPrecision", "BinaryRecall", "Precision", "Recall"] +__all__ = ["Accumulator", "Accuracy", "BinaryPrecision", + "BinaryRecall", "Precision", "Recall"] class Accumulator: @@ -112,12 +113,13 @@ class BinaryRecall(Accumulator): * Call the update function to add predictions and labels. * Get recall score at any point by accessing the value property. """ + def __init__(self) -> None: """NO DOC""" super().__init__() warnings.warn( - "The `BinaryRecall` metric is deprecated; use `Recall` metric instead.", - DeprecationWarning) + "The `BinaryRecall` metric is deprecated; use `Recall` metric instead.", + DeprecationWarning) def update(self, preds: ndarray, labels: ndarray) -> None: """Add predictions and labels to be compared. @@ -145,6 +147,7 @@ def value(self) -> float: else: return None + class ConfusionMatrix(Accumulator): """Confusion Matrix Metric. Updates a confusion matrix as new updates occur. @@ -153,6 +156,7 @@ class ConfusionMatrix(Accumulator): num_classes (int): Number of classes in your classification task. """ + def __init__(self, num_classes: int) -> None: """Instantiate the Confusion Matrix metric. Args: @@ -175,7 +179,6 @@ def update(self, preds: ndarray, labels: ndarray) -> None: labels ), "The lists of predictions and labels must have same length" - confusion_mat = np.zeros((self.num_classes, self.num_classes)) for pair in zip(labels.tolist(), preds.tolist()): confusion_mat[int(pair[0]), int(pair[1])] += 1 @@ -190,10 +193,11 @@ def value(self) -> np.array: Consfusion matrix in dataframe form. ''' if self._count > 0: - return pd.DataFrame(self._cumsum, columns=["predicted_"+ str(i) for i in range(self.num_classes)], index=["label_"+str(i) for i in range(self.num_classes)]) + return pd.DataFrame(self._cumsum, columns=["predicted_" + str(i) for i in range(self.num_classes)], index=["label_"+str(i) for i in range(self.num_classes)]) else: return None + class Recall(ConfusionMatrix): """Recall Metric. @@ -218,7 +222,7 @@ def value(self) -> Union[dict, float]: recalls = {} for c in range(self.num_classes): - tp = cm[c,c] + tp = cm[c, c] fn = sum(cm[c, :]) - tp recalls[c] = tp/(tp+fn) if self.num_classes == 2: @@ -228,6 +232,7 @@ def value(self) -> Union[dict, float]: else: return None + class BinaryPrecision(Accumulator): """DEPRECATED: Binary Precision Metric. This metric is deprecated. Use the Precision metric instead. @@ -245,8 +250,8 @@ def __init__(self) -> None: """NO DOC""" super().__init__() warnings.warn( - "The `BinaryPrecision` metric is deprecated; use `Precision` metric instead.", - DeprecationWarning) + "The `BinaryPrecision` metric is deprecated; use `Precision` metric instead.", + DeprecationWarning) def update(self, preds: ndarray, labels: ndarray) -> None: """Add predictions and labels to be compared. @@ -274,6 +279,7 @@ def value(self) -> float: else: return None + class Precision(ConfusionMatrix): """Precision Metric. @@ -298,7 +304,7 @@ def value(self) -> Union[dict, float]: precs = {} for c in range(self.num_classes): - tp = cm[c,c] + tp = cm[c, c] fp = sum(cm[:, c]) - tp precs[c] = tp/(tp+fp) if self.num_classes == 2: @@ -308,9 +314,10 @@ def value(self) -> Union[dict, float]: else: return None + class MSE(Accumulator): """MSE Metrc. - + MSE = stem:[\sum(predicted-actual)^2/n] This metric is for regression tasks, i.e. predicting a n-dimensional vector of float values. @@ -320,6 +327,7 @@ class MSE(Accumulator): * Call the update function to add predictions and labels. * Get MSE value at any point by accessing the value property. """ + def update(self, preds: ndarray, labels: ndarray) -> None: """Add predictions and labels to be compared. @@ -346,6 +354,7 @@ def value(self) -> float: else: return None + class RMSE(MSE): """RMSE Metric. @@ -358,6 +367,7 @@ class RMSE(MSE): * Call the update function to add predictions and labels. * Get RMSE score at any point by accessing the value property. """ + def __init__(self): """NO DOC""" super().__init__() @@ -373,6 +383,7 @@ def value(self) -> float: else: return None + class MAE(Accumulator): """MAE Metrc. @@ -385,6 +396,7 @@ class MAE(Accumulator): * Call the update function to add predictions and labels. * Get MAE value at any point by accessing the value property. """ + def update(self, preds: ndarray, labels: ndarray) -> None: """Add predictions and labels to be compared. @@ -411,6 +423,7 @@ def value(self) -> float: else: return None + class HitsAtK(Accumulator): """Hits@K Metric. This metric is used in link prediction tasks, i.e. determining if two vertices have an edge between them. @@ -425,7 +438,8 @@ class HitsAtK(Accumulator): k (int): Top k number of entities to compare. """ - def __init__(self, k:int) -> None: + + def __init__(self, k: int) -> None: """Instantiate the Hits@K Metric Args: k (int): @@ -461,6 +475,7 @@ def value(self) -> float: else: return None + class RecallAtK(Accumulator): """Recall@K Metric. This metric is used in link prediction tasks, i.e. determining if two vertices have an edge between them @@ -474,7 +489,8 @@ class RecallAtK(Accumulator): k (int): Top k number of entities to compare. """ - def __init__(self, k:int) -> None: + + def __init__(self, k: int) -> None: """Instantiate the Recall@K Metric Args: k (int): @@ -510,12 +526,14 @@ def value(self) -> float: else: return None + class BaseMetrics(): """NO DOC""" + def __init__(self): """NO DOC""" self.reset_metrics() - + def reset_metrics(self): self.loss = Accumulator() @@ -530,7 +548,8 @@ class ClassificationMetrics(BaseMetrics): """Classification Metrics collection. Collects Loss, Accuracy, Precision, Recall, and Confusion Matrix Metrics. """ - def __init__(self, num_classes: int=2): + + def __init__(self, num_classes: int = 2): """Instantiate the Classification Metrics collection. Args: num_classes (int): @@ -560,25 +579,35 @@ def update_metrics(self, loss, out, batch, target_type=None): pred = out.argmax(dim=1) if isinstance(batch, dict): if target_type: - self.accuracy.update(pred[target_type], batch[target_type]["y"]) - self.confusion_matrix.update(pred[target_type], batch[target_type]["y"]) - self.precision.update(pred[target_type], batch[target_type]["y"]) + self.accuracy.update( + pred[target_type], batch[target_type]["y"]) + self.confusion_matrix.update( + pred[target_type], batch[target_type]["y"]) + self.precision.update( + pred[target_type], batch[target_type]["y"]) self.recall.update(pred[target_type], batch[target_type]["y"]) else: self.accuracy.update(pred, batch["y"]) self.confusion_matrix.update(pred, batch["y"]) self.precision.update(pred, batch["y"]) self.recall.update(pred, batch["y"]) - else: # batch is a PyG Object (has is_seed attribute) + else: # batch is a PyG Object (has is_seed attribute) if target_type: - self.accuracy.update(pred[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) - self.confusion_matrix.update(pred[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) - self.precision.update(pred[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) - self.recall.update(pred[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) + self.accuracy.update( + pred[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) + self.confusion_matrix.update( + pred[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) + self.precision.update( + pred[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) + self.recall.update( + pred[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) else: - self.accuracy.update(pred[batch.is_seed], batch.y[batch.is_seed]) - self.confusion_matrix.update(pred[batch.is_seed], batch.y[batch.is_seed]) - self.precision.update(pred[batch.is_seed], batch.y[batch.is_seed]) + self.accuracy.update( + pred[batch.is_seed], batch.y[batch.is_seed]) + self.confusion_matrix.update( + pred[batch.is_seed], batch.y[batch.is_seed]) + self.precision.update( + pred[batch.is_seed], batch.y[batch.is_seed]) self.recall.update(pred[batch.is_seed], batch.y[batch.is_seed]) def get_metrics(self): @@ -587,14 +616,17 @@ def get_metrics(self): Dictionary of Accuracy, Precision, Recall, and Confusion Matrix """ super_met = super().get_metrics() - metrics = {"accuracy": self.accuracy.value, "precision": self.precision.value, "recall": self.recall.value, "confusion_matrix": self.confusion_matrix.value} + metrics = {"accuracy": self.accuracy.value, "precision": self.precision.value, + "recall": self.recall.value, "confusion_matrix": self.confusion_matrix.value} metrics.update(super_met) return metrics + class RegressionMetrics(BaseMetrics): """Regression Metrics Collection. Collects Loss, MSE, RMSE, and MAE metrics. """ + def __init__(self): """Instantiate the Regression Metrics collection. """ @@ -628,9 +660,12 @@ def update_metrics(self, loss, out, batch, target_type=None): self.mae.update(out, batch["y"]) else: if target_type: - self.mse.update(out[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) - self.rmse.update(out[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) - self.mae.update(out[batch[target_type].is_seed], batch[target_type].y[batch[target_type].is_seed]) + self.mse.update(out[batch[target_type].is_seed], + batch[target_type].y[batch[target_type].is_seed]) + self.rmse.update(out[batch[target_type].is_seed], + batch[target_type].y[batch[target_type].is_seed]) + self.mae.update(out[batch[target_type].is_seed], + batch[target_type].y[batch[target_type].is_seed]) else: self.mse.update(out[batch.is_seed], batch.y[batch.is_seed]) self.rmse.update(out[batch.is_seed], batch.y[batch.is_seed]) @@ -654,6 +689,7 @@ class LinkPredictionMetrics(BaseMetrics): Collects Loss, Recall@K, and Hits@K metrics. """ + def __init__(self, k): """Instantiate the Classification Metrics collection. Args: @@ -693,4 +729,3 @@ def get_metrics(self): "k": self.k} metrics.update(super_met) return metrics - diff --git a/pyTigerGraph/gds/models/GraphSAGE.py b/pyTigerGraph/gds/models/GraphSAGE.py index e299c56a..a90e8480 100644 --- a/pyTigerGraph/gds/models/GraphSAGE.py +++ b/pyTigerGraph/gds/models/GraphSAGE.py @@ -9,10 +9,13 @@ from torch_geometric.nn import to_hetero import torch_geometric.nn as gnn except: - raise Exception("PyTorch Geometric required to use GraphSAGE. Please install PyTorch Geometric") + raise Exception( + "PyTorch Geometric required to use GraphSAGE. Please install PyTorch Geometric") + class BaseGraphSAGEModel(bm.BaseModel): """NO DOC.""" + def __init__(self, num_layers, out_dim, hidden_dim, dropout=0.0, heterogeneous=None): super().__init__() self.dropout = dropout @@ -33,14 +36,17 @@ def forward(self, batch, target_type=None): x = batch.x.float() edge_index = batch.edge_index return self.model(x, edge_index) - - def compute_loss(self, loss_fn = None): - raise NotImplementedError("Loss computation not implemented for BaseGraphSAGEModel") + + def compute_loss(self, loss_fn=None): + raise NotImplementedError( + "Loss computation not implemented for BaseGraphSAGEModel") + class GraphSAGEForVertexClassification(BaseGraphSAGEModel): """GraphSAGEForVertexClassification Use a GraphSAGE model to classify vertices. By default, this model collects `ClassficiationMetrics`, and uses cross entropy as its loss function. """ + def __init__(self, num_layers: int, out_dim: int, hidden_dim: int, dropout=0.0, heterogeneous=None, class_weights=None): """Initialize the GraphSAGE Vertex Classification Model. Args: @@ -87,7 +93,7 @@ def forward(self, batch, get_probs=False, target_type=None): else: return logits - def compute_loss(self, logits, batch, target_type=None, loss_fn = None): + def compute_loss(self, logits, batch, target_type=None, loss_fn=None): """Compute loss. Args: logits (torch.Tensor or dict of torch.Tensor): @@ -99,26 +105,30 @@ def compute_loss(self, logits, batch, target_type=None, loss_fn = None): loss_fn (callable, optional): The function to compute the loss with. Uses cross entropy loss if not defined. """ - if not(loss_fn): + if not (loss_fn): loss_fn = F.cross_entropy if self.heterogeneous: - loss = loss_fn(logits[batch[target_type].is_seed], - batch[target_type].y[batch[target_type].is_seed].long(), - self.class_weight) + loss = loss_fn(logits[batch[target_type].is_seed], + batch[target_type].y[batch[target_type].is_seed].long(), + self.class_weight) else: - loss = loss_fn(logits[batch.is_seed], batch.y[batch.is_seed].long(), self.class_weight) - else: # can't assume custom loss supports class weights + loss = loss_fn( + logits[batch.is_seed], batch.y[batch.is_seed].long(), self.class_weight) + else: # can't assume custom loss supports class weights if self.heterogeneous: - loss = loss_fn(logits[batch[target_type].is_seed], - batch[target_type].y[batch[target_type].is_seed].long()) + loss = loss_fn(logits[batch[target_type].is_seed], + batch[target_type].y[batch[target_type].is_seed].long()) else: - loss = loss_fn(logits[batch.is_seed], batch.y[batch.is_seed].long()) + loss = loss_fn(logits[batch.is_seed], + batch.y[batch.is_seed].long()) return loss + class GraphSAGEForVertexRegression(BaseGraphSAGEModel): """GraphSAGEForVertexRegression Use GraphSAGE for vertex regression tasks. By default, this model collects `RegressionMetrics`, and uses MSE as its loss function. """ + def __init__(self, num_layers: int, out_dim: int, hidden_dim: int, dropout=0.0, heterogeneous=None): """Initialize the GraphSAGE Vertex Regression Model. Args: @@ -163,11 +173,11 @@ def compute_loss(self, logits, batch, target_type=None, loss_fn=None): loss_fn (callable, optional): The function to compute the loss with. Uses MSE loss if not defined. """ - if not(loss_fn): + if not (loss_fn): loss_fn = F.mse_loss if self.heterogeneous: - loss = loss_fn(logits[target_type][batch[target_type].is_seed], - batch[target_type].y[batch[target_type].is_seed]) + loss = loss_fn(logits[target_type][batch[target_type].is_seed], + batch[target_type].y[batch[target_type].is_seed]) else: loss = loss_fn(logits[batch.is_seed], batch.y[batch.is_seed]) return loss @@ -177,7 +187,8 @@ class GraphSAGEForLinkPrediction(BaseGraphSAGEModel): """GraphSAGEForLinkPrediction By default, this model collects `LinkPredictionMetrics` with k = 10, and uses binary cross entropy as its loss function. """ - def __init__(self, num_layers, embedding_dim, hidden_dim, dropout = 0.0, heterogeneous=None): + + def __init__(self, num_layers, embedding_dim, hidden_dim, dropout=0.0, heterogeneous=None): """Initialize the GraphSAGE Link Prediction Model. Args: num_layers (int): @@ -219,8 +230,10 @@ def forward(self, batch, target_type=None): def decode(self, src_z, dest_z, pos_edge_index, neg_edge_index): """NO DOC.""" - edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges - logits = (src_z[edge_index[0]] * dest_z[edge_index[1]]).sum(dim=-1) # dot product + edge_index = torch.cat( + [pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges + logits = (src_z[edge_index[0]] * dest_z[edge_index[1]] + ).sum(dim=-1) # dot product return logits def get_link_labels(self, pos_edge_index, neg_edge_index): @@ -233,13 +246,17 @@ def get_link_labels(self, pos_edge_index, neg_edge_index): def generate_edges(self, batch, target_edge_type=None): """NO DOC.""" if self.heterogeneous: - pos_edges = batch[target_edge_type].edge_index[:, batch[target_edge_type].is_seed] - src_neg_edges = torch.randint(0, batch[target_edge_type[0]].x.shape[0], (pos_edges.shape[1],), dtype=torch.long) - dest_neg_edges = torch.randint(0, batch[target_edge_type[-1]].x.shape[0], (pos_edges.shape[1],), dtype=torch.long) + pos_edges = batch[target_edge_type].edge_index[:, + batch[target_edge_type].is_seed] + src_neg_edges = torch.randint( + 0, batch[target_edge_type[0]].x.shape[0], (pos_edges.shape[1],), dtype=torch.long) + dest_neg_edges = torch.randint( + 0, batch[target_edge_type[-1]].x.shape[0], (pos_edges.shape[1],), dtype=torch.long) neg_edges = torch.stack((src_neg_edges, dest_neg_edges)) else: pos_edges = batch.edge_index[:, batch.is_seed] - neg_edges = torch.randint(0, batch.x.shape[0], pos_edges.size(), dtype=torch.long) + neg_edges = torch.randint( + 0, batch.x.shape[0], pos_edges.size(), dtype=torch.long) return pos_edges, neg_edges def compute_loss(self, logits, batch, target_type=None, loss_fn=None): @@ -254,7 +271,7 @@ def compute_loss(self, logits, batch, target_type=None, loss_fn=None): loss_fn (callable, optional): The function to compute the loss with. Uses binary cross entropy loss if not defined. """ - if not(loss_fn): + if not (loss_fn): loss_fn = F.binary_cross_entropy_with_logits loss = loss_fn(logits, batch.y) return loss @@ -265,4 +282,4 @@ def get_embeddings(self, batch): batch (torch_geometric.Data or torch_geometric.HeteroData): Get the embeddings for all vertices in a batch. """ - return super().forward(batch) \ No newline at end of file + return super().forward(batch) diff --git a/pyTigerGraph/gds/models/NodePieceMLP.py b/pyTigerGraph/gds/models/NodePieceMLP.py index 37499fa6..541e53de 100644 --- a/pyTigerGraph/gds/models/NodePieceMLP.py +++ b/pyTigerGraph/gds/models/NodePieceMLP.py @@ -8,14 +8,17 @@ import torch.nn as nn import torch.nn.functional as F except: - raise Exception("PyTorch is required to use NodePiece MLPs. Please install PyTorch") + raise Exception( + "PyTorch is required to use NodePiece MLPs. Please install PyTorch") + class BaseNodePieceEmbeddingTable(nn.Module): """NO DOC.""" + def __init__(self, vocab_size: int, sequence_length: int, - embedding_dim: int=768): + embedding_dim: int = 768): super().__init__() self.embedding_dim = embedding_dim self.seq_len = sequence_length @@ -32,27 +35,31 @@ def forward(self, x): class BaseNodePieceMLPModel(nn.Module): """NO DOC.""" - def __init__(self, num_layers, out_dim, hidden_dim, vocab_size, sequence_length, embedding_dim = 768, dropout = 0.0): + + def __init__(self, num_layers, out_dim, hidden_dim, vocab_size, sequence_length, embedding_dim=768, dropout=0.0): super().__init__() self.embedding_dim = embedding_dim self.vocab_size = vocab_size self.sequence_length = sequence_length - self.base_embedding = BaseNodePieceEmbeddingTable(vocab_size, sequence_length, embedding_dim) - + self.base_embedding = BaseNodePieceEmbeddingTable( + vocab_size, sequence_length, embedding_dim) self.num_embedding_dim = embedding_dim*sequence_length self.in_layer = None self.hidden_dim = hidden_dim self.dropout = dropout self.out_layer = nn.Linear(self.hidden_dim, out_dim) - self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers-2)]) + self.hidden_layers = nn.ModuleList( + [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers-2)]) def forward(self, batch): if not self.in_layer: - if "features" in list(batch.keys()): - self.in_layer = nn.Linear(batch["features"].shape[1] + self.num_embedding_dim, self.hidden_dim) + if "features" in list(batch.keys()): + self.in_layer = nn.Linear( + batch["features"].shape[1] + self.num_embedding_dim, self.hidden_dim) else: - self.in_layer = nn.Linear(self.num_embedding_dim, self.hidden_dim) + self.in_layer = nn.Linear( + self.num_embedding_dim, self.hidden_dim) x = self.base_embedding(batch) x = torch.flatten(x, start_dim=1) if "features" in list(batch.keys()): @@ -63,13 +70,15 @@ def forward(self, batch): x = self.out_layer(x) return x + class NodePieceMLPForVertexClassification(bm.BaseModel): """NodePieceMLPForVertexClassification. This model is for training an multi-layer perceptron (MLP) on batches produced by NodePiece dataloaders, and transformed by the `NodePieceMLPTransform`. The architecture is for a vertex classification task, and assumes the label of each vertex is in a batch attribute called `"y"`, such as what is produced by the `NodePieceMLPTransform`. By default, this model collects `ClassficiationMetrics`, and uses cross entropy as its loss function. """ - def __init__(self, num_layers: int, out_dim: int, hidden_dim: int, vocab_size: int, sequence_length: int, embedding_dim = 768, dropout = 0.0, class_weights = None): + + def __init__(self, num_layers: int, out_dim: int, hidden_dim: int, vocab_size: int, sequence_length: int, embedding_dim=768, dropout=0.0, class_weights=None): """Initialize a NodePieceMLPForVertexClassification. Initializes the model. Args: @@ -91,7 +100,8 @@ def __init__(self, num_layers: int, out_dim: int, hidden_dim: int, vocab_size: i Weight the importance of each class in the classification task when computing loss. Helpful in imbalanced classification tasks. """ super().__init__() - self.model = BaseNodePieceMLPModel(num_layers, out_dim, hidden_dim, vocab_size, sequence_length, embedding_dim, dropout) + self.model = BaseNodePieceMLPModel( + num_layers, out_dim, hidden_dim, vocab_size, sequence_length, embedding_dim, dropout) self.metrics = ClassificationMetrics(out_dim) self.class_weight = class_weights @@ -109,7 +119,7 @@ def forward(self, batch, get_probs=False, **kwargs): else: return logits - def compute_loss(self, logits, batch, loss_fn = None, **kwargs): + def compute_loss(self, logits, batch, loss_fn=None, **kwargs): """Compute loss. Args: logits (torch.Tensor): @@ -120,9 +130,9 @@ def compute_loss(self, logits, batch, loss_fn = None, **kwargs): A PyTorch-compatible function to produce the loss of the model, which takes in logits, the labels, and optionally the class_weights. Defaults to Cross Entropy. """ - if not(loss_fn): + if not (loss_fn): loss_fn = F.cross_entropy loss = loss_fn(logits, batch["y"].long(), self.class_weight) else: loss = loss_fn(logits, batch["y"].long()) - return loss \ No newline at end of file + return loss diff --git a/pyTigerGraph/gds/models/__init__.py b/pyTigerGraph/gds/models/__init__.py index 8dc99dff..9a102da7 100644 --- a/pyTigerGraph/gds/models/__init__.py +++ b/pyTigerGraph/gds/models/__init__.py @@ -1,2 +1,2 @@ from . import GraphSAGE -from . import NodePieceMLP \ No newline at end of file +from . import NodePieceMLP diff --git a/pyTigerGraph/gds/models/base_model.py b/pyTigerGraph/gds/models/base_model.py index 9a174a1d..b2a1361d 100644 --- a/pyTigerGraph/gds/models/base_model.py +++ b/pyTigerGraph/gds/models/base_model.py @@ -1,10 +1,12 @@ try: import torch except: - raise Exception("PyTorch required to use built-in models. Please install PyTorch") + raise Exception( + "PyTorch required to use built-in models. Please install PyTorch") from ..trainer import Trainer + class BaseModel(torch.nn.Module): def __init__(self): super().__init__() @@ -17,12 +19,13 @@ def forward(self, batch, target_type=None, **kwargs): raise NotImplementedError("Forward pass not implemented for BaseModel") def compute_loss(self, logits, batch, loss_fn=None, **kwargs): - raise NotImplementedError("Loss computation not implemented in BaseModel") + raise NotImplementedError( + "Loss computation not implemented in BaseModel") def fit(self, training_dataloader, eval_dataloader, number_epochs, target_type=None, trainer_kwargs={}): trainer_kwargs.update({"model": self, "training_dataloader": training_dataloader, - "eval_dataloader": eval_dataloader, + "eval_dataloader": eval_dataloader, "target_type": target_type}) self.trainer = Trainer(**trainer_kwargs) self.trainer.train(number_epochs) @@ -31,4 +34,5 @@ def predict(self, batch): if self.trainer: return self.trainer.predict(batch) else: - raise Exception("Model has not been fit yet. Call model.fit() before model.predict()") \ No newline at end of file + raise Exception( + "Model has not been fit yet. Call model.fit() before model.predict()") diff --git a/pyTigerGraph/gds/splitters.py b/pyTigerGraph/gds/splitters.py index 92cb1481..f11644d8 100644 --- a/pyTigerGraph/gds/splitters.py +++ b/pyTigerGraph/gds/splitters.py @@ -38,7 +38,8 @@ def _validate_args(self, split_ratios) -> None: raise ValueError("Can take at most 3 partition ratios in input.") for v in split_ratios.values(): if v < 0 or v > 1: - raise ValueError("All partition ratios have to be between 0 and 1.") + raise ValueError( + "All partition ratios have to be between 0 and 1.") if sum(split_ratios.values()) > 1: raise ValueError("Sum of all partition ratios have to be <=1") @@ -69,17 +70,17 @@ def run(self, schema_type, **split_ratios) -> None: else: global_types.append(e_type) if len(global_types) > 0: - add_attribute(self._graph, - schema_type, - attr_name = {key: "BOOL" for key in split_ratios}, - schema_name = global_types, - global_change = True) + add_attribute(self._graph, + schema_type, + attr_name={key: "BOOL" for key in split_ratios}, + schema_name=global_types, + global_change=True) if len(local_types) > 0: - add_attribute(self._graph, - schema_type, - attr_name = {key: "BOOL" for key in split_ratios}, - schema_name = local_types, - global_change = False) + add_attribute(self._graph, + schema_type, + attr_name={key: "BOOL" for key in split_ratios}, + schema_name=local_types, + global_change=False) payload = {} payload["stypes"] = self.schema_types for i, key in enumerate(split_ratios): @@ -110,11 +111,11 @@ class RandomVertexSplitter(BaseRandomSplitter): splitter = RandomVertexSplitter(conn, timeout, attr_name=0.6) splitter.run() ---- - + * A random 60% of vertices will have their attribute "attr_name" set to True, and a random 20% of vertices will have their attribute "attr_name2" set to True. The two parts are disjoint. Example: - + + [source,python] ---- @@ -134,7 +135,7 @@ class RandomVertexSplitter(BaseRandomSplitter): splitter = RandomVertexSplitter(conn, timeout, attr_name=0.6, attr_name2=0.2, attr_name3=0.2) splitter.run() ---- - + Args: conn (TigerGraphConnection): Connection to TigerGraph database. @@ -163,7 +164,7 @@ def run(self, **split_ratios) -> None: """Perform the split. The split ratios set in initialization can be overridden here. - + For example: [,python] @@ -171,7 +172,7 @@ def run(self, **split_ratios) -> None: splitter = RandomVertexSplitter(conn, timeout, attr_name=0.6); splitter.run(attr_name=0.3) ---- - + The spliter above uses the ratio 0.3 instead of 0.6. """ @@ -187,7 +188,7 @@ class RandomEdgeSplitter(BaseRandomSplitter): indicates which part an edge belongs to. Usage: - + * A random 60% of edges will have their attribute "attr_name" set to True, and others False. `attr_name` can be any attribute that exists in the database (same below). Example: diff --git a/pyTigerGraph/gds/trainer.py b/pyTigerGraph/gds/trainer.py index f44ca992..ad35715d 100644 --- a/pyTigerGraph/gds/trainer.py +++ b/pyTigerGraph/gds/trainer.py @@ -17,6 +17,7 @@ import os import warnings + class BaseCallback(): """Base class for training callbacks. @@ -25,13 +26,14 @@ class BaseCallback(): during that point in time of the trainer's execution, such as the beginning or end of an epoch. Inherit from this class if a custom callback implementation is desired. """ + def __init__(self): """NO DOC""" pass def on_init_end(self, trainer): """Run operations after the initialization of the trainer. - + Args: trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it. @@ -40,7 +42,7 @@ def on_init_end(self, trainer): def on_epoch_start(self, trainer): """Run operations at the start of a training epoch. - + Args: trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it. @@ -67,7 +69,7 @@ def on_train_step_end(self, trainer): def on_epoch_end(self, trainer): """Run operations at the end of an epoch. - + Args: trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it. @@ -76,7 +78,7 @@ def on_epoch_end(self, trainer): def on_eval_start(self, trainer): """Run operations at the start of the evaulation process. - + Args: trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it. @@ -103,7 +105,7 @@ def on_eval_step_end(self, trainer): def on_eval_end(self, trainer): """Run operations at the end of the evaluation process. - + Args: trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it. @@ -113,7 +115,7 @@ def on_eval_end(self, trainer): class PrinterCallback(BaseCallback): """Callback for printing metrics during training. - + To use, import the class and pass it to the Trainer's callback argument. [.wrap,python] @@ -123,6 +125,7 @@ class PrinterCallback(BaseCallback): trainer = Trainer(model, training_dataloader, eval_dataloader, callbacks=[PrinterCallback]) ---- """ + def __init__(self): """NO DOC""" pass @@ -135,40 +138,47 @@ def on_eval_end(self, trainer): """NO DOC""" print(trainer.get_eval_metrics()) + class MetricsCallback(BaseCallback): """NO DOC""" + def on_train_step_end(self, trainer): """NO DOC""" trainer.reset_train_step_metrics() for metric in trainer.metrics: - metric.update_metrics(trainer.loss, trainer.out, trainer.batch, target_type=trainer.target_type) + metric.update_metrics(trainer.loss, trainer.out, + trainer.batch, target_type=trainer.target_type) trainer.update_train_step_metrics(metric.get_metrics()) metric.reset_metrics() trainer.update_train_step_metrics({"global_step": trainer.cur_step}) - trainer.update_train_step_metrics({"epoch": int(trainer.cur_step/trainer.train_loader.num_batches)}) - + trainer.update_train_step_metrics( + {"epoch": int(trainer.cur_step/trainer.train_loader.num_batches)}) + def on_eval_start(self, trainer): """NO DOC""" for metric in trainer.metrics: metric.reset_metrics() - + def on_eval_step_end(self, trainer): """NO DOC""" for metric in trainer.metrics: - metric.update_metrics(trainer.loss, trainer.out, trainer.batch, target_type=trainer.target_type) - + metric.update_metrics(trainer.loss, trainer.out, + trainer.batch, target_type=trainer.target_type) + def on_eval_end(self, trainer): """NO DOC""" for metric in trainer.metrics: trainer.update_eval_metrics(metric.get_metrics()) + class DefaultCallback(BaseCallback): """Default Callback - + The `DefaultCallback` class logs metrics and updates progress bars during the training process. The Trainer `callbacks` parameter is populated with this callback. If you define other callbacks with that parameter, you will have to pass `DefaultCallback` again in your list of callbacks. """ + def __init__(self, output_dir="./logs", use_tqdm=True): """Instantiate the Default Callback. @@ -188,7 +198,8 @@ def __init__(self, output_dir="./logs", use_tqdm=True): self.valid_bar = None except: self.tqdm = None - warnings.warn("tqdm not installed. Please install tqdm if progress bar support is desired.") + warnings.warn( + "tqdm not installed. Please install tqdm if progress bar support is desired.") else: self.tqdm = False self.output_dir = output_dir @@ -196,7 +207,8 @@ def __init__(self, output_dir="./logs", use_tqdm=True): os.makedirs(self.output_dir, exist_ok=True) curDT = time.time() logging.basicConfig(format='%(asctime)s %(levelname)s:%(name)s:%(message)s', - filename=output_dir+'/train_results_'+str(curDT)+'.log', + filename=output_dir + + '/train_results_'+str(curDT)+'.log', filemode='w', encoding='utf-8', level=logging.INFO) @@ -204,13 +216,16 @@ def __init__(self, output_dir="./logs", use_tqdm=True): def on_epoch_start(self, trainer): """NO DOC""" if self.tqdm: - if not(self.epoch_bar): + if not (self.epoch_bar): if trainer.num_epochs: - self.epoch_bar = self.tqdm(desc="Epochs", total=trainer.num_epochs) + self.epoch_bar = self.tqdm( + desc="Epochs", total=trainer.num_epochs) else: - self.epoch_bar = self.tqdm(desc="Training Steps", total=trainer.max_num_steps) - if not(self.batch_bar): - self.batch_bar = self.tqdm(desc="Training Batches", total=trainer.train_loader.num_batches) + self.epoch_bar = self.tqdm( + desc="Training Steps", total=trainer.max_num_steps) + if not (self.batch_bar): + self.batch_bar = self.tqdm( + desc="Training Batches", total=trainer.train_loader.num_batches) def on_train_step_end(self, trainer): """NO DOC""" @@ -224,8 +239,9 @@ def on_eval_start(self, trainer): """NO DOC""" trainer.reset_eval_metrics() if self.tqdm: - if not(self.valid_bar): - self.valid_bar = self.tqdm(desc="Eval Batches", total=trainer.eval_loader.num_batches) + if not (self.valid_bar): + self.valid_bar = self.tqdm( + desc="Eval Batches", total=trainer.eval_loader.num_batches) def on_eval_step_end(self, trainer): """NO DOC""" @@ -256,22 +272,23 @@ def on_epoch_end(self, trainer): class Trainer(): """Trainer - + Train graph machine learning models that comply with the `BaseModel` object in pyTigerGraph. Performs training and evaluation loops and automatically collects metrics for the given task. - + PyTorch is required to use the Trainer. """ - def __init__(self, + + def __init__(self, model, training_dataloader: BaseLoader, eval_dataloader: BaseLoader, callbacks: List[BaseCallback] = [DefaultCallback], - metrics = None, - target_type = None, - loss_fn = None, - optimizer = None, - optimizer_kwargs = {}): + metrics=None, + target_type=None, + loss_fn=None, + optimizer=None, + optimizer_kwargs={}): """Instantiate a Trainer. Create a Trainer object to train graph machine learning models. @@ -301,7 +318,8 @@ def __init__(self, try: import torch except: - raise Exception("PyTorch is required to use the trainer. Please install PyTorch.") + raise Exception( + "PyTorch is required to use the trainer. Please install PyTorch.") self.model = model self.train_loader = training_dataloader self.eval_loader = eval_dataloader @@ -328,16 +346,17 @@ def __init__(self, try: if self.train_loader.v_out_labels: if self.is_hetero: - self.target_type = list(self.train_loader.v_out_labels.keys())[0] + self.target_type = list( + self.train_loader.v_out_labels.keys())[0] else: - self.target_type = None #self.train_loader.v_out_labels + self.target_type = None # self.train_loader.v_out_labels else: self.target_type = target_type except: self.target_type = None callbacks = [MetricsCallback] + callbacks - for callback in callbacks: # instantiate callbacks if not already done so + for callback in callbacks: # instantiate callbacks if not already done so if isinstance(callback, type): callback = callback() self.callbacks.append(callback) @@ -348,7 +367,7 @@ def __init__(self, def update_train_step_metrics(self, metrics): """Update the metrics for a training step. - + Args: metrics (dict): Dictionary of calculated metrics. @@ -357,7 +376,7 @@ def update_train_step_metrics(self, metrics): def get_train_step_metrics(self): """Get the metrics for a training step. - + Returns: Dictionary of training metrics results. """ @@ -373,7 +392,7 @@ def reset_train_step_metrics(self): def update_eval_metrics(self, metrics): """Update the metrics of an evaluation loop. - + Args: metrics (dict): Dictionary of calculated metrics. @@ -382,7 +401,7 @@ def update_eval_metrics(self, metrics): def get_eval_metrics(self): """Get the metrics for an evaluation loop. - + Returns: Dictionary of evaluation loop metrics results. """ @@ -398,7 +417,7 @@ def reset_eval_metrics(self): def train(self, num_epochs=None, max_num_steps=None): """Train a model. - + Args: num_epochs (int, optional): Number of epochs to train for. Defaults to 1 full iteration through the `training_dataloader`. @@ -423,9 +442,9 @@ def train(self, num_epochs=None, max_num_steps=None): self.out = self.model(batch, target_type=self.target_type) self.batch = batch self.loss = self.model.compute_loss(self.out, - batch, - target_type = self.target_type, - loss_fn = self.loss_fn) + batch, + target_type=self.target_type, + loss_fn=self.loss_fn) self.optimizer.zero_grad() self.loss.backward() self.optimizer.step() @@ -434,7 +453,7 @@ def train(self, num_epochs=None, max_num_steps=None): callback.on_train_step_end(trainer=self) for callback in self.callbacks: - callback.on_epoch_end(trainer=self) + callback.on_epoch_end(trainer=self) def eval(self, loader=None): """Evaluate a model. @@ -457,9 +476,9 @@ def eval(self, loader=None): self.out = self.model(batch, target_type=self.target_type) self.batch = batch self.loss = self.model.compute_loss(self.out, - batch, - target_type = self.target_type, - loss_fn = self.loss_fn) + batch, + target_type=self.target_type, + loss_fn=self.loss_fn) for callback in self.callbacks: callback.on_eval_step_end(trainer=self) for callback in self.callbacks: @@ -472,9 +491,9 @@ def predict(self, batch): batch (any): Data object that is compatible with the model being trained. Make predictions on the batch passed in. - + Returns: Returns a tuple of `(model output, evaluation metrics)` """ self.eval(loader=[batch]) - return self.out, self.get_eval_metrics() \ No newline at end of file + return self.out, self.get_eval_metrics() diff --git a/pyTigerGraph/gds/transforms/nodepiece_transforms.py b/pyTigerGraph/gds/transforms/nodepiece_transforms.py index a11eee18..40534d0d 100644 --- a/pyTigerGraph/gds/transforms/nodepiece_transforms.py +++ b/pyTigerGraph/gds/transforms/nodepiece_transforms.py @@ -1,18 +1,22 @@ """NodePiece Transforms""" + class BaseNodePieceTransform(): """NO DOC.""" + def __call__(self, data): return data def __repr__(self): return f'{self.__class__.__name__}()' + class NodePieceMLPTransform(BaseNodePieceTransform): """NodePieceMLPTransform. The NodePieceMLPTransform converts a batch of data from the NodePieceLoader into a format that can be used in a MLP implemented in PyTorch. """ # Assumes numerical types for features and labels. No support for complex datatypes as features. + def __init__(self, label: str, features: list = [], target_type: str = None): """Instantiate a NodePieceMLPTransform. Args: @@ -26,7 +30,8 @@ def __init__(self, label: str, features: list = [], target_type: str = None): try: import torch except: - raise Exception("PyTorch is required to use this transform. Please install PyTorch") + raise Exception( + "PyTorch is required to use this transform. Please install PyTorch") self.features = features self.target_type = target_type self.label = label @@ -42,11 +47,12 @@ def __call__(self, data): if self.target_type: data = data[self.target_type] batch["y"] = torch.tensor(data[self.label].astype(int)) - batch["relational_context"] = torch.tensor(data["relational_context"], dtype=torch.long) + batch["relational_context"] = torch.tensor( + data["relational_context"], dtype=torch.long) batch["anchors"] = torch.tensor(data["anchors"], dtype=torch.long) - batch["distance"] = torch.tensor(data["anchor_distances"], dtype=torch.long) + batch["distance"] = torch.tensor( + data["anchor_distances"], dtype=torch.long) if len(self.features) > 0: - batch["features"] = torch.stack([torch.tensor(data[feat]) for feat in self.features]).T + batch["features"] = torch.stack( + [torch.tensor(data[feat]) for feat in self.features]).T return batch - - \ No newline at end of file diff --git a/pyTigerGraph/gds/transforms/pyg_transforms.py b/pyTigerGraph/gds/transforms/pyg_transforms.py index f8385772..a596c284 100644 --- a/pyTigerGraph/gds/transforms/pyg_transforms.py +++ b/pyTigerGraph/gds/transforms/pyg_transforms.py @@ -1,17 +1,21 @@ """PyTorch Geometric Transforms""" + class BasePyGTransform(): """NO DOC""" + def __call__(self, data): return data def __repr__(self): return f'{self.__class__.__name__}()' + class TemporalPyGTransform(BasePyGTransform): """TemporalPyGTransform. The TemporalPyGTransform creates a sequence of subgraph batches out of a single batch of data produced by a NeighborLoader or HGTLoader. It assumes that there are datetime attributes on vertices and edges. If vertex attributes change over time, children vertex attributes are moved to the appropriate parent, and then the children are removed from the graph. """ + def __init__(self, vertex_start_attrs: dict, vertex_end_attrs: dict, @@ -65,9 +69,11 @@ def __init__(self, import torch_geometric as pyg import torch if (int(pyg.__version__.split(".")[1]) < 3 and int(pyg.__version__.split(".")[0]) == 2) or int(pyg.__version__.split(".")[0]) < 2: - raise Exception("PyTorch Geometric version must be 2.3.0 or greater") + raise Exception( + "PyTorch Geometric version must be 2.3.0 or greater") except: - raise Exception("PyTorch Geometric required to use PyG models. Please install PyTorch Geometric") + raise Exception( + "PyTorch Geometric required to use PyG models. Please install PyTorch Geometric") def __call__(self, data) -> list: """Perform the transform. Returns a list of PyTorch Geometric data objects, a sequence of snapshots in time of the graph. @@ -88,39 +94,48 @@ def __call__(self, data) -> list: v_start_attr = self.vertex_start[v_type] if v_type in self.vertex_end.keys(): v_end_attr = self.vertex_end[v_type] - v_to_keep[v_type] = torch.logical_and(data[v_type][v_start_attr] <= i, torch.logical_or(data[v_type][v_end_attr] > i, data[v_type][v_end_attr] == -1)) + v_to_keep[v_type] = torch.logical_and(data[v_type][v_start_attr] <= i, torch.logical_or( + data[v_type][v_end_attr] > i, data[v_type][v_end_attr] == -1)) else: v_to_keep[v_type] = data[v_type][v_start_attr] <= i elif v_type in self.vertex_end.keys(): v_end_attr = self.vertex_end[v_type] if v_type not in self.vertex_start[v_type]: - v_to_keep[v_type] = torch.logical_or(data[v_type][v_end_attr] >= i, data[v_type][v_end_attr] == -1) + v_to_keep[v_type] = torch.logical_or( + data[v_type][v_end_attr] >= i, data[v_type][v_end_attr] == -1) else: - v_to_keep[v_type] = torch.tensor([i for i in data[v_type].is_seed]) + v_to_keep[v_type] = torch.tensor( + [i for i in data[v_type].is_seed]) data[v_type]["vertex_present"] = v_to_keep[v_type] - - + e_to_keep = {} for e_type in data.edge_types: v_src_type = e_type[0] v_dest_type = e_type[-1] - src_idx_to_keep = torch.argwhere(v_to_keep[v_src_type]).flatten() - dest_idx_to_keep = torch.argwhere(v_to_keep[v_dest_type]).flatten() + src_idx_to_keep = torch.argwhere( + v_to_keep[v_src_type]).flatten() + dest_idx_to_keep = torch.argwhere( + v_to_keep[v_dest_type]).flatten() edges = data.edge_index_dict[e_type] - filtered_edges = torch.logical_and(torch.tensor([True if i in src_idx_to_keep else False for i in edges[0]]), torch.tensor([True if i in dest_idx_to_keep else False for i in edges[1]])) + filtered_edges = torch.logical_and(torch.tensor([True if i in src_idx_to_keep else False for i in edges[0]]), torch.tensor([ + True if i in dest_idx_to_keep else False for i in edges[1]])) if e_type in self.edge_start.keys(): - filtered_edges = torch.logical_and(filtered_edges, data[e_type][self.edge_start[e_type]] <= i) + filtered_edges = torch.logical_and( + filtered_edges, data[e_type][self.edge_start[e_type]] <= i) if e_type in self.edge_end.keys(): - filtered_edges = torch.logical_and(filtered_edges, data[e_type][self.edge_end[e_type]] >= i) + filtered_edges = torch.logical_and( + filtered_edges, data[e_type][self.edge_end[e_type]] >= i) e_to_keep[e_type] = filtered_edges - + subgraph = data.edge_subgraph(e_to_keep) - + for triple in self.feat_tr.keys(): for feat in self.feat_tr[triple]: - subgraph[triple[-1]][str(triple[0])+"_"+str(feat)] = torch.zeros(subgraph[triple[-1]]["vertex_present"].size(), dtype=subgraph[triple[0]][feat].dtype) - subgraph[triple[-1]][str(triple[0])+"_"+str(feat)][subgraph[triple].edge_index[1]] = subgraph[triple[0]][feat][subgraph[triple].edge_index[0]] - + subgraph[triple[-1]][str(triple[0])+"_"+str(feat)] = torch.zeros( + subgraph[triple[-1]]["vertex_present"].size(), dtype=subgraph[triple[0]][feat].dtype) + subgraph[triple[-1]][str(triple[0])+"_"+str(feat)][subgraph[triple].edge_index[1] + ] = subgraph[triple[0]][feat][subgraph[triple].edge_index[0]] + for triple in self.feat_tr.keys(): del subgraph[triple[0]] for es in subgraph.edge_types: @@ -131,22 +146,28 @@ def __call__(self, data) -> list: return sequence elif isinstance(data, pyg.data.Data): if self.feat_tr: - raise Exception("No feature transformations are supported on homogeneous data") + raise Exception( + "No feature transformations are supported on homogeneous data") sequence = [] for i in range(self.start_dt, self.end_dt, self.timestep): - v_to_keep = torch.logical_and(data[self.vertex_start] <= i, torch.logical_or(data[self.vertex_end] > i, data[self.vertex_end] == -1)) + v_to_keep = torch.logical_and(data[self.vertex_start] <= i, torch.logical_or( + data[self.vertex_end] > i, data[self.vertex_end] == -1)) src_idx_to_keep = torch.argwhere(v_to_keep).flatten() dest_idx_to_keep = torch.argwhere(v_to_keep).flatten() edges = data.edge_index - filtered_edges = torch.logical_and(torch.tensor([True if i in src_idx_to_keep else False for i in edges[0]]), torch.tensor([True if i in dest_idx_to_keep else False for i in edges[1]])) + filtered_edges = torch.logical_and(torch.tensor([True if i in src_idx_to_keep else False for i in edges[0]]), torch.tensor([ + True if i in dest_idx_to_keep else False for i in edges[1]])) if self.edge_start: - filtered_edges = torch.logical_and(filtered_edges, data[self.edge_start] <= i) + filtered_edges = torch.logical_and( + filtered_edges, data[self.edge_start] <= i) if self.edge_end: - filtered_edges = torch.logical_and(filtered_edges, data[self.edge_end] >= i) + filtered_edges = torch.logical_and( + filtered_edges, data[self.edge_end] >= i) e_to_keep = filtered_edges subgraph = data.edge_subgraph(e_to_keep) subgraph.vertex_present = v_to_keep sequence.append(subgraph) return sequence else: - raise Exception("Passed batch of data must be of type torch_geometric.data.Data or torch_geometric.data.HeteroData") \ No newline at end of file + raise Exception( + "Passed batch of data must be of type torch_geometric.data.Data or torch_geometric.data.HeteroData") diff --git a/pyTigerGraph/gds/utilities.py b/pyTigerGraph/gds/utilities.py index 5089f7ec..d7b5c285 100644 --- a/pyTigerGraph/gds/utilities.py +++ b/pyTigerGraph/gds/utilities.py @@ -129,12 +129,15 @@ def install_query_file( ) # If a suffix is to be added to query name if replace and ("{QUERYSUFFIX}" in replace): - query_name = query_name.replace("{QUERYSUFFIX}", replace["{QUERYSUFFIX}"]) + query_name = query_name.replace( + "{QUERYSUFFIX}", replace["{QUERYSUFFIX}"]) # If query is already installed, skip unless force install. - is_installed, is_enabled = is_query_installed(conn, query_name, return_status=True) + is_installed, is_enabled = is_query_installed( + conn, query_name, return_status=True) if is_installed: if force or (not is_enabled): - query = "USE GRAPH {}\nDROP QUERY {}\n".format(conn.graphname, query_name) + query = "USE GRAPH {}\nDROP QUERY {}\n".format( + conn.graphname, query_name) resp = conn.gsql(query) if "Successfully dropped queries" not in resp: raise ConnectionError(resp) @@ -166,7 +169,7 @@ def install_query_file( return query_name -def add_attribute(conn: "TigerGraphConnection", schema_type:str, attr_type:str = None, attr_name:Union[str, dict] = None, schema_name:list = None, global_change:bool = False): +def add_attribute(conn: "TigerGraphConnection", schema_type: str, attr_type: str = None, attr_name: Union[str, dict] = None, schema_name: list = None, global_change: bool = False): ''' If the current attribute is not already added to the schema, it will create the schema job to do that. Check whether to add the attribute to vertex(vertices) or edge(s). @@ -203,7 +206,7 @@ def add_attribute(conn: "TigerGraphConnection", schema_type:str, attr_type:str = for t in target: attributes = [] if v_type: - meta_data = conn.getVertexType(t, force=True) + meta_data = conn.getVertexType(t, force=True) else: meta_data = conn.getEdgeType(t, force=True) for i in range(len(meta_data['Attributes'])): @@ -211,10 +214,11 @@ def add_attribute(conn: "TigerGraphConnection", schema_type:str, attr_type:str = # If attribute is not in list of vertex attributes, do the schema change to add it if isinstance(attr_name, str): if not attr_type: - raise Exception("attr_type must be defined if attr_name is of type string") + raise Exception( + "attr_type must be defined if attr_name is of type string") if attr_name != None and attr_name not in attributes: tasks.append("ALTER {} {} ADD ATTRIBUTE ({} {});\n".format( - schema_type, t, attr_name, attr_type)) + schema_type, t, attr_name, attr_type)) elif isinstance(attr_name, dict): for aname in attr_name: if aname != None and aname not in attributes: @@ -226,9 +230,9 @@ def add_attribute(conn: "TigerGraphConnection", schema_type:str, attr_type:str = return "Attribute already exists" # Drop all jobs on the graph # self.conn.gsql("USE GRAPH {}\n".format(self.conn.graphname) + "DROP JOB *") - # Create schema change job - job_name = "add_{}_attr_{}".format(schema_type,random_string(6)) - if not(global_change): + # Create schema change job + job_name = "add_{}_attr_{}".format(schema_type, random_string(6)) + if not (global_change): job = "USE GRAPH {}\n".format(conn.graphname) + "CREATE SCHEMA_CHANGE JOB {} {{\n".format( job_name) + ''.join(tasks) + "}}\nRUN SCHEMA_CHANGE JOB {}".format(job_name) else: @@ -242,4 +246,4 @@ def add_attribute(conn: "TigerGraphConnection", schema_type:str, attr_type:str = raise ConnectionError(resp) else: print(status, flush=True) - return 'Schema change succeeded.' \ No newline at end of file + return 'Schema change succeeded.' diff --git a/pyTigerGraph/pyTigerGraph.py b/pyTigerGraph/pyTigerGraph.py index 5c2e909b..c6583be7 100644 --- a/pyTigerGraph/pyTigerGraph.py +++ b/pyTigerGraph/pyTigerGraph.py @@ -23,17 +23,17 @@ # TODO Proper deprecation handling; import deprecation? class TigerGraphConnection(pyTigerGraphVertex, pyTigerGraphEdge, pyTigerGraphUDT, - pyTigerGraphLoading, pyTigerGraphPath, pyTigerGraphDataset, object): + pyTigerGraphLoading, pyTigerGraphPath, pyTigerGraphDataset, object): """Python wrapper for TigerGraph's REST++ and GSQL APIs""" def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", - gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", - tgCloud: bool = False, restppPort: Union[int, str] = "9000", - gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", - apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, - sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): + gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", + tgCloud: bool = False, restppPort: Union[int, str] = "9000", + gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", + apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): super().__init__(host, graphname, gsqlSecret, username, password, tgCloud, restppPort, - gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp, jwtToken) + gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp, jwtToken) self.gds = None self.ai = None diff --git a/pyTigerGraph/pyTigerGraphAuth.py b/pyTigerGraph/pyTigerGraphAuth.py index 4953f6d9..2b6332f2 100644 --- a/pyTigerGraph/pyTigerGraphAuth.py +++ b/pyTigerGraph/pyTigerGraphAuth.py @@ -3,24 +3,29 @@ The functions on this page authenticate connections and manage TigerGraph credentials. All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. """ -import json + import logging -import time import warnings -from datetime import datetime, timezone -from typing import Union - import requests -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from typing import Union, Tuple, Dict + +from pyTigerGraph.common.auth import ( + _parse_get_secrets, + _parse_create_secret, + _prep_token_request, + _parse_token_response +) +from pyTigerGraph.common.exception import TigerGraphException from pyTigerGraph.pyTigerGraphGSQL import pyTigerGraphGSQL + logger = logging.getLogger(__name__) class pyTigerGraphAuth(pyTigerGraphGSQL): - def getSecrets(self) -> dict: + def getSecrets(self) -> Dict[str, str]: """Issues a `SHOW SECRET` GSQL statement and returns the secret generated by that statement. Secrets are unique strings that serve as credentials when generating authentication tokens. @@ -37,40 +42,25 @@ def getSecrets(self) -> dict: res = self.gsql(""" USE GRAPH {} SHOW SECRET""".format(self.graphname), ) - ret = {} - lines = res.split("\n") - i = 0 - while i < len(lines): - l = lines[i] - s = "" - if "- Secret" in l: - s = l.split(": ")[1] - i += 1 - l = lines[i] - if "- Alias" in l: - ret[l.split(": ")[1]] = s - i += 1 + ret = _parse_get_secrets(res) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) logger.info("exit: getSecrets") return ret - # TODO Process response, return a dictionary of alias/secret pairs - def showSecrets(self) -> dict: + def showSecrets(self) -> Dict[str, str]: """DEPRECATED Use `getSecrets()` instead. """ warnings.warn("The `showSecrets()` function is deprecated; use `getSecrets()` instead.", - DeprecationWarning) + DeprecationWarning) return self.getSecrets() - # TODO getSecret() - - def createSecret(self, alias: str = "", withAlias: bool = False) -> Union[str, dict]: + def createSecret(self, alias: str = "", withAlias: bool = False) -> Union[str, Dict[str, str]]: """Issues a `CREATE SECRET` GSQL statement and returns the secret generated by that statement. Secrets are unique strings that serve as credentials when generating authentication tokens. @@ -100,47 +90,23 @@ def createSecret(self, alias: str = "", withAlias: bool = False) -> Union[str, d res = self.gsql(""" USE GRAPH {} CREATE SECRET {} """.format(self.graphname, alias)) - try: - if ("already exists" in res): - errorMsg = "The secret " - if alias != "": - errorMsg += "with alias {} ".format(alias) - errorMsg += "already exists." - raise TigerGraphException(errorMsg, "E-00001") - - secret = "".join(res).replace('\n', '').split('The secret: ')[1].split(" ")[0].strip() - - if not withAlias: - if logger.level == logging.DEBUG: - logger.debug("return: " + str(secret)) - logger.info("exit: createSecret (withAlias") - - return secret - - if alias: - ret = {alias: secret} + secret = _parse_create_secret( + res, alias=alias, withAlias=withAlias) - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: createSecret (alias)") - - return ret - - # Alias was not provided, let's find out the autogenerated one + # Alias was not provided, let's find out the autogenerated one + # done in createSecret since need to call self.getSecrets which is a possibly async function + if withAlias and not alias: masked = secret[:3] + "****" + secret[-3:] secs = self.getSecrets() for (a, s) in secs.items(): if s == masked: - ret = {a: secret} + secret = {a: secret} - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: createSecret") - - return ret + if logger.level == logging.DEBUG: + logger.debug("return: " + str(secret)) + logger.info("exit: createSecret") - except: - raise + return secret def dropSecret(self, alias: Union[str, list], ignoreErrors: bool = True) -> str: """Drops a secret. @@ -177,338 +143,97 @@ def dropSecret(self, alias: Union[str, list], ignoreErrors: bool = True) -> str: return res - def getToken(self, secret: str = None, setToken: bool = True, lifetime: int = None) -> Union[tuple, str]: - """Requests an authorization token. - - This function returns a token only if REST++ authentication is enabled. If not, an exception - will be raised. - See https://docs.tigergraph.com/admin/admin-guide/user-access-management/user-privileges-and-authentication#rest-authentication - - Args: - secret (str, Optional): - The secret (string) generated in GSQL using `CREATE SECRET`. - See https://docs.tigergraph.com/tigergraph-server/current/user-access/managing-credentials#_create_a_secret - setToken (bool, Optional): - Set the connection's API token to the new value (default: `True`). - lifetime (int, Optional): - Duration of token validity (in seconds, default 30 days = 2,592,000 seconds). - - Returns: - If your TigerGraph instance is running version <=3.10, the return value is - a tuple of `(, , )`. - The return value can be ignored, as the token is automatically set for the connection after this call. - - If your TigerGraph instance is running version 4.0, the return value is a tuple of `(, ). - - [NOTE] - The expiration timestamp's time zone might be different from your computer's local time - zone. - - Raises: - `TigerGraphException` if REST++ authentication is not enabled or if an authentication - error occurred. + def _token(self, secret: str = None, lifetime: int = None, token: str = None, _method: str = None) -> Union[tuple, str]: + method, url, alt_url, authMode, data, alt_data = _prep_token_request(self.restppUrl, + self.gsUrl, + self.graphname, + self.version, + secret, + lifetime, + token) + # _method Used for delete and refresh token + + # method == GET when using old version since _prep_newToken() gets the method for getting a new token for a version + if method == "GET": + if _method: + method = _method + + # Use TG < 3.5 format (no json data) + res = self._req(method, url, authMode=authMode, + data=data, resKey=None) + mainVer = 3 + else: + if _method: + method = _method - Endpoint: - - `POST /requesttoken` (In TigerGraph versions 3.x) - See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_request_a_token - - `POST /gsql/v1/tokens` (In TigerGraph versions 4.x) - """ + # Try using TG 4.1 endpoint first, if url not found then try <4.1 endpoint + try: + res = self._req(method, url, authMode=authMode, + data=data, resKey=None, jsonData=True) + mainVer = 4 + except: + try: + res = self._req( + method, alt_url, authMode=authMode, data=alt_data, resKey=None) + mainVer = 3 + except Exception as e: + raise TigerGraphException("Error requesting token. Check if the connection's graphname is correct.", 400) + + # uses mainVer instead of _versionGreaterThan4_0 since you need a token for verson checking + return res, mainVer + + def getToken(self, + secret: str = None, + setToken: bool = True, + lifetime: int = None) -> Union[Tuple[str, str], str]: logger.info("entry: getToken") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - s, m, i = (0, 0, 0) - res = {} - if self.version: - s, m, i = self.version.split(".") - success = False - - if not(secret) and self.graphname: - if self.graphname: - _json = {"graph": self.graphname} - try: - res = self._post(self.restppUrl+"/requesttoken", authMode="pwd", data=str(_json), resKey="results") - mainVer = 3 - - # The old endpoint doesn't exist (since on TigerGraph Ver >=4.1). Use new endpoint - # have to handle it in this order since _req changes the url to the new url path if first request fails - except Exception as e: - try: - res = self._post(self.gsUrl + "/gsql/v1/tokens", - data=_json, - authMode="pwd", - jsonData=True, - resKey=None) - mainVer = 4 - except requests.exceptions.HTTPError as e: - if e.response.status_code == 404: - raise TigerGraphException( - "Error requesting token. Check if the connection's graphname is correct and that REST authentication is enabled.", - 404 - ) - else: - raise e - pass - success = True - elif secret: - try: - data = {"secret": secret} - - if lifetime: - data["lifetime"] = str(lifetime) - - res = json.loads(requests.post(self.restppUrl + "/requesttoken", - data=json.dumps(data), verify=False).text) - success = True - mainVer = 3 - except Exception as e: # case of less than version 3.5 - try: - res = requests.request("GET", self.restppUrl + - "/requesttoken?secret=" + secret + - ("&lifetime=" + str(lifetime) if lifetime else ""), verify=False) - mainVer = 3 # Can't use _verGreaterThan4_0 to check version since you need to set a token for that - - res = json.loads(res.text) - - if not res["error"]: - success = True - except: - raise e - else: - raise TigerGraphException("Cannot request a token with username/password for versions < 3.5.") - - - - if not res.get("error"): - if setToken: - self.apiToken = res["token"] - self.authHeader = {'Authorization': "Bearer " + self.apiToken} - else: - self.apiToken = None - self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)} - - if res.get("expiration"): - # On >=4.1 the format for the date of expiration changed, can't get utc time stamp from it - if mainVer == 4: - ret = res["token"], res.get("expiration") - else: - ret = res["token"], res.get("expiration"), \ - datetime.utcfromtimestamp(float(res.get("expiration"))).strftime('%Y-%m-%d %H:%M:%S') - else: - ret = res["token"] - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: parseVertices") - - return ret - - if "Endpoint is not found from url = /requesttoken" in res["message"]: - raise TigerGraphException("REST++ authentication is not enabled, can't generate token.", - None) - raise TigerGraphException(res["message"], (res["code"] if "code" in res else None)) - - def refreshToken(self, secret: str, token: str = "", setToken: bool = True, lifetime: int = None) -> tuple: - """Extends a token's lifetime. - - This function works only if REST++ authentication is enabled. If not, an exception will be - raised. - See https://docs.tigergraph.com/admin/admin-guide/user-access-management/user-privileges-and-authentication#rest-authentication - - Args: - secret: - The secret (string) generated in GSQL using `CREATE SECRET`. - See https://docs.tigergraph.com/tigergraph-server/current/user-access/managing-credentials#_create_a_secret - token: - The token requested earlier. If not specified, refreshes current connection's token. - lifetime: - Duration of token validity (in seconds, default 30 days = 2,592,000 seconds) from - current system timestamp. - - Returns: - A tuple of `(, , )`. - The return value can be ignored. / - New expiration timestamp will be now + lifetime seconds, _not_ current expiration - timestamp + lifetime seconds. - - [NOTE] - The expiration timestamp's time zone might be different from your computer's local time - zone. + res, mainVer = self._token(secret, lifetime) + token, auth_header = _parse_token_response(res, + setToken, + mainVer, + self.base64_credential + ) + self.apiToken = token + self.authHeader = auth_header + logger.info("exit: getToken") + return token - Raises: - `TigerGraphException` if REST++ authentication is not enabled, if an authentication error - occurs, or if calling while using TigerGraph 4.x. - - Note: - Not avaliable on TigerGraph version 4.x - - Endpoint: - - `PUT /requesttoken` - See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_refresh_a_token - TODO Rework lifetime parameter handling the same as in getToken() - """ + def refreshToken(self, secret: str = None, setToken: bool = True, lifetime: int = None, token: str = None) -> Union[Tuple[str, str], str]: logger.info("entry: refreshToken") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - s, m, i = (0, 0, 0) - res = {} - if self.version: - s, m, i = self.version.split(".") - success = False + if self._version_greater_than_4_0(): + logger.info("exit: refreshToken") + raise TigerGraphException( + "Refreshing tokens is only supported on versions of TigerGraph <= 4.0.0.", 0) if not token: token = self.apiToken + res, mainVer = self._token(secret, lifetime, token, "PUT") - if self._versionGreaterThan4_0(): - logger.info("exit: refreshToken") - raise TigerGraphException("Refreshing tokens is only supported on versions of TigerGraph <= 4.0.0.", 0) - - if int(s) < 3 or (int(s) == 3 and int(m) < 5): - if self.useCert and self.certPath: - res = json.loads(requests.request("PUT", self.restppUrl + "/requesttoken?secret=" + - secret + "&token=" + token + ("&lifetime=" + str(lifetime) if lifetime else ""), - verify=False).text) - else: - res = json.loads(requests.request("PUT", self.restppUrl + "/requesttoken?secret=" + - secret + "&token=" + token + ("&lifetime=" + str(lifetime) if lifetime else ""), - verify=False).text) - if not res["error"]: - success = True - if "Endpoint is not found from url = /requesttoken" in res["message"]: - raise TigerGraphException("REST++ authentication is not enabled, can't refresh token.", - None) - - if not success: - data = {"secret": secret, "token": token} - if lifetime: - data["lifetime"] = str(lifetime) - if self.useCert is True and self.certPath is not None: - res = json.loads(requests.put(self.restppUrl + "/requesttoken", - data=json.dumps(data), verify=False).text) - else: - res = json.loads(requests.put(self.restppUrl + "/requesttoken", - data=json.dumps(data), verify=False).text) - if not res["error"]: - success = True - if "Endpoint is not found from url = /requesttoken" in res["message"]: - raise TigerGraphException("REST++ authentication is not enabled, can't refresh token.", - None) - - if success: - exp = time.time() + res["expiration"] - ret = res["token"], int(exp), \ - datetime.fromtimestamp(exp, timezone.utc).strftime('%Y-%m-%d %H:%M:%S') - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: refreshToken") - - return ret - - raise TigerGraphException(res["message"], (res["code"] if "code" in res else None)) + newToken = _parse_token_response(res, setToken, mainVer, self.base64_credential) - def deleteToken(self, secret, token=None, skipNA=True) -> bool: - """Deletes a token. + logger.info("exit: refreshToken") - This function works only if REST++ authentication is enabled. If not, an exception will be - raised. - See https://docs.tigergraph.com/tigergraph-server/current/user-access/enabling-user-authentication#_enable_restpp_authentication - - Args: - secret: - The secret (string) generated in GSQL using `CREATE SECRET`. - See https://docs.tigergraph.com/tigergraph-server/current/user-access/managing-credentials#_create_a_secret - token: - The token requested earlier. If not specified, deletes current connection's token, - so be careful. - skipNA: - Don't raise an exception if the specified token does not exist. - - Returns: - `True`, if deletion was successful, or if the token did not exist but `skipNA` was - `True`. - - Raises: - `TigerGraphException` if REST++ authentication is not enabled or an authentication error - occurred, for example if the specified token does not exist. - - Endpoint: - - `DELETE /requesttoken` (In TigerGraph version 3.x) - See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_delete_a_token - - `DELETE /gsql/v1/tokens` (In TigerGraph version 4.x) - """ - logger.info("entry: deleteToken") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - s, m, i = (0, 0, 0) - res = {} - if self.version: - s, m, i = self.version.split(".") - success = False + return newToken + def deleteToken(self, secret: str, token: str = None, skipNA: bool = False) -> bool: if not token: token = self.apiToken + res, _ = self._token(secret, None, token, "DELETE") - if int(s) < 3 or (int(s) == 3 and int(m) < 5): - if self.useCert is True and self.certPath is not None: - if self._versionGreaterThan4_0(): - res = requests.request("DELETE", self.gsUrl + - "/gsql/v1/tokens", verify=False, json={"secret": secret, "token": token}, - headers={"X-User-Agent": "pyTigerGraph"}) - res = json.loads(res.text) - else: - res = json.loads( - requests.request("DELETE", - self.restppUrl + "/requesttoken?secret=" + secret + "&token=" + token, - verify=False).text) - else: - if self._versionGreaterThan4_0(): - res = requests.request("DELETE", self.gsUrl + - "/gsql/v1/tokens", verify=False, json={"tokens": token}, - headers={"X-User-Agent": "pyTigerGraph"}) - res = json.loads(res.text) - else: - res = json.loads( - requests.request("DELETE", - self.restppUrl + "/requesttoken?secret=" + secret + "&token=" + token).text) - if not res["error"]: - success = True - - if not success: - data = {"secret": secret, "token": token} - if self.useCert is True and self.certPath is not None: - res = json.loads(requests.delete(self.restppUrl + "/requesttoken", - data=json.dumps(data)).text) - else: - if self._versionGreaterThan4_0(): - res = requests.request("DELETE", self.gsUrl + - "/gsql/v1/tokens", verify=False, data=json.dumps(data), - headers={"X-User-Agent": "pyTigerGraph"}) - res = json.loads(res.text) - else: - res = json.loads(requests.delete(self.restppUrl + "/requesttoken", - data=json.dumps(data), verify=False).text) - - if "Endpoint is not found from url = /requesttoken" in res["message"]: - raise TigerGraphException("REST++ authentication is not enabled, can't delete token.", - None) - - if not res["error"]: + if not res["error"] or (res["code"] == "REST-3300" and skipNA): if logger.level == logging.DEBUG: logger.debug("return: " + str(True)) logger.info("exit: deleteToken") return True - if res["code"] == "REST-3300" and skipNA: - if logger.level == logging.DEBUG: - logger.debug("return: " + str(True)) - logger.info("exit: parseVertices") - - return True - + raise TigerGraphException( + res["message"], (res["code"] if "code" in res else None)) - raise TigerGraphException(res["message"], (res["code"] if "code" in res else None)) diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index 13ef2f6c..4d336f79 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -10,16 +10,18 @@ import sys import re import warnings +import requests + from typing import Union from urllib.parse import urlparse -import requests -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.base import PyTigerGraphCore def excepthook(type, value, traceback): """NO DOC - + This function prints out a given traceback and exception to sys.stderr. See: https://docs.python.org/3/library/sys.html#sys.excepthook @@ -30,13 +32,14 @@ def excepthook(type, value, traceback): logger = logging.getLogger(__name__) -class pyTigerGraphBase(object): + +class pyTigerGraphBase(PyTigerGraphCore, object): def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", - gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", - tgCloud: bool = False, restppPort: Union[int, str] = "9000", - gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", - apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, - sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): + gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", + tgCloud: bool = False, restppPort: Union[int, str] = "9000", + gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", + apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): """Initiate a connection object. Args: @@ -90,7 +93,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", inputHost = urlparse(host) if inputHost.scheme not in ["http", "https"]: raise TigerGraphException("Invalid URL scheme. Supported schemes are http and https.", - "E-0003") + "E-0003") self.netloc = inputHost.netloc self.host = "{0}://{1}".format(inputHost.scheme, self.netloc) if gsqlSecret != "": @@ -101,13 +104,13 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", self.password = password self.graphname = graphname self.responseConfigHeader = {} - self.awsIamHeaders={} + self.awsIamHeaders = {} self.jwtToken = jwtToken self.apiToken = apiToken self.base64_credential = base64.b64encode( - "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") - + "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") + self.authHeader = self._set_auth_header() # TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version @@ -151,13 +154,15 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", # TODO Remove gcp parameter if gcp: - warnings.warn("The `gcp` parameter is deprecated.", DeprecationWarning) + warnings.warn("The `gcp` parameter is deprecated.", + DeprecationWarning) self.tgCloud = tgCloud or gcp if "tgcloud" in self.netloc.lower(): try: # If get request succeeds, using TG Cloud instance provisioned after 6/20/2022 self._get(self.host + "/api/ping", resKey="message") self.tgCloud = True - except requests.exceptions.RequestException: # If get request fails, using TG Cloud instance provisioned before 6/20/2022, before new firewall config + # If get request fails, using TG Cloud instance provisioned before 6/20/2022, before new firewall config + except requests.exceptions.RequestException: self.tgCloud = False except TigerGraphException: raise (TigerGraphException("Incorrect graphname.")) @@ -166,7 +171,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", sslPort = str(sslPort) if self.tgCloud and (restppPort == "9000" or restppPort == "443"): self.restppPort = sslPort - self.restppUrl = self.host + ":"+ sslPort + "/restpp" + self.restppUrl = self.host + ":" + sslPort + "/restpp" else: self.restppPort = restppPort self.restppUrl = self.host + ":" + self.restppPort @@ -192,7 +197,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", 'Host': 'sts.amazonaws.com' }) # Get headers - SigV4Auth(boto3.Session().get_credentials(), "sts", "us-east-1").add_auth(request) + SigV4Auth(boto3.Session().get_credentials(), + "sts", "us-east-1").add_auth(request) self.awsIamHeaders["X-Amz-Date"] = request.headers["X-Amz-Date"] self.awsIamHeaders["X-Amz-Security-Token"] = request.headers["X-Amz-Security-Token"] self.awsIamHeaders["Authorization"] = request.headers["Authorization"] @@ -200,6 +206,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", if self.jwtToken: self._verify_jwt_token_support() + self.asynchronous = False + logger.info("exit: __init__") def _set_auth_header(self): @@ -241,25 +249,12 @@ def _locals(self, _locals: dict) -> str: del _locals["self"] return str(_locals) - def _errorCheck(self, res: dict): - """Checks if the JSON document returned by an endpoint has contains `error: true`. If so, - it raises an exception. - - Args: - res: - The output from a request. - - Raises: - TigerGraphException: if request returned with error, indicated in the returned JSON. - """ - if "error" in res and res["error"] and res["error"] != "false": - # Endpoint might return string "false" rather than Boolean false - raise TigerGraphException(res["message"], (res["code"] if "code" in res else None)) + logger.info("exit: __init__") def _req(self, method: str, url: str, authMode: str = "token", headers: dict = None, - data: Union[dict, list, str] = None, resKey: str = "results", skipCheck: bool = False, - params: Union[dict, list, str] = None, strictJson: bool = True, jsonData: bool = False, - jsonResponse: bool = True) -> Union[dict, list]: + data: Union[dict, list, str] = None, resKey: str = "results", skipCheck: bool = False, + params: Union[dict, list, str] = None, strictJson: bool = True, jsonData: bool = False, + jsonResponse: bool = True) -> Union[dict, list]: """Generic REST++ API request. Args: @@ -288,63 +283,27 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N Returns: The (relevant part of the) response from the request (as a dictionary). """ - logger.info("entry: _req") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - # If JWT token is provided, always use jwtToken as token - if authMode == "token": - if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "": - token = self.jwtToken - elif isinstance(self.apiToken, tuple): - token = self.apiToken[0] - elif isinstance(self.apiToken, str) and self.apiToken.strip() != "": - token = self.apiToken - else: - token = None - - if token: - self.authHeader = {'Authorization': "Bearer " + token} - _headers = self.authHeader - else: - self.authHeader = {'Authorization': 'Basic {0}'.format(self.base64_credential)} - _headers = self.authHeader - authMode = 'pwd' - - if authMode == "pwd": - if self.jwtToken: - _headers = {'Authorization': "Bearer " + self.jwtToken} - else: - _headers = {'Authorization': 'Basic {0}'.format(self.base64_credential)} - - if headers: - _headers.update(headers) - if self.awsIamHeaders: - if url.startswith(self.gsUrl + "/gsqlserver/") or (self._versionGreaterThan4_0() and url.startswith(self.gsUrl)): # version >=4.1 has removed /gsqlserver/ - _headers.update(self.awsIamHeaders) - if self.responseConfigHeader: - _headers.update(self.responseConfigHeader) - if method == "POST" or method == "PUT" or method == "DELETE": - _data = data - else: - _data = None - - if self.useCert is True or self.certPath is not None: - verify = False - else: - verify = True - - _headers.update({"X-User-Agent": "pyTigerGraph"}) + _headers, _data, verify = self._prep_req( + authMode, headers, url, method, data) if jsonData: - res = requests.request(method, url, headers=_headers, json=_data, params=params, verify=verify) + res = requests.request( + method, url, headers=_headers, json=_data, params=params, verify=verify) else: - res = requests.request(method, url, headers=_headers, data=_data, params=params, verify=verify) + res = requests.request( + method, url, headers=_headers, data=_data, params=params, verify=verify) try: + if not skipCheck and not (200 <= res.status_code < 300): + try: + self._error_check(json.loads(res.text)) + except json.decoder.JSONDecodeError: + # could not parse the res text (probably returned an html response) + pass res.raise_for_status() except Exception as e: - # In TG 4.x the port for restpp has changed from 9000 to 14240. + + # In TG 4.x the port for restpp has changed from 9000 to 14240. # This block should only be called once. When using 4.x, using port 9000 should fail so self.restppurl will change to host:14240/restpp # ---- # Changes port to gsql port, adds /restpp to end to url, tries again, saves changes if successful @@ -354,42 +313,33 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N if self.tgCloud: url = newRestppUrl + '/' + '/'.join(url.split(':')[2].split('/')[2:]) else: - url = newRestppUrl + '/' + '/'.join(url.split(':')[2].split('/')[1:]) + url = newRestppUrl + '/' + \ + '/'.join(url.split(':')[2].split('/')[1:]) if jsonData: - res = requests.request(method, url, headers=_headers, json=_data, params=params, verify=verify) + res = requests.request( + method, url, headers=_headers, json=_data, params=params, verify=verify) else: - res = requests.request(method, url, headers=_headers, data=_data, params=params, verify=verify) + res = requests.request( + method, url, headers=_headers, data=_data, params=params, verify=verify) + + # Run error check if there might be an error before raising for status + # raising for status gives less descriptive error message + if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: + try: + self._error_check(json.loads(res.text)) + except json.decoder.JSONDecodeError: + # could not parse the res text (probably returned an html response) + pass res.raise_for_status() self.restppUrl = newRestppUrl self.restppPort = self.gsPort else: raise e - if jsonResponse: - try: - res = json.loads(res.text, strict=strictJson) - except: - raise TigerGraphException("Cannot parse json: " + res.text) - else: - res = res.text - - if not skipCheck: - self._errorCheck(res) - if not resKey: - if logger.level == logging.DEBUG: - logger.debug("return: " + str(res)) - logger.info("exit: _req (no resKey)") - - return res + return self._parse_req(res, jsonResponse, strictJson, skipCheck, resKey) - if logger.level == logging.DEBUG: - logger.debug("return: " + str(res[resKey])) - logger.info("exit: _req (resKey)") - - return res[resKey] - def _get(self, url: str, authMode: str = "token", headers: dict = None, resKey: str = "results", - skipCheck: bool = False, params: Union[dict, list, str] = None, strictJson: bool = True) -> Union[dict, list]: + skipCheck: bool = False, params: Union[dict, list, str] = None, strictJson: bool = True) -> Union[dict, list]: """Generic GET method. Args: @@ -414,7 +364,8 @@ def _get(self, url: str, authMode: str = "token", headers: dict = None, resKey: if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - res = self._req("GET", url, authMode, headers, None, resKey, skipCheck, params, strictJson) + res = self._req("GET", url, authMode, headers, None, + resKey, skipCheck, params, strictJson) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -423,8 +374,8 @@ def _get(self, url: str, authMode: str = "token", headers: dict = None, resKey: return res def _post(self, url: str, authMode: str = "token", headers: dict = None, - data: Union[dict, list, str, bytes] = None, resKey: str = "results", skipCheck: bool = False, - params: Union[dict, list, str] = None, jsonData: bool = False) -> Union[dict, list]: + data: Union[dict, list, str, bytes] = None, resKey: str = "results", skipCheck: bool = False, + params: Union[dict, list, str] = None, jsonData: bool = False) -> Union[dict, list]: """Generic POST method. Args: @@ -451,15 +402,16 @@ def _post(self, url: str, authMode: str = "token", headers: dict = None, if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - res = self._req("POST", url, authMode, headers, data, resKey, skipCheck, params, jsonData=jsonData) + res = self._req("POST", url, authMode, headers, data, + resKey, skipCheck, params, jsonData=jsonData) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) logger.info("exit: _post") return res - - def _put(self, url: str, authMode: str = "token", data = None, resKey=None, jsonData=False) -> Union[dict, list]: + + def _put(self, url: str, authMode: str = "token", data=None, resKey=None, jsonData=False) -> Union[dict, list]: """Generic PUT method. Args: @@ -475,7 +427,8 @@ def _put(self, url: str, authMode: str = "token", data = None, resKey=None, json if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - res = self._req("PUT", url, authMode, data=data, resKey=resKey, jsonData=jsonData) + res = self._req("PUT", url, authMode, data=data, + resKey=resKey, jsonData=jsonData) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -499,28 +452,15 @@ def _delete(self, url: str, authMode: str = "token", data: dict = None, resKey=" if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - res = self._req("DELETE", url, authMode, data=data, resKey=resKey, jsonData=jsonData) + res = self._req("DELETE", url, authMode, data=data, + resKey=resKey, jsonData=jsonData) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) logger.info("exit: _delete") return res - - def customizeHeader(self, timeout:int = 16_000, responseSize:int = 3.2e+7): - """Method to configure the request header. - - Args: - tiemout (int, optional): - The timeout value desired in milliseconds. Defaults to 16,000 ms (16 sec) - responseSize: - The size of the response in bytes. Defaults to 3.2E7 bytes (32 MB). - - Returns: - Nothing. Sets `responseConfigHeader` class attribute. - """ - self.responseConfigHeader = {"GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)} - + def getVersion(self, raw: bool = False) -> Union[str, list]: """Retrieves the git versions of all components of the system. @@ -540,22 +480,13 @@ def getVersion(self, raw: bool = False) -> Union[str, list]: logger.info("entry: getVersion") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - response = self._get(self.restppUrl+"/version", strictJson=False, resKey="message") - if raw: - return response - res = response.split("\n") - components = [] - for i in range(len(res)): - if 2 < i < len(res) - 1: - m = res[i].split() - component = {"name": m[0], "version": m[1], "hash": m[2], - "datetime": m[3] + " " + m[4] + " " + m[5]} - components.append(component) + response = self._get(self.restppUrl+"/version", + strictJson=False, resKey="message") + components = self._parse_get_version(response, raw) if logger.level == logging.DEBUG: logger.debug("return: " + str(components)) logger.info("exit: getVersion") - return components def getVer(self, component: str = "product", full: bool = False) -> str: @@ -578,32 +509,22 @@ def getVer(self, component: str = "product", full: bool = False) -> str: logger.info("entry: getVer") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) + version = self.getVersion() + ret = self._parse_get_ver(version, component, full) - ret = "" - for v in self.getVersion(): - if v["name"] == component.lower(): - ret = v["version"] - if ret != "": - if full: - return ret - ret = re.search("_.+_", ret) - ret = ret.group().strip("_") - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: getVer") - - return ret - else: - raise TigerGraphException("\"" + component + "\" is not a valid component.", None) - - def _versionGreaterThan4_0(self) -> bool: + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVer") + + return ret + + def _version_greater_than_4_0(self) -> bool: """Gets if the TigerGraph database version is greater than 4.0 using gerVer(). Returns: Boolean of whether databse version is greater than 4.0. """ version = self.getVer().split('.') - if version[0]>="4" and version[1]>"0": + if version[0] >= "4" and version[1] > "0": return True - return False + return False \ No newline at end of file diff --git a/pyTigerGraph/pyTigerGraphDataset.py b/pyTigerGraph/pyTigerGraphDataset.py index 6e6c22af..3fffb765 100644 --- a/pyTigerGraph/pyTigerGraphDataset.py +++ b/pyTigerGraph/pyTigerGraphDataset.py @@ -5,8 +5,9 @@ """ import logging -from .datasets import Datasets -from .pyTigerGraphAuth import pyTigerGraphAuth +from pyTigerGraph.datasets import Datasets +from pyTigerGraph.common.dataset import _parse_ingest_dataset +from pyTigerGraph.pyTigerGraphAuth import pyTigerGraphAuth logger = logging.getLogger(__name__) @@ -75,30 +76,11 @@ def ingestDataset( if getToken: self.getToken(self.createSecret()) + responses = [] for resp in dataset.run_load_job(self): - stats = resp[0]["statistics"] - if "vertex" in stats: - for vstats in stats["vertex"]: - print( - "Ingested {} objects into VERTEX {}".format( - vstats["validObject"], vstats["typeName"] - ), - flush=True, - ) - if "edge" in stats: - for estats in stats["edge"]: - print( - "Ingested {} objects into EDGE {}".format( - estats["validObject"], estats["typeName"] - ), - flush=True, - ) - if logger.level == logging.DEBUG: - logger.debug(str(resp)) - - if cleanup: - print("---- Cleaning ----", flush=True) - dataset.clean_up() + responses.append(resp) + + _parse_ingest_dataset(responses, cleanup, dataset) print("---- Finished ingestion ----", flush=True) logger.info("exit: ingestDataset") diff --git a/pyTigerGraph/pyTigerGraphEdge.py b/pyTigerGraph/pyTigerGraphEdge.py index c08ad73e..304dcdba 100644 --- a/pyTigerGraph/pyTigerGraphEdge.py +++ b/pyTigerGraph/pyTigerGraphEdge.py @@ -12,9 +12,31 @@ if TYPE_CHECKING: import pandas as pd -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.edge import ( + _parse_get_edge_source_vertex_type, + _parse_get_edge_target_vertex_type, + _prep_get_edge_count_from, + _parse_get_edge_count_from, + _prep_upsert_edge, + _dumps, + _prep_upsert_edges, + _prep_upsert_edge_dataframe, + _prep_get_edges, + _prep_get_edges_by_type, + _parse_get_edge_stats, + _prep_del_edges +) + +from pyTigerGraph.common.edge import edgeSetToDataFrame as _eS2DF + +from pyTigerGraph.common.schema import ( + _get_attr_type, + _upsert_attrs +) + from pyTigerGraph.pyTigerGraphQuery import pyTigerGraphQuery + logger = logging.getLogger(__name__) @@ -100,7 +122,8 @@ def getEdgeAttrs(self, edgeType: str) -> list: ret = [] for at in et["Attributes"]: - ret.append((at["AttributeName"], self._getAttrType(at["AttributeType"]))) + ret.append( + (at["AttributeName"], _get_attr_type(at["AttributeType"]))) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -119,7 +142,8 @@ def getEdgeSourceVertexType(self, edgeType: str) -> Union[str, set]: - A single source vertex type name string if the edge has a single source vertex type. - "*" if the edge can originate from any vertex type (notation used in 2.6.1 and earlier versions). - See https://docs.tigergraph.com/v/2.6/dev/gsql-ref/ddl-and-loading/defining-a-graph-schema#creating-an-edge-from-or-to-any-vertex-type + #creating-an-edge-from-or-to-any-vertex-type + See https://docs.tigergraph.com/v/2.6/dev/gsql-ref/ddl-and-loading/defining-a-graph-schema - A set of vertex type name strings (unique values) if the edge has multiple source vertex types (notation used in 3.0 and later versions). / Even if the source vertex types were defined as `"*"`, the REST API will list them as @@ -136,36 +160,8 @@ def getEdgeSourceVertexType(self, edgeType: str) -> Union[str, set]: logger.debug("params: " + self._locals(locals())) edgeTypeDetails = self.getEdgeType(edgeType) - - # Edge type with a single source vertex type - if edgeTypeDetails["FromVertexTypeName"] != "*": - ret = edgeTypeDetails["FromVertexTypeName"] - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: getEdgeSourceVertexType (single source)") - - return ret - - # Edge type with multiple source vertex types - if "EdgePairs" in edgeTypeDetails: - # v3.0 and later notation - vts = set() - for ep in edgeTypeDetails["EdgePairs"]: - vts.add(ep["From"]) - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(vts)) - logger.info("exit: getEdgeSourceVertexType (multi source)") - - return vts - else: - # 2.6.1 and earlier notation - if logger.level == logging.DEBUG: - logger.debug("return: *") - logger.info("exit: getEdgeSourceVertexType (multi source, pre-3.x)") - - return "*" + res = _parse_get_edge_source_vertex_type(edgeTypeDetails) + return res def getEdgeTargetVertexType(self, edgeType: str) -> Union[str, set]: """Returns the type(s) of the edge type's target vertex. @@ -178,7 +174,8 @@ def getEdgeTargetVertexType(self, edgeType: str) -> Union[str, set]: - A single target vertex type name string if the edge has a single target vertex type. - "*" if the edge can end in any vertex type (notation used in 2.6.1 and earlier versions). - See https://docs.tigergraph.com/v/2.6/dev/gsql-ref/ddl-and-loading/defining-a-graph-schema#creating-an-edge-from-or-to-any-vertex-type + #creating-an-edge-from-or-to-any-vertex-type + See https://docs.tigergraph.com/v/2.6/dev/gsql-ref/ddl-and-loading/defining-a-graph-schema - A set of vertex type name strings (unique values) if the edge has multiple target vertex types (notation used in 3.0 and later versions). / Even if the target vertex types were defined as "*", the REST API will list them as @@ -194,36 +191,8 @@ def getEdgeTargetVertexType(self, edgeType: str) -> Union[str, set]: logger.debug("params: " + self._locals(locals())) edgeTypeDetails = self.getEdgeType(edgeType) - - # Edge type with a single target vertex type - if edgeTypeDetails["ToVertexTypeName"] != "*": - ret = edgeTypeDetails["ToVertexTypeName"] - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: getEdgeTargetVertexType (single target)") - - return ret - - # Edge type with multiple target vertex types - if "EdgePairs" in edgeTypeDetails: - # v3.0 and later notation - vts = set() - for ep in edgeTypeDetails["EdgePairs"]: - vts.add(ep["To"]) - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(vts)) - logger.info("exit: getEdgeTargetVertexType (multi target)") - - return vts - else: - # 2.6.1 and earlier notation - if logger.level == logging.DEBUG: - logger.debug("return: *") - logger.info("exit: getEdgeTargetVertexType (multi target, pre-3.x)") - - return "*" + ret = _parse_get_edge_target_vertex_type(edgeTypeDetails) + return ret def isDirected(self, edgeType: str) -> bool: """Is the specified edge type directed? @@ -324,7 +293,8 @@ def getDiscriminators(self, edgeType: str) -> list: for at in et["Attributes"]: if "IsDiscriminator" in at and at["IsDiscriminator"]: - ret.append((at["AttributeName"], self._getAttrType(at["AttributeType"]))) + ret.append( + (at["AttributeName"], _get_attr_type(at["AttributeType"]))) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -333,8 +303,8 @@ def getDiscriminators(self, edgeType: str) -> list: return ret def getEdgeCountFrom(self, sourceVertexType: str = "", sourceVertexId: Union[str, int] = None, - edgeType: str = "", targetVertexType: str = "", targetVertexId: Union[str, int] = None, - where: str = "") -> dict: + edgeType: str = "", targetVertexType: str = "", targetVertexId: Union[str, int] = None, + where: str = "") -> dict: """Returns the number of edges from a specific vertex. Args: @@ -372,54 +342,29 @@ def getEdgeCountFrom(self, sourceVertexType: str = "", sourceVertexId: Union[str Endpoints: - `GET /graph/{graph_name}/edges/{source_vertex_type}/{source_vertex_id}` - See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_list_edges_of_a_vertex + #_list_edges_of_a_vertex + See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints - `POST /builtins/{graph_name}` - See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_run_built_in_functions_on_graph + #_run_built_in_functions_on_graph + See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints """ logger.info("entry: getEdgeCountFrom") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - # If WHERE condition is not specified, use /builtins else user /vertices - if where or (sourceVertexType and sourceVertexId): - if not sourceVertexType or not sourceVertexId: - raise TigerGraphException( - "If where condition is specified, then both sourceVertexType and sourceVertexId" - " must be provided too.", None) - url = self.restppUrl + "/graph/" + self._safeChar(self.graphname) + "/edges/" + \ - self._safeChar(sourceVertexType) + "/" + self._safeChar(sourceVertexId) - if edgeType: - url += "/" + self._safeChar(edgeType) - if targetVertexType: - url += "/" + self._safeChar(targetVertexType) - if targetVertexId: - url += "/" + self._safeChar(targetVertexId) - url += "?count_only=true" - if where: - url += "&filter=" + self._safeChar(where) - res = self._get(url) + url, data = _prep_get_edge_count_from(restppUrl=self.restppUrl, + graphname=self.graphname, + sourceVertexType=sourceVertexType, + sourceVertexId=sourceVertexId, + edgeType=edgeType, + targetVertexType=targetVertexType, + targetVertexId=targetVertexId, + where=where) + if data: + res = self._req("POST", url, data=data) else: - if not edgeType: # TODO Is this a valid check? - raise TigerGraphException( - "A valid edge type or \"*\" must be specified for edge type.", None) - data = '{"function":"stat_edge_number","type":"' + edgeType + '"' \ - + (',"from_type":"' + sourceVertexType + '"' if sourceVertexType else '') \ - + (',"to_type":"' + targetVertexType + '"' if targetVertexType else '') \ - + '}' - res = self._post(self.restppUrl + "/builtins/" + self.graphname, data=data) - - if len(res) == 1 and res[0]["e_type"] == edgeType: - ret = res[0]["count"] - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: getEdgeCountFrom (single edge type)") - - return ret - - ret = {} - for r in res: - ret[r["e_type"]] = r["count"] + res = self._req("GET", url) + ret = _parse_get_edge_count_from(res, edgeType) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -428,7 +373,7 @@ def getEdgeCountFrom(self, sourceVertexType: str = "", sourceVertexId: Union[str return ret def getEdgeCount(self, edgeType: str = "*", sourceVertexType: str = "", - targetVertexType: str = "") -> dict: + targetVertexType: str = "") -> dict: """Returns the number of edges of an edge type. This is a simplified version of `getEdgeCountFrom()`, to be used when the total number of @@ -451,7 +396,7 @@ def getEdgeCount(self, edgeType: str = "*", sourceVertexType: str = "", logger.debug("params: " + self._locals(locals())) ret = self.getEdgeCountFrom(edgeType=edgeType, sourceVertexType=sourceVertexType, - targetVertexType=targetVertexType) + targetVertexType=targetVertexType) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -459,16 +404,8 @@ def getEdgeCount(self, edgeType: str = "*", sourceVertexType: str = "", return ret - def upsertEdge( - self, - sourceVertexType: str, - sourceVertexId: str, - edgeType: str, - targetVertexType: str, - targetVertexId: str, - attributes: dict = None, - vertexMustExist: bool = False, - ) -> int: + def upsertEdge(self, sourceVertexType: str, sourceVertexId: str, edgeType: str, + targetVertexType: str, targetVertexId: str, attributes: dict = None, vertexMustExist: bool = False) -> int: """Upserts an edge. Data is upserted: @@ -499,7 +436,8 @@ def upsertEdge( ``` {"visits": (1482, "+"), "max_duration": (371, "max")} ``` - For valid values of `` see https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#operation-codes . + #operation-codes . + For valid values of `` see https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints Returns: A single number of accepted (successfully upserted) edges (0 or 1). @@ -515,16 +453,26 @@ def upsertEdge( if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - if attributes is None: - attributes = {} + data = _prep_upsert_edge(sourceVertexType, + sourceVertexId, + edgeType, + targetVertexType, + targetVertexId, + attributes + ) - vals = self._upsertAttrs(attributes) + ret = self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data)[0][ + "accepted_edges"] + + vals = _upsert_attrs(attributes) data = json.dumps( { "edges": { sourceVertexType: { sourceVertexId: { - edgeType: {targetVertexType: {targetVertexId: vals}} + edgeType: {targetVertexType: { + targetVertexId: vals}} + } } } @@ -544,14 +492,8 @@ def upsertEdge( return ret - def upsertEdges( - self, - sourceVertexType: str, - edgeType: str, - targetVertexType: str, - edges: list, - vertexMustExist=False, - ) -> int: + def upsertEdges(self, sourceVertexType: str, edgeType: str, targetVertexType: str, + edges: list, vertexMustExist=False) -> int: """Upserts multiple edges (of the same type). Args: @@ -573,11 +515,14 @@ def upsertEdges( Example: ``` [ - (17, "home_page", {"visits": (35, "+"), "max_duration": (93, "max")}), - (42, "search", {"visits": (17, "+"), "max_duration": (41, "max")}) + (17, "home_page", {"visits": (35, "+"), + "max_duration": (93, "max")}), + (42, "search", {"visits": (17, "+"), + "max_duration": (41, "max")}) ] ``` - For valid values of `` see https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#operation-codes . + #operation-codes . + For valid values of `` see https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints Returns: A single number of accepted (successfully upserted) edges (0 or positive integer). @@ -590,47 +535,6 @@ def upsertEdges( parameters and functionality. """ - def _dumps(data) -> str: - """Generates the JSON format expected by the endpoint. - - The important thing this function does is converting the list of target vertex IDs and - the attributes belonging to the edge instances into a JSON object that can contain - multiple occurrences of the same key. If the these details were stored in a dictionary - then in case of MultiEdge only the last instance would be retained (as the key would be - the target vertex ID). - - Args: - data: - The Python data structure containing the edge instance details. - - Returns: - The JSON to be sent to the endpoint. - """ - ret = "" - if isinstance(data, dict): - c1 = 0 - for k1, v1 in data.items(): - if c1 > 0: - ret += "," - if k1 == self.___trgvtxids: - # Dealing with the (possibly multiple instances of) edge details - # v1 should be a dict of lists - c2 = 0 - for k2, v2 in v1.items(): - if c2 > 0: - ret += "," - c3 = 0 - for v3 in v2: - if c3 > 0: - ret += "," - ret += json.dumps(k2) + ":" + json.dumps(v3) - c3 += 1 - c2 += 1 - else: - ret += json.dumps(k1) + ":" + _dumps(data[k1]) - c1 += 1 - return "{" + ret + "}" - logger.info("entry: upsertEdges") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) @@ -644,16 +548,23 @@ def _dumps(data) -> str: Converting the primary IDs to string here prevents inconsistencies as Python dict would otherwise handle 1 and "1" as two separate keys. """ + data = _prep_upsert_edges(sourceVertexType=sourceVertexType, + edgeType=edgeType, + targetVertexType=targetVertexType, + edges=edges) + ret = self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data)[0][ + "accepted_edges"] data = {sourceVertexType: {}} l1 = data[sourceVertexType] for e in edges: if len(e) > 2: - vals = self._upsertAttrs(e[2]) + vals = _upsert_attrs(e[2]) else: vals = {} # sourceVertexId - sourceVertexId = str(e[0]) # Converted to string as the key in the JSON payload must be a string + # Converted to string as the key in the JSON payload must be a string + sourceVertexId = str(e[0]) if sourceVertexId not in l1: l1[sourceVertexId] = {} l2 = l1[sourceVertexId] @@ -669,7 +580,8 @@ def _dumps(data) -> str: l4[self.___trgvtxids] = {} l4 = l4[self.___trgvtxids] # targetVertexId - targetVertexId = str(e[1]) # Converted to string as the key in the JSON payload must be a string + # Converted to string as the key in the JSON payload must be a string + targetVertexId = str(e[1]) if targetVertexId not in l4: l4[targetVertexId] = [] l4[targetVertexId].append(vals) @@ -687,17 +599,9 @@ def _dumps(data) -> str: return ret - def upsertEdgeDataFrame( - self, - df: "pd.DataFrame", - sourceVertexType: str, - edgeType: str, - targetVertexType: str, - from_id: str = "", - to_id: str = "", - attributes: dict = None, - vertexMustExist: bool = False, - ) -> int: + def upsertEdgeDataFrame(self, df: 'pd.DataFrame', sourceVertexType: str, edgeType: str, + targetVertexType: str, from_id: str = "", to_id: str = "", + attributes: dict = None, vertexMustExist: bool = False) -> int: """Upserts edges from a Pandas DataFrame. Args: @@ -728,6 +632,10 @@ def upsertEdgeDataFrame( if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) + json_up = _prep_upsert_edge_dataframe(df, from_id, to_id, attributes) + ret = self.upsertEdges(sourceVertexType, edgeType, + targetVertexType, json_up) + json_up = [] for index in df.index: @@ -757,9 +665,9 @@ def upsertEdgeDataFrame( return ret def getEdges(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = "", - targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", - limit: Union[int, str] = None, sort: str = "", fmt: str = "py", withId: bool = True, - withType: bool = False, timeout: int = 0) -> Union[dict, str, 'pd.DataFrame']: + targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", fmt: str = "py", withId: bool = True, + withType: bool = False, timeout: int = 0) -> Union[dict, str, 'pd.DataFrame']: """Retrieves edges of the given edge type originating from a specific source vertex. Only `sourceVertexType` and `sourceVertexId` are required. @@ -813,38 +721,24 @@ def getEdges(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = " # TODO Change sourceVertexId to sourceVertexIds and allow passing both str and list as # parameter - if not sourceVertexType or not sourceVertexId: - raise TigerGraphException( - "Both source vertex type and source vertex ID must be provided.", None) - url = self.restppUrl + "/graph/" + self.graphname + "/edges/" + sourceVertexType + "/" + \ - str(sourceVertexId) - if edgeType: - url += "/" + edgeType - if targetVertexType: - url += "/" + targetVertexType - if targetVertexId: - url += "/" + str(targetVertexId) - isFirst = True - if select: - url += "?select=" + select - isFirst = False - if where: - url += ("?" if isFirst else "&") + "filter=" + where - isFirst = False - if limit: - url += ("?" if isFirst else "&") + "limit=" + str(limit) - isFirst = False - if sort: - url += ("?" if isFirst else "&") + "sort=" + sort - isFirst = False - if timeout and timeout > 0: - url += ("?" if isFirst else "&") + "timeout=" + str(timeout) - ret = self._get(url) + url = _prep_get_edges(self.restppUrl, + self.graphname, + sourceVertexType, + sourceVertexId, + edgeType, + targetVertexType, + targetVertexId, + select, + where, + limit, + sort, + timeout) + ret = self._req("GET", url) if fmt == "json": ret = json.dumps(ret) elif fmt == "df": - ret = self.edgeSetToDataFrame(ret, withId, withType) + ret = _eS2DF(ret, withId, withType) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -853,8 +747,8 @@ def getEdges(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = " return ret def getEdgesDataFrame(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = "", - targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", - limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': + targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': """Retrieves edges of the given edge type originating from a specific source vertex. This is a shortcut to ``getEdges(..., fmt="df", withId=True, withType=False)``. @@ -894,7 +788,7 @@ def getEdgesDataFrame(self, sourceVertexType: str, sourceVertexId: str, edgeType logger.debug("params: " + self._locals(locals())) ret = self.getEdges(sourceVertexType, sourceVertexId, edgeType, targetVertexType, - targetVertexId, select, where, limit, sort, fmt="df", timeout=timeout) + targetVertexId, select, where, limit, sort, fmt="df", timeout=timeout) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -903,8 +797,8 @@ def getEdgesDataFrame(self, sourceVertexType: str, sourceVertexId: str, edgeType return ret def getEdgesDataframe(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = "", - targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", - limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': + targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': """DEPRECATED Use `getEdgesDataFrame()` instead. @@ -914,10 +808,10 @@ def getEdgesDataframe(self, sourceVertexType: str, sourceVertexId: str, edgeType DeprecationWarning) return self.getEdgesDataFrame(sourceVertexType, sourceVertexId, edgeType, targetVertexType, - targetVertexId, select, where, limit, sort, timeout) + targetVertexId, select, where, limit, sort, timeout) def getEdgesByType(self, edgeType: str, fmt: str = "py", withId: bool = True, - withType: bool = False) -> Union[dict, str, 'pd.DataFrame']: + withType: bool = False) -> Union[dict, str, 'pd.DataFrame']: """Retrieves edges of the given edge type regardless the source vertex. Args: @@ -951,27 +845,7 @@ def getEdgesByType(self, edgeType: str, fmt: str = "py", withId: bool = True, return {} sourceVertexType = self.getEdgeSourceVertexType(edgeType) - # TODO Support edges with multiple source vertex types - if isinstance(sourceVertexType, set) or sourceVertexType == "*": - raise TigerGraphException( - "Edges with multiple source vertex types are not currently supported.", None) - - queryText = \ - 'INTERPRET QUERY () FOR GRAPH $graph { \ - SetAccum @@edges; \ - start = {ANY}; \ - res = \ - SELECT s \ - FROM start:s-(:e)->ANY:t \ - WHERE e.type == "$edgeType" \ - AND s.type == "$sourceEdgeType" \ - ACCUM @@edges += e; \ - PRINT @@edges AS edges; \ - }' - - queryText = queryText.replace("$graph", self.graphname) \ - .replace('$sourceEdgeType', sourceVertexType) \ - .replace('$edgeType', edgeType) + queryText = _prep_get_edges_by_type(self.graphname, sourceVertexType, edgeType) ret = self.runInterpretedQuery(queryText) ret = ret[0]["edges"] @@ -979,7 +853,7 @@ def getEdgesByType(self, edgeType: str, fmt: str = "py", withId: bool = True, if fmt == "json": ret = json.dumps(ret) elif fmt == "df": - ret = self.edgeSetToDataFrame(ret, withId, withType) + ret = _eS2DF(ret, withId, withType) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -1023,23 +897,14 @@ def getEdgeStats(self, edgeTypes: Union[str, list], skipNA: bool = False) -> dic return {} - ret = {} + responses = [] for et in ets: - data = '{"function":"stat_edge_attr","type":"' + et + '","from_type":"*","to_type":"*"}' - res = self._post(self.restppUrl + "/builtins/" + self.graphname, data=data, resKey="", - skipCheck=True) - if res["error"]: - if "stat_edge_attr is skip" in res["message"] or \ - "No valid edge for the input edge type" in res["message"]: - if not skipNA: - ret[et] = {} - else: - raise TigerGraphException(res["message"], - (res["code"] if "code" in res else None)) - else: - res = res["results"] - for r in res: - ret[r["e_type"]] = r["attributes"] + data = '{"function":"stat_edge_attr","type":"' + \ + et + '","from_type":"*","to_type":"*"}' + res = self._req("POST", self.restppUrl + "/builtins/" + self.graphname, data=data, resKey="", + skipCheck=True) + responses.append((et, res)) + ret = _parse_get_edge_stats(responses, skipNA) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -1048,8 +913,8 @@ def getEdgeStats(self, edgeTypes: Union[str, list], skipNA: bool = False) -> dic return ret def delEdges(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = "", - targetVertexType: str = "", targetVertexId: str = "", where: str = "", - limit: str = "", sort: str = "", timeout: int = 0) -> dict: + targetVertexType: str = "", targetVertexId: str = "", where: str = "", + limit: str = "", sort: str = "", timeout: int = 0) -> dict: """Deletes edges from the graph. Only `sourceVertexType` and `sourceVertexId` are required. @@ -1088,31 +953,18 @@ def delEdges(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = " if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - if not sourceVertexType or not sourceVertexId: - raise TigerGraphException("Both sourceVertexType and sourceVertexId must be provided.", - None) - - url = self.restppUrl + "/graph/" + self.graphname + "/edges/" + sourceVertexType + "/" + str( - sourceVertexId) - - if edgeType: - url += "/" + edgeType - if targetVertexType: - url += "/" + targetVertexType - if targetVertexId: - url += "/" + str(targetVertexId) - - isFirst = True - if where: - url += ("?" if isFirst else "&") + "filter=" + where - isFirst = False - if limit and sort: # These two must be provided together - url += ("?" if isFirst else "&") + "limit=" + str(limit) + "&sort=" + sort - isFirst = False - if timeout and timeout > 0: - url += ("?" if isFirst else "&") + "timeout=" + str(timeout) - - res = self._delete(url) + url = _prep_del_edges(self.restppUrl, + self.graphname, + sourceVertexType, + sourceVertexId, + edgeType, + targetVertexType, + targetVertexId, + where, + limit, + sort, + timeout) + res = self._req("DELETE", url) ret = {} for r in res: ret[r["e_type"]] = r["deleted_edges"] @@ -1123,93 +975,28 @@ def delEdges(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = " return ret - def edgeSetToDataFrame(self, edgeSet: list, withId: bool = True, - withType: bool = False) -> 'pd.DataFrame': - """Converts an edge set to Pandas DataFrame - - Edge sets contain instances of the same edge type. Edge sets are not generated "naturally" - like vertex sets. Instead, you need to collect edges in (global) accumulators, like when you - want to visualize them in GraphStudio or by other tools. - - For example: - ``` - SetAccum @@edges; - - start = {country.*}; - - result = - SELECT trg - FROM start:src -(city_in_country:e)- city:trg - ACCUM @@edges += e; - - PRINT start, result, @@edges; - ``` - - The `@@edges` is an edge set. - It contains, for each edge instance, the source and target vertex type and ID, the edge type, - a directedness indicator and the (optional) attributes. / - - [NOTE] - `start` and `result` are vertex sets. - - An edge set has this structure (when serialised as JSON): - - [source.wrap, json] - ---- - [ - { - "e_type": , - "from_type": , - "from_id": , - "to_type": , - "to_id": , - "directed": , - "attributes": - { - "attr1": , - "attr2": , - ⋮ - } - }, - ⋮ - ] - ---- + def edgeSetToDataFrame(self, edgeSet: list, withId: bool = True, withType: bool = False) -> 'pd.DataFrame': + """Converts an edge set to a pandas DataFrame. Args: edgeSet: - A JSON array containing an edge set in the format returned by queries (see below). + The edge set to convert. withId: - Whether to include the type and primary ID of source and target vertices as a column. Default is `True`. + Should the source and target vertex types and IDs be included in the dataframe? withType: - Whether to include edge type info as a column. Default is `False`. + Should the edge type be included in the dataframe? Returns: - A pandas DataFrame containing the edge attributes and optionally the type and primary - ID or source and target vertices, and the edge type. - + The edge set as a pandas DataFrame. """ logger.info("entry: edgeSetToDataFrame") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - try: - import pandas as pd - except ImportError: - raise ImportError("Pandas is required to use this function. " - "Download pandas using 'pip install pandas'.") - - df = pd.DataFrame(edgeSet) - cols = [] - if withId: - cols.extend([df["from_type"], df["from_id"], df["to_type"], df["to_id"]]) - if withType: - cols.append(df["e_type"]) - cols.append(pd.DataFrame(df["attributes"].tolist())) - - ret = pd.concat(cols, axis=1) + ret = _eS2DF(edgeSet, withId, withType) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) logger.info("exit: edgeSetToDataFrame") - return ret + return ret \ No newline at end of file diff --git a/pyTigerGraph/pyTigerGraphGSQL.py b/pyTigerGraph/pyTigerGraphGSQL.py index 406b8f84..9ff67bd9 100644 --- a/pyTigerGraph/pyTigerGraphGSQL.py +++ b/pyTigerGraph/pyTigerGraphGSQL.py @@ -4,17 +4,20 @@ All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. """ import logging -import os -import sys -from typing import Union, Tuple, Dict -from urllib.parse import urlparse, quote_plus import re +import requests +from typing import Union, Tuple, Dict +from urllib.parse import urlparse, quote_plus -import requests +from pyTigerGraph.common.gsql import ( + _parse_gsql, + _prep_get_udf, + _parse_get_udf +) +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.pyTigerGraphBase import pyTigerGraphBase -from .pyTigerGraphBase import pyTigerGraphBase -from .pyTigerGraphException import TigerGraphException logger = logging.getLogger(__name__) @@ -22,7 +25,8 @@ class pyTigerGraphGSQL(pyTigerGraphBase): - def gsql(self, query: str, graphname: str = None, options: list[str] = None) -> Union[str, dict]: + + def gsql(self, query: str, graphname: str = None, options=None) -> Union[str, dict]: """Runs a GSQL query and processes the output. Args: @@ -41,54 +45,19 @@ def gsql(self, query: str, graphname: str = None, options: list[str] = None) -> - `POST /gsqlserver/gsql/file` (In TigerGraph versions 3.x) - `POST /gsql/v1/statements` (In TigerGraph versions 4.x) """ - logger.info("entry: gsql") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - def check_error(query: str, resp: str) -> None: - if "CREATE VERTEX" in query.upper(): - if "Failed to create vertex types" in resp: - raise TigerGraphException(resp) - if ("CREATE DIRECTED EDGE" in query.upper()) or ("CREATE UNDIRECTED EDGE" in query.upper()): - if "Failed to create edge types" in resp: - raise TigerGraphException(resp) - if "CREATE GRAPH" in query.upper(): - if ("The graph" in resp) and ("could not be created!" in resp): - raise TigerGraphException(resp) - if "CREATE DATA_SOURCE" in query.upper(): - if ("Successfully created local data sources" not in resp) and ("Successfully created data sources" not in resp): - raise TigerGraphException(resp) - if "CREATE LOADING JOB" in query.upper(): - if "Successfully created loading jobs" not in resp: - raise TigerGraphException(resp) - if "RUN LOADING JOB" in query.upper(): - if "LOAD SUCCESSFUL" not in resp: - raise TigerGraphException(resp) - - def clean_res(resp: list) -> str: - ret = [] - for line in resp: - if not line.startswith("__GSQL__"): - ret.append(line) - return "\n".join(ret) - - if graphname is None: - graphname = self.graphname - if str(graphname).upper() == "GLOBAL" or str(graphname).upper() == "": - graphname = "" - # Can't use self._isVersionGreaterThan4_0 since you need a token to call /version url # but you need a secret to get a token and you need this function to get a secret try: res = self._req("POST", - self.gsUrl + "/gsql/v1/statements", - data=query.encode("utf-8"), # quote_plus would not work with the new endpoint - authMode="pwd", resKey=None, skipCheck=True, - jsonResponse=False, - headers={"Content-Type": "text/plain"}) + self.gsUrl + "/gsql/v1/statements", + # quote_plus would not work with the new endpoint + data=query.encode("utf-8"), + authMode="pwd", resKey=None, skipCheck=True, + jsonResponse=False, + headers={"Content-Type": "text/plain"}) except requests.exceptions.HTTPError as e: - if e.response.status_code == 404: + if e.response.status_code == 404: res = self._req("POST", self.gsUrl + "/gsqlserver/gsql/file", data=quote_plus(query.encode("utf-8")), @@ -96,22 +65,7 @@ def clean_res(resp: list) -> str: jsonResponse=False) else: raise e - - - if isinstance(res, list): - ret = clean_res(res) - else: - ret = clean_res(res.splitlines()) - - check_error(query, ret) - - string_without_ansi = ANSI_ESCAPE.sub('', ret) - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: gsql (success)") - - return string_without_ansi + return _parse_gsql(res, query, graphname=graphname, options=options) def installUDF(self, ExprFunctions: str = "", ExprUtil: str = "") -> None: """Install user defined functions (UDF) to the database. @@ -141,15 +95,15 @@ def installUDF(self, ExprFunctions: str = "", ExprUtil: str = "") -> None: # A local file: read from disk. with open(ExprFunctions) as infile: data = infile.read() - + if self._versionGreaterThan4_0(): res = self._req("PUT", - url="{}/gsql/v1/udt/files/ExprFunctions".format( - self.gsUrl), authMode="pwd", data=data, resKey="") - else: + url="{}/gsql/v1/udt/files/ExprFunctions".format( + self.gsUrl), authMode="pwd", data=data, resKey="") + else: res = self._req("PUT", - url="{}/gsqlserver/gsql/userdefinedfunction?filename=ExprFunctions".format( - self.gsUrl), authMode="pwd", data=data, resKey="") + url="{}/gsqlserver/gsql/userdefinedfunction?filename=ExprFunctions".format( + self.gsUrl), authMode="pwd", data=data, resKey="") if not res["error"]: logger.info("ExprFunctions installed successfully") else: @@ -166,12 +120,14 @@ def installUDF(self, ExprFunctions: str = "", ExprUtil: str = "") -> None: data = infile.read() if self._versionGreaterThan4_0(): res = self._req("PUT", - url="{}/gsql/v1/udt/files/ExprUtil".format(self.gsUrl), - authMode="pwd", data=data, resKey="") + url="{}/gsql/v1/udt/files/ExprUtil".format( + self.gsUrl), + authMode="pwd", data=data, resKey="") else: res = self._req("PUT", - url="{}/gsqlserver/gsql/userdefinedfunction?filename=ExprUtil".format(self.gsUrl), - authMode="pwd", data=data, resKey="") + url="{}/gsqlserver/gsql/userdefinedfunction?filename=ExprUtil".format( + self.gsUrl), + authMode="pwd", data=data, resKey="") if not res["error"]: logger.info("ExprUtil installed successfully") else: @@ -184,7 +140,7 @@ def installUDF(self, ExprFunctions: str = "", ExprUtil: str = "") -> None: return 0 - def getUDF(self, ExprFunctions: bool = True, ExprUtil: bool = True, json_out: bool = False) -> Union[str, Tuple[str, str], Dict[str,str]]: + def getUDF(self, ExprFunctions: bool = True, ExprUtil: bool = True, json_out=False) -> Union[str, Tuple[str, str], Dict[str, str]]: """Get user defined functions (UDF) installed in the database. See https://docs.tigergraph.com/gsql-ref/current/querying/func/query-user-defined-functions for details on UDFs. @@ -198,9 +154,8 @@ def getUDF(self, ExprFunctions: bool = True, ExprUtil: bool = True, json_out: bo Only supported on version >=4.1 Returns: - - `str`: If only one of `ExprFunctions` or `ExprUtil` is True and json_out is False, return of the content of that file. - - `Tuple[str, str]`: If both `ExprFunctions` and `ExprUtil` are True and json_out is False, return content of ExprFunctions and content of ExprUtil. - - `Dict[str, str]`: If json_out is True, return dict with `ExprFunctions` and/or `ExprUtil` as keys and content of file as value. + str: If only one of `ExprFunctions` or `ExprUtil` is True, return of the content of that file. + Tuple[str, str]: content of ExprFunctions and content of ExprUtil. Endpoints: - `GET /gsqlserver/gsql/userdefinedfunction?filename={ExprFunctions or ExprUtil}` (In TigerGraph versions 3.x) @@ -210,53 +165,18 @@ def getUDF(self, ExprFunctions: bool = True, ExprUtil: bool = True, json_out: bo if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - functions_ret = None - if ExprFunctions: - if self._versionGreaterThan4_0(): - resp = self._get( - "{}/gsql/v1/udt/files/ExprFunctions".format(self.gsUrl), - resKey="") - else: - resp = self._get( - "{}/gsqlserver/gsql/userdefinedfunction".format(self.gsUrl), - params={"filename": "ExprFunctions"}, resKey="") - if not resp["error"]: - logger.info("ExprFunctions get successfully") - functions_ret = resp["results"] - if type(functions_ret) == dict and not json_out: #Endpoint returns a dict when above 4.0 - functions_ret = functions_ret['ExprFunctions'] - else: - logger.error("Failed to get ExprFunctions") - raise TigerGraphException(resp["message"]) - - util_ret = None - if ExprUtil: - if self._versionGreaterThan4_0(): - resp = self._get( - "{}/gsql/v1/udt/files/ExprUtil".format(self.gsUrl), - resKey="") - else: - resp = self._get( - "{}/gsqlserver/gsql/userdefinedfunction".format(self.gsUrl), - params={"filename": "ExprUtil"}, resKey="") - if not resp["error"]: - logger.info("ExprUtil get successfully") - util_ret = resp["results"] - if type(util_ret) == dict and not json_out: #Endpoint returns a dict when above 4.0 - util_ret = util_ret['ExprUtil'] - else: - logger.error("Failed to get ExprUtil") - raise TigerGraphException(resp["message"]) - - if (functions_ret is not None) and (util_ret is not None): + urls, alt_urls = _prep_get_udf( + ExprFunctions=ExprFunctions, ExprUtil=ExprUtil) + if not self._version_greater_than_4_0(): if json_out: - functions_ret.update(util_ret) - return functions_ret - return (functions_ret, util_ret) - elif functions_ret is not None: - return functions_ret - elif util_ret is not None: - return util_ret - else: - return "" + raise TigerGraphException( + "The 'json_out' parameter is only supported in TigerGraph Versions >=4.1.") + urls = alt_urls + responses = {} + + for file_name in urls: + resp = self._req( + "GET", f"{self.gsUrl}{urls[file_name]}", resKey="") + responses[file_name] = resp + return _parse_get_udf(responses, json_out=json_out) diff --git a/pyTigerGraph/pyTigerGraphLoading.py b/pyTigerGraph/pyTigerGraphLoading.py index 16272445..dce96690 100644 --- a/pyTigerGraph/pyTigerGraphLoading.py +++ b/pyTigerGraph/pyTigerGraphLoading.py @@ -5,7 +5,11 @@ """ import logging import warnings + from typing import Union + +from pyTigerGraph.common.loading import _prep_run_loading_job_with_file + from pyTigerGraph.pyTigerGraphBase import pyTigerGraphBase logger = logging.getLogger(__name__) @@ -14,7 +18,7 @@ class pyTigerGraphLoading(pyTigerGraphBase): def runLoadingJobWithFile(self, filePath: str, fileTag: str, jobName: str, sep: str = None, - eol: str = None, timeout: int = 16000, sizeLimit: int = 128000000) -> Union[dict, None]: + eol: str = None, timeout: int = 16000, sizeLimit: int = 128000000) -> Union[dict, None]: """Execute a loading job with the referenced file. The file will first be uploaded to the TigerGraph server and the value of the appropriate @@ -49,25 +53,15 @@ def runLoadingJobWithFile(self, filePath: str, fileTag: str, jobName: str, sep: if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - try: - data = open(filePath, 'rb').read() - params = { - "tag": jobName, - "filename": fileTag, - } - if sep is not None: - params["sep"] = sep - if eol is not None: - params["eol"] = eol - except OSError as ose: - logger.error(ose.strerror) - logger.info("exit: runLoadingJobWithFile") + data, params = _prep_run_loading_job_with_file( + filePath, jobName, fileTag, sep, eol) + if not data and not params: + # failed to read file return None - # TODO Should throw exception instead? - res = self._post(self.restppUrl + "/ddl/" + self.graphname, params=params, data=data, - headers={"RESPONSE-LIMIT": str(sizeLimit), "GSQL-TIMEOUT": str(timeout)}) + res = self._req("POST", self.restppUrl + "/ddl/" + self.graphname, params=params, data=data, + headers={"RESPONSE-LIMIT": str(sizeLimit), "GSQL-TIMEOUT": str(timeout)}) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -76,7 +70,7 @@ def runLoadingJobWithFile(self, filePath: str, fileTag: str, jobName: str, sep: return res def uploadFile(self, filePath, fileTag, jobName="", sep=None, eol=None, timeout=16000, - sizeLimit=128000000) -> dict: + sizeLimit=128000000) -> dict: """DEPRECATED Use `runLoadingJobWithFile()` instead. diff --git a/pyTigerGraph/pyTigerGraphPath.py b/pyTigerGraph/pyTigerGraphPath.py index 34172d88..4df70713 100644 --- a/pyTigerGraph/pyTigerGraphPath.py +++ b/pyTigerGraph/pyTigerGraphPath.py @@ -6,147 +6,23 @@ import json import logging + from typing import Union +from pyTigerGraph.common.path import ( + _prepare_path_params +) from pyTigerGraph.pyTigerGraphBase import pyTigerGraphBase logger = logging.getLogger(__name__) class pyTigerGraphPath(pyTigerGraphBase): - def _preparePathParams(self, sourceVertices: Union[dict, tuple, list], - targetVertices: Union[dict, tuple, list], maxLength: int = None, - vertexFilters: Union[list, dict] = None, edgeFilters: Union[list, dict] = None, - allShortestPaths: bool = False) -> str: - """Prepares the input parameters by transforming them to the format expected by the path - algorithms. - - See xref:tigergraph-server:API:built-in-endpoints.adoc#[Parameters and output format for path finding] - - A vertex set is a dict that has three top-level keys: `v_type`, `v_id`, `attributes` (also a dictionary). - - Args: - sourceVertices: - A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; - the source vertices of the shortest paths sought. - targetVertices: - A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; - the target vertices of the shortest paths sought. - maxLength: - The maximum length of a shortest path. Optional, default is `6`. - vertexFilters: - An optional list of `(vertexType, condition)` tuples or - `{"type": , "condition": }` dictionaries. - edgeFilters: - An optional list of `(edgeType, condition)` tuples or - `{"type": , "condition": }` dictionaries. - allShortestPaths: - If `True`, the endpoint will return all shortest paths between the source and target. - Default is `False`, meaning that the endpoint will return only one path. - - Returns: - A string representation of the dictionary of end-point parameters. - """ - - def parseVertices(vertices: Union[dict, tuple, list]) -> list: - """Parses vertex input parameters and converts it to the format required by the path - finding endpoints. - - Args: - vertices: - A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; - the source or target vertices of the shortest paths sought. - Returns: - A list of vertices in the format required by the path finding endpoints. - """ - logger.info("entry: parseVertices") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - ret = [] - if not isinstance(vertices, list): - vertices = [vertices] - for v in vertices: - if isinstance(v, tuple): - tmp = {"type": v[0], "id": v[1]} - ret.append(tmp) - elif isinstance(v, dict) and "v_type" in v and "v_id" in v: - tmp = {"type": v["v_type"], "id": v["v_id"]} - ret.append(tmp) - else: - logger.warning("Invalid vertex type or value: " + str(v)) - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: parseVertices") - - return ret - - def parseFilters(filters: Union[dict, tuple, list]) -> list: - """Parses filter input parameters and converts it to the format required by the path - finding endpoints. - - Args: - filters: - A list of `(vertexType, condition)` tuples or - `{"type": , "condition": }` dictionaries. - - Returns: - A list of filters in the format required by the path finding endpoints. - """ - logger.info("entry: parseFilters") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - ret = [] - if not isinstance(filters, list): - filters = [filters] - for f in filters: - if isinstance(f, tuple): - tmp = {"type": f[0], "condition": f[1]} - ret.append(tmp) - elif isinstance(f, dict) and "type" in f and "condition" in f: - tmp = {"type": f["type"], "condition": f["condition"]} - ret.append(tmp) - else: - logger.warning("Invalid filter type or value: " + str(f)) - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: parseFilters") - - return ret - - logger.info("entry: _preparePathParams") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - # Assembling the input payload - if not sourceVertices or not targetVertices: - return "" - # TODO Should allow returning error instead of handling missing parameters here? - data = {"sources": parseVertices(sourceVertices), "targets": parseVertices(targetVertices)} - if vertexFilters: - data["vertexFilters"] = parseFilters(vertexFilters) - if edgeFilters: - data["edgeFilters"] = parseFilters(edgeFilters) - if maxLength: - data["maxLength"] = maxLength - if allShortestPaths: - data["allShortestPaths"] = True - - ret = json.dumps(data) - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: _preparePathParams") - - return ret def shortestPath(self, sourceVertices: Union[dict, tuple, list], - targetVertices: Union[dict, tuple, list], maxLength: int = None, - vertexFilters: Union[list, dict] = None, edgeFilters: Union[list, dict] = None, - allShortestPaths: bool = False) -> dict: + targetVertices: Union[dict, tuple, list], maxLength: int = None, + vertexFilters: Union[list, dict] = None, edgeFilters: Union[list, dict] = None, + allShortestPaths: bool = False) -> dict: """Finds the shortest path (or all shortest paths) between the source and target vertex sets. A vertex set is a set of dictionaries that each has three top-level keys: `v_type`, `v_id`, @@ -194,9 +70,10 @@ def shortestPath(self, sourceVertices: Union[dict, tuple, list], if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - data = self._preparePathParams(sourceVertices, targetVertices, maxLength, vertexFilters, - edgeFilters, allShortestPaths) - ret = self._post(self.restppUrl + "/shortestpath/" + self.graphname, data=data) + data = _prepare_path_params(sourceVertices, targetVertices, maxLength, vertexFilters, + edgeFilters, allShortestPaths) + ret = self._post(self.restppUrl + "/shortestpath/" + + self.graphname, data=data) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -205,8 +82,8 @@ def shortestPath(self, sourceVertices: Union[dict, tuple, list], return ret def allPaths(self, sourceVertices: Union[dict, tuple, list], - targetVertices: Union[dict, tuple, list], maxLength: int, - vertexFilters: Union[list, dict] = None, edgeFilters: Union[list, dict] = None) -> dict: + targetVertices: Union[dict, tuple, list], maxLength: int, + vertexFilters: Union[list, dict] = None, edgeFilters: Union[list, dict] = None) -> dict: """Find all possible paths up to a given maximum path length between the source and target vertex sets. @@ -249,9 +126,10 @@ def allPaths(self, sourceVertices: Union[dict, tuple, list], if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - data = self._preparePathParams(sourceVertices, targetVertices, maxLength, vertexFilters, - edgeFilters) - ret = self._post(self.restppUrl + "/allpaths/" + self.graphname, data=data) + data = _prepare_path_params(sourceVertices, targetVertices, maxLength, vertexFilters, + edgeFilters) + ret = self._post(self.restppUrl + "/allpaths/" + + self.graphname, data=data) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) diff --git a/pyTigerGraph/pyTigerGraphQuery.py b/pyTigerGraph/pyTigerGraphQuery.py index fd90fe03..f8751ecf 100644 --- a/pyTigerGraph/pyTigerGraphQuery.py +++ b/pyTigerGraph/pyTigerGraphQuery.py @@ -5,25 +5,31 @@ """ import json import logging -from datetime import datetime +from datetime import datetime from typing import TYPE_CHECKING, Union, Optional if TYPE_CHECKING: import pandas as pd -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.query import ( + _parse_get_installed_queries, + _parse_query_parameters, + _prep_run_installed_query, + _prep_get_statistics +) from pyTigerGraph.pyTigerGraphSchema import pyTigerGraphSchema -from pyTigerGraph.pyTigerGraphUtils import pyTigerGraphUtils from pyTigerGraph.pyTigerGraphGSQL import pyTigerGraphGSQL + logger = logging.getLogger(__name__) -class pyTigerGraphQuery(pyTigerGraphUtils, pyTigerGraphSchema, pyTigerGraphGSQL): +class pyTigerGraphQuery(pyTigerGraphGSQL, pyTigerGraphSchema): # TODO getQueries() # List _all_ query names def showQuery(self, queryName: str) -> str: """Returns the string of the given GSQL query. - + Args: queryName (str): Name of the query to get metadata of. @@ -51,19 +57,21 @@ def getQueryMetadata(self, queryName: str) -> dict: """ if logger.level == logging.DEBUG: logger.debug("entry: getQueryMetadata") - if self._versionGreaterThan4_0(): + if self._version_greater_than_4_0(): params = {"graph": self.graphname, "queryName": queryName} - res = self._post(self.gsUrl+"/gsql/v1/queries/signature", params=params, authMode="pwd", resKey="") - else: + res = self._post(self.gsUrl+"/gsql/v1/queries/signature", + params=params, authMode="pwd", resKey="") + else: params = {"graph": self.graphname, "query": queryName} - res = self._get(self.gsUrl+"/gsqlserver/gsql/queryinfo", params=params, authMode="pwd", resKey="") - if not res["error"]: + res = self._get(self.gsUrl+"/gsqlserver/gsql/queryinfo", + params=params, authMode="pwd", resKey="") + if not res["error"]: if logger.level == logging.DEBUG: logger.debug("exit: getQueryMetadata") return res else: TigerGraphException(res["message"], res["code"]) - + def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, 'pd.DataFrame']: """Returns a list of installed queries. @@ -86,15 +94,7 @@ def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, 'pd.DataFrame logger.debug("params: " + self._locals(locals())) ret = self.getEndpoints(dynamic=True) - if fmt == "json": - ret = json.dumps(ret) - if fmt == "df": - try: - import pandas as pd - except ImportError: - raise ImportError("Pandas is required to use this function. " - "Download pandas using 'pip install pandas'.") - ret = pd.DataFrame(ret).T + ret = _parse_get_installed_queries(fmt, ret) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -110,62 +110,9 @@ def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, 'pd.DataFrame # GET /gsql/queries/install/{request_id} # xref:tigergraph-server:API:built-in-endpoints.adoc#_check_query_installation_status[Check query installation status] - def _parseQueryParameters(self, params: dict) -> str: - """Parses a dictionary of query parameters and converts them to query strings. - - While most of the values provided for various query parameter types can be easily converted - to query strings (key1=value1&key2=value2), `SET` and `BAG` parameter types, and especially - `VERTEX` and `SET` (i.e. vertex primary ID types without vertex type specification) - require special handling. - - See xref:tigergraph-server:API:built-in-endpoints.adoc#_query_parameter_passing[Query parameter passing] - - TODO Accept this format for SET: - "key": [([p_id1, p_id2, ...], "vtype"), ...] - I.e. multiple primary IDs of the same vertex type - """ - logger.info("entry: _parseQueryParameters") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - ret = "" - for k, v in params.items(): - if isinstance(v, tuple): - if len(v) == 2 and isinstance(v[1], str): - ret += k + "=" + str(v[0]) + "&" + k + ".type=" + self._safeChar(v[1]) + "&" - else: - raise TigerGraphException( - "Invalid parameter value: (vertex_primary_id, vertex_type)" - " was expected.") - elif isinstance(v, list): - i = 0 - for vv in v: - if isinstance(vv, tuple): - if len(vv) == 2 and isinstance(vv[1], str): - ret += k + "[" + str(i) + "]=" + self._safeChar(vv[0]) + "&" + \ - k + "[" + str(i) + "].type=" + vv[1] + "&" - else: - raise TigerGraphException( - "Invalid parameter value: (vertex_primary_id , vertex_type)" - " was expected.") - else: - ret += k + "=" + self._safeChar(vv) + "&" - i += 1 - elif isinstance(v, datetime): - ret += k + "=" + self._safeChar(v.strftime("%Y-%m-%d %H:%M:%S")) + "&" - else: - ret += k + "=" + self._safeChar(v) + "&" - ret = ret[:-1] - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(ret)) - logger.info("exit: _parseQueryParameters") - - return ret - def runInstalledQuery(self, queryName: str, params: Union[str, dict] = None, - timeout: int = None, sizeLimit: int = None, usePost: bool = False, runAsync: bool = False, - replica: int = None, threadLimit: int = None, memoryLimit: int = None) -> list: + timeout: int = None, sizeLimit: int = None, usePost: bool = False, runAsync: bool = False, + replica: int = None, threadLimit: int = None, memoryLimit: int = None) -> list: """Runs an installed query. The query must be already created and installed in the graph. @@ -232,25 +179,11 @@ def runInstalledQuery(self, queryName: str, params: Union[str, dict] = None, if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - headers = {} - res_key = "results" - if timeout and timeout > 0: - headers["GSQL-TIMEOUT"] = str(timeout) - if sizeLimit and sizeLimit > 0: - headers["RESPONSE-LIMIT"] = str(sizeLimit) - if runAsync: - headers["GSQL-ASYNC"] = "true" - res_key = "request_id" - if replica: - headers["GSQL-REPLICA"] = str(replica) - if threadLimit: - headers["GSQL-THREAD-LIMIT"] = str(threadLimit) - if memoryLimit: - headers["GSQL-QueryLocalMemLimitMB"] = str(memoryLimit) - + headers, res_key = _prep_run_installed_query(timeout=timeout, sizeLimit=sizeLimit, runAsync=runAsync, + replica=replica, threadLimit=threadLimit, memoryLimit=memoryLimit) if usePost: - ret = self._post(self.restppUrl + "/query/" + self.graphname + "/" + queryName, - data=params, headers=headers, resKey=res_key, jsonData=True) + ret = self._req("POST", self.restppUrl + "/query/" + self.graphname + "/" + queryName, + data=params, headers=headers, resKey=res_key, jsonData=True) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -259,9 +192,9 @@ def runInstalledQuery(self, queryName: str, params: Union[str, dict] = None, return ret else: if isinstance(params, dict): - params = self._parseQueryParameters(params) - ret = self._get(self.restppUrl + "/query/" + self.graphname + "/" + queryName, - params=params, headers=headers, resKey=res_key) + params = _parse_query_parameters(params) + ret = self._req("GET", self.restppUrl + "/query/" + self.graphname + "/" + queryName, + params=params, headers=headers, resKey=res_key) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -353,15 +286,15 @@ def runInterpretedQuery(self, queryText: str, params: Union[str, dict] = None) - queryText = queryText.replace("$graphname", self.graphname) queryText = queryText.replace("@graphname@", self.graphname) if isinstance(params, dict): - params = self._parseQueryParameters(params) + params = _parse_query_parameters(params) - if self._versionGreaterThan4_0(): + if self._version_greater_than_4_0(): ret = self._post(self.gsUrl + "/gsql/v1/queries/interpret", params=params, data=queryText, authMode="pwd", headers={'Content-Type': 'text/plain'}) else: ret = self._post(self.gsUrl + "/gsqlserver/interpreted_query", data=queryText, - params=params, authMode="pwd") + params=params, authMode="pwd") if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -374,7 +307,8 @@ def getRunningQueries(self) -> dict: """ if logger.level == logging.DEBUG: logger.debug("entry: getRunningQueries") - res = self._get(self.restppUrl+"/showprocesslist/"+self.graphname, resKey="") + res = self._get(self.restppUrl+"/showprocesslist/" + + self.graphname, resKey="") if not res["error"]: if logger.level == logging.DEBUG: logger.debug("exit: getRunningQueries") @@ -385,7 +319,7 @@ def getRunningQueries(self) -> dict: def abortQuery(self, request_id: Union[str, list] = None, url: str = None): """This function safely abortsa a selected query by ID or all queries of an endpoint by endpoint URL of a graph. If neither `request_id` or `url` are specified, all queries currently running on the graph are aborted. - + Args: request_id (str, list, optional): The ID(s) of the query(s) to abort. If set to "all", it will abort all running queries. @@ -398,9 +332,10 @@ def abortQuery(self, request_id: Union[str, list] = None, url: str = None): params["requestid"] = request_id if url: params["url"] = url - res = self._get(self.restppUrl+"/abortquery/"+self.graphname, params=params, resKey="") + res = self._get(self.restppUrl+"/abortquery/" + + self.graphname, params=params, resKey="") if not res["error"]: - if logger.level == logging.DEBUG: + if logger.level == logging.DEBUG: logger.debug("exit: abortQuery") return res else: @@ -531,7 +466,7 @@ def addOccurrences(obj: dict, src: str): # Then handle the edge itself eId = o3["from_type"] + "(" + o3["from_id"] + ")->" + o3["to_type"] + \ - "(" + o3["to_id"] + ")" + "(" + o3["to_id"] + ")" o3["e_id"] = eId # Add reverse edge name, if applicable @@ -589,17 +524,9 @@ def getStatistics(self, seconds: int = 10, segments: int = 10) -> dict: if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - if not seconds: - seconds = 10 - else: - seconds = max(min(seconds, 0), 60) - if not segments: - segments = 10 - else: - segments = max(min(segments, 0), 100) - - ret = self._get(self.restppUrl + "/statistics/" + self.graphname + "?seconds=" + - str(seconds) + "&segment=" + str(segments), resKey="") + seconds, segments = _prep_get_statistics(self, seconds, segments) + ret = self._req("GET", self.restppUrl + "/statistics/" + self.graphname + "?seconds=" + + str(seconds) + "&segment=" + str(segments), resKey="") if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -609,7 +536,7 @@ def getStatistics(self, seconds: int = 10, segments: int = 10) -> dict: def describeQuery(self, queryName: str, queryDescription: str, parameterDescriptions: dict = {}): """Add a query description and parameter descriptions. Only supported on versions of TigerGraph >= 4.0.0. - + Args: queryName: The name of the query to describe. @@ -617,7 +544,7 @@ def describeQuery(self, queryName: str, queryDescription: str, parameterDescript A description of the query. parameterDescriptions (optional): A dictionary of parameter descriptions. The keys are the parameter names and the values are the descriptions. - + Returns: The response from the database. @@ -630,44 +557,47 @@ def describeQuery(self, queryName: str, queryDescription: str, parameterDescript major_ver, minor_ver, patch_ver = self.ver.split(".") if int(major_ver) < 4: logger.info("exit: describeQuery") - raise TigerGraphException("This function is only supported on versions of TigerGraph >= 4.0.0.", 0) - + raise TigerGraphException( + "This function is only supported on versions of TigerGraph >= 4.0.0.", 0) + if parameterDescriptions: params = {"queries": [ {"queryName": queryName, - "description": queryDescription, - "parameters": [{"paramName": k, "description": v} for k, v in parameterDescriptions.items()]} + "description": queryDescription, + "parameters": [{"paramName": k, "description": v} for k, v in parameterDescriptions.items()]} ]} else: params = {"queries": [ {"queryName": queryName, - "description": queryDescription} + "description": queryDescription} ]} if logger.level == logging.DEBUG: logger.debug("params: " + params) - if self._versionGreaterThan4_0(): - res = self._put(self.gsUrl+"/gsql/v1/description?graph="+self.graphname, data=params, authMode="pwd", jsonData=True) + if self._version_greater_than_4_0(): + res = self._put(self.gsUrl+"/gsql/v1/description?graph=" + + self.graphname, data=params, authMode="pwd", jsonData=True) else: - res = self._put(self.gsUrl+"/gsqlserver/gsql/description?graph="+self.graphname, data=params, authMode="pwd", jsonData=True) + res = self._put(self.gsUrl+"/gsqlserver/gsql/description?graph=" + + self.graphname, data=params, authMode="pwd", jsonData=True) if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) logger.info("exit: describeQuery") return res - + def getQueryDescription(self, queryName: Optional[Union[str, list]] = "all"): """Get the description of a query. Only supported on versions of TigerGraph >= 4.0.0. - + Args: queryName: The name of the query to get the description of. If multiple query descriptions are desired, pass a list of query names. If set to "all", returns the description of all queries. - + Returns: The description of the query(ies). - + Endpoints: - `GET /gsqlserver/gsql/description?graph={graph_name}` (In TigerGraph version 4.0) - `GET /gsql/v1/description?graph={graph_name}` (In TigerGraph versions >4.0) @@ -677,38 +607,41 @@ def getQueryDescription(self, queryName: Optional[Union[str, list]] = "all"): major_ver, minor_ver, patch_ver = self.ver.split(".") if int(major_ver) < 4: logger.info("exit: getQueryDescription") - raise TigerGraphException("This function is only supported on versions of TigerGraph >= 4.0.0.", 0) - + raise TigerGraphException( + "This function is only supported on versions of TigerGraph >= 4.0.0.", 0) + if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - + if isinstance(queryName, list): queryName = ",".join(queryName) - if self._versionGreaterThan4_0(): - res = self._get(self.gsUrl+"/gsql/v1/description?graph="+self.graphname+"&query="+queryName, authMode="pwd", resKey=None) - else: - res = self._get(self.gsUrl+"/gsqlserver/gsql/description?graph="+self.graphname+"&query="+queryName, authMode="pwd", resKey=None) + if self._version_greater_than_4_0(): + res = self._get(self.gsUrl+"/gsql/v1/description?graph=" + + self.graphname+"&query="+queryName, authMode="pwd", resKey=None) + else: + res = self._get(self.gsUrl+"/gsqlserver/gsql/description?graph=" + + self.graphname+"&query="+queryName, authMode="pwd", resKey=None) if not res["error"]: if logger.level == logging.DEBUG: logger.debug("exit: getQueryDescription") return res["results"]["queries"] else: raise TigerGraphException(res["message"], res["code"]) - + def dropQueryDescription(self, queryName: str, dropParamDescriptions: bool = True): """Drop the description of a query. Only supported on versions of TigerGraph >= 4.0.0. - + Args: queryName: The name of the query to drop the description of. If set to "*", drops the description of all queries. dropParamDescriptions: Whether to drop the parameter descriptions as well. Defaults to True. - + Returns: The response from the database. - + Endpoints: - `DELETE /gsqlserver/gsql/description?graph={graph_name}` (In TigerGraph version 4.0) - `DELETE /gsql/v1/description?graph={graph_name}` (In TigerGraph versions >4.0) @@ -718,22 +651,26 @@ def dropQueryDescription(self, queryName: str, dropParamDescriptions: bool = Tru major_ver, minor_ver, patch_ver = self.ver.split(".") if int(major_ver) < 4: logger.info("exit: describeQuery") - raise TigerGraphException("This function is only supported on versions of TigerGraph >= 4.0.0.", 0) - + raise TigerGraphException( + "This function is only supported on versions of TigerGraph >= 4.0.0.", 0) + if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) if dropParamDescriptions: - params = {"queries": [queryName], "queryParameters": [queryName+".*"]} + params = {"queries": [queryName], + "queryParameters": [queryName+".*"]} else: params = {"queries": [queryName]} print(params) - if self._versionGreaterThan4_0(): - res = self._delete(self.gsUrl+"/gsql/v1/description?graph="+self.graphname, authMode="pwd", data=params, jsonData=True, resKey=None) + if self._version_greater_than_4_0(): + res = self._delete(self.gsUrl+"/gsql/v1/description?graph="+self.graphname, + authMode="pwd", data=params, jsonData=True, resKey=None) else: - res = self._delete(self.gsUrl+"/gsqlserver/gsql/description?graph="+self.graphname, authMode="pwd", data=params, jsonData=True, resKey=None) - + res = self._delete(self.gsUrl+"/gsqlserver/gsql/description?graph=" + + self.graphname, authMode="pwd", data=params, jsonData=True, resKey=None) + if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) logger.info("exit: dropQueryDescription") - - return res \ No newline at end of file + + return res diff --git a/pyTigerGraph/pyTigerGraphSchema.py b/pyTigerGraph/pyTigerGraphSchema.py index 4b18de75..5adff214 100644 --- a/pyTigerGraph/pyTigerGraphSchema.py +++ b/pyTigerGraph/pyTigerGraphSchema.py @@ -6,8 +6,13 @@ import json import logging import re + from typing import Union +from pyTigerGraph.common.schema import ( + _prep_upsert_data, + _prep_get_endpoints +) from pyTigerGraph.pyTigerGraphBase import pyTigerGraphBase logger = logging.getLogger(__name__) @@ -27,12 +32,12 @@ def _getUDTs(self) -> dict: """ logger.info("entry: _getUDTs") - if self._versionGreaterThan4_0(): + if self._version_greater_than_4_0(): res = self._get(self.gsUrl + "/gsql/v1/udt/tuples?graph=" + self.graphname, - authMode="pwd") - else: + authMode="pwd") + else: res = self._get(self.gsUrl + "/gsqlserver/gsql/udtlist?graph=" + self.graphname, - authMode="pwd") + authMode="pwd") if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -40,66 +45,6 @@ def _getUDTs(self) -> dict: return res - def _getAttrType(self, attrType: dict) -> str: - """Returns attribute data type in simple format. - - Args: - attribute: - The details of the attribute's data type. - - Returns: - Either "(scalar_type)" or "(complex_type, scalar_type)" string. - """ - ret = attrType["Name"] - if "KeyTypeName" in attrType: - ret += "(" + attrType["KeyTypeName"] + "," + attrType["ValueTypeName"] + ")" - elif "ValueTypeName" in attrType: - ret += "(" + attrType["ValueTypeName"] + ")" - - return ret - - def _upsertAttrs(self, attributes: dict) -> dict: - """Transforms attributes (provided as a table) into a hierarchy as expected by the upsert - functions. - - Args: - attributes: A dictionary of attribute/value pairs (with an optional operator) in this - format: - {: |(, ), …} - - Returns: - A dictionary in this format: - { - : {"value": }, - : {"value": , "op": } - } - - Documentation: - xref:tigergraph-server:API:built-in-endpoints.adoc#operation-codes[Operation codes] - """ - logger.info("entry: _upsertAttrs") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) - - if not isinstance(attributes, dict): - return {} - # TODO Should return something else or raise exception? - vals = {} - for attr in attributes: - val = attributes[attr] - if isinstance(val, tuple): - vals[attr] = {"value": val[0], "op": val[1]} - elif isinstance(val, dict): - vals[attr] = {"value": {"keylist": list(val.keys()), "valuelist": list(val.values())}} - else: - vals[attr] = {"value": val} - - if logger.level == logging.DEBUG: - logger.debug("return: " + str(vals)) - logger.info("exit: _upsertAttrs") - - return vals - def getSchema(self, udts: bool = True, force: bool = False) -> dict: """Retrieves the schema metadata (of all vertex and edge type and, if not disabled, the User-Defined Type details) of the graph. @@ -124,12 +69,12 @@ def getSchema(self, udts: bool = True, force: bool = False) -> dict: logger.debug("params: " + self._locals(locals())) if not self.schema or force: - if self._versionGreaterThan4_0(): + if self._version_greater_than_4_0(): self.schema = self._get(self.gsUrl + "/gsql/v1/schema/graphs/" + self.graphname, - authMode="pwd") + authMode="pwd") else: self.schema = self._get(self.gsUrl + "/gsqlserver/gsql/schema?graph=" + self.graphname, - authMode="pwd") + authMode="pwd") if udts and ("UDTs" not in self.schema or force): self.schema["UDTs"] = self._getUDTs() @@ -140,8 +85,8 @@ def getSchema(self, udts: bool = True, force: bool = False) -> dict: return self.schema def upsertData(self, data: Union[str, object], atomic: bool = False, ackAll: bool = False, - newVertexOnly: bool = False, vertexMustExist: bool = False, - updateVertexOnly: bool = False) -> dict: + newVertexOnly: bool = False, vertexMustExist: bool = False, + updateVertexOnly: bool = False) -> dict: """Upserts data (vertices and edges) from a JSON file or a file with equivalent object structure. Args: @@ -176,23 +121,11 @@ def upsertData(self, data: Union[str, object], atomic: bool = False, ackAll: boo if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - if not isinstance(data, str): - data = json.dumps(data) - headers = {} - if atomic: - headers["gsql-atomic-level"] = "atomic" - params = {} - if ackAll: - params["ack"] = "all" - if newVertexOnly: - params["new_vertex_only"] = True - if vertexMustExist: - params["vertex_must_exist"] = True - if updateVertexOnly: - params["update_vertex_only"] = True + data, headers, params = _prep_upsert_data(data=data, atomic=atomic, ackAll=ackAll, newVertexOnly=newVertexOnly, + vertexMustExist=vertexMustExist, updateVertexOnly=updateVertexOnly) res = self._post(self.restppUrl + "/graph/" + self.graphname, headers=headers, data=data, - params=params)[0] + params=params)[0] if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -201,7 +134,7 @@ def upsertData(self, data: Union[str, object], atomic: bool = False, ackAll: boo return res def getEndpoints(self, builtin: bool = False, dynamic: bool = False, - static: bool = False) -> dict: + static: bool = False) -> dict: """Lists the REST++ endpoints and their parameters. Args: @@ -222,30 +155,29 @@ def getEndpoints(self, builtin: bool = False, dynamic: bool = False, if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - ret = {} - if not (builtin or dynamic or static): - bui = dyn = sta = True - else: - bui = builtin - dyn = dynamic - sta = static - url = self.restppUrl + "/endpoints/" + self.graphname + "?" + bui, dyn, sta, url, ret = _prep_get_endpoints( + restppUrl=self.restppUrl, + graphname=self.graphname, + builtin=builtin, + dynamic=dynamic, + static=static + ) if bui: eps = {} - res = self._get(url + "builtin=true", resKey="") + res = self._req("GET", url + "builtin=true", resKey="") for ep in res: if not re.search(" /graph/", ep) or re.search(" /graph/{graph_name}/", ep): eps[ep] = res[ep] ret.update(eps) if dyn: eps = {} - res = self._get(url + "dynamic=true", resKey="") + res = self._req("GET", url + "dynamic=true", resKey="") for ep in res: if re.search("^GET /query/" + self.graphname, ep): eps[ep] = res[ep] ret.update(eps) if sta: - ret.update(self._get(url + "static=true", resKey="")) + ret.update(self._req("GET", url + "static=true", resKey="")) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) diff --git a/pyTigerGraph/pyTigerGraphUtils.py b/pyTigerGraph/pyTigerGraphUtils.py index e10be18d..849b0bed 100644 --- a/pyTigerGraph/pyTigerGraphUtils.py +++ b/pyTigerGraph/pyTigerGraphUtils.py @@ -5,36 +5,23 @@ """ import json import logging -import urllib -from typing import Any, Union + +from typing import Any, TYPE_CHECKING from urllib.parse import urlparse -import requests -from typing import TYPE_CHECKING, Union +from pyTigerGraph.common.util import ( + _parse_get_license_info, + _prep_get_system_metrics +) +from pyTigerGraph.common.exception import TigerGraphException from pyTigerGraph.pyTigerGraphBase import pyTigerGraphBase -from pyTigerGraph.pyTigerGraphException import TigerGraphException logger = logging.getLogger(__name__) class pyTigerGraphUtils(pyTigerGraphBase): - def _safeChar(self, inputString: Any) -> str: - """Replace special characters in string using the %xx escape. - - Args: - inputString: - The string to process - - Returns: - Processed string. - - Documentation: - https://docs.python.org/3/library/urllib.parse.html#url-quoting - """ - return urllib.parse.quote(str(inputString), safe='') - def echo(self, usePost: bool = False) -> str: """Pings the database. @@ -80,18 +67,9 @@ def getLicenseInfo(self) -> dict: """ logger.info("entry: getLicenseInfo") - res = self._get(self.restppUrl + "/showlicenseinfo", resKey="", skipCheck=True) - ret = {} - if not res["error"]: - ret["message"] = res["message"] - ret["expirationDate"] = res["results"][0]["Expiration date"] - ret["daysRemaining"] = res["results"][0]["Days remaining"] - elif "code" in res and res["code"] == "REST-5000": - ret["message"] = \ - "This instance does not have a valid enterprise license. Is this a trial version?" - ret["daysRemaining"] = -1 - else: - raise TigerGraphException(res["message"], res["code"]) + res = self._req("GET", self.restppUrl + + "/showlicenseinfo", resKey="", skipCheck=True) + ret = _parse_get_license_info(res) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -115,9 +93,9 @@ def ping(self) -> dict: else: raise TigerGraphException(res["message"], res["code"]) - def getSystemMetrics(self, from_ts:int = None, to_ts:int = None, latest:int = None, what:str = None, who:str = None, where:str = None): + def getSystemMetrics(self, from_ts: int = None, to_ts: int = None, latest: int = None, what: str = None, who: str = None, where: str = None): """Monitor system usage metrics. - + Args: from_ts (int, optional): The epoch timestamp that indicates the start of the time filter. @@ -151,44 +129,33 @@ def getSystemMetrics(self, from_ts:int = None, to_ts:int = None, latest:int = No """ if logger.level == logging.DEBUG: logger.debug("entry: getSystemMetrics") - params = {} - _json = {} # in >=4.1 we need a json request of different parameter names - if from_ts or to_ts: - _json["TimeRange"] = {} - if from_ts: - params["from"] = from_ts - _json['TimeRange']['StartTimestampNS'] = str(from_ts) - if to_ts: - params["to"] = to_ts - _json['TimeRange']['EndTimestampNS'] = str(from_ts) - if latest: - params["latest"] = latest - _json["LatestNum"] = str(latest) + + params, _json = _prep_get_system_metrics( + from_ts=from_ts, to_ts=to_ts, latest=latest, who=who, where=where) + + # Couldn't be placed in prep since version checking requires await statements if what: - if self._versionGreaterThan4_0(): + if self._version_greater_than_4_0(): if what == "servicestate" or what == "connection": - raise TigerGraphException("This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", 0) + raise TigerGraphException( + "This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", 0) if what == "cpu" or what == "mem": - what = "cpu-memory" # in >=4.1 cpu and mem have been conjoined into one category + what = "cpu-memory" # in >=4.1 cpu and mem have been conjoined into one category params["what"] = what - if who: - params["who"] = who - if where: - params["where"] = where - _json["HostID"] = where # in >=4.1 the datapoints endpoint has been removed and replaced - if self._versionGreaterThan4_0(): - res = self._post(self.gsUrl+"/informant/metrics/get/"+what, data=_json, jsonData=True, resKey="") + if self._version_greater_than_4_0(): + res = self._req("POST", self.gsUrl+"/informant/metrics/get/" + + what, data=_json, jsonData=True, resKey="") else: - res = self._get(self.gsUrl+"/ts3/api/datapoints", authMode="pwd", params=params, resKey="") + res = self._req("GET", self.gsUrl+"/ts3/api/datapoints", + authMode="pwd", params=params, resKey="") if logger.level == logging.DEBUG: logger.debug("exit: getSystemMetrics") return res - - def getQueryPerformance(self, seconds:int = 10): + def getQueryPerformance(self, seconds: int = 10): """Returns real-time query performance statistics over the given time period, as specified by the seconds parameter. - + Args: seconds (int, optional): Seconds are measured up to 60, so the seconds parameter must be a positive integer less than or equal to 60. @@ -199,7 +166,8 @@ def getQueryPerformance(self, seconds:int = 10): params = {} if seconds: params["seconds"] = seconds - res = self._get(self.restppUrl+"/statistics/"+self.graphname, params=params, resKey="") + res = self._get(self.restppUrl+"/statistics/" + + self.graphname, params=params, resKey="") if logger.level == logging.DEBUG: logger.debug("exit: getQueryPerformance") return res @@ -214,7 +182,8 @@ def getServiceStatus(self, request_body: dict): """ if logger.level == logging.DEBUG: logger.debug("entry: getServiceStatus") - res = self._post(self.gsUrl+"/informant/current-service-status", data=json.dumps(request_body), resKey="") + res = self._post(self.gsUrl+"/informant/current-service-status", + data=json.dumps(request_body), resKey="") if logger.level == logging.DEBUG: logger.debug("exit: getServiceStatus") return res @@ -251,11 +220,11 @@ def rebuildGraph(self, threadnum: int = None, vertextype: str = "", segid: str = params["path"] = path if force: params["force"] = force - res = self._get(self.restppUrl+"/rebuildnow/"+self.graphname, params=params, resKey="") + res = self._get(self.restppUrl+"/rebuildnow/" + + self.graphname, params=params, resKey="") if not res["error"]: if logger.level == logging.DEBUG: logger.debug("exit: rebuildGraph") return res else: raise TigerGraphException(res["message"], res["code"]) - \ No newline at end of file diff --git a/pyTigerGraph/pyTigerGraphVertex.py b/pyTigerGraph/pyTigerGraphVertex.py index 5c230b63..899e5d87 100644 --- a/pyTigerGraph/pyTigerGraphVertex.py +++ b/pyTigerGraph/pyTigerGraphVertex.py @@ -13,7 +13,21 @@ if TYPE_CHECKING: import pandas as pd -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.vertex import ( + _parse_get_vertex_count, + _prep_upsert_vertex_dataframe, + _prep_get_vertices, + _prep_get_vertices_by_id, + _parse_get_vertex_stats, + _prep_del_vertices, + _prep_del_vertices_by_id +) + +from pyTigerGraph.common.schema import _upsert_attrs +from pyTigerGraph.common.util import _safe_char +from pyTigerGraph.common.vertex import vertexSetToDataFrame as _vS2DF + from pyTigerGraph.pyTigerGraphSchema import pyTigerGraphSchema from pyTigerGraph.pyTigerGraphUtils import pyTigerGraphUtils @@ -70,7 +84,8 @@ def getVertexAttrs(self, vertexType: str) -> list: ret = [] for at in et["Attributes"]: - ret.append((at["AttributeName"], self._getAttrType(at["AttributeType"]))) + ret.append( + (at["AttributeName"], self._getAttrType(at["AttributeType"]))) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -152,12 +167,13 @@ def getVertexCount(self, vertexType: Union[str, list] = "*", where: str = "", re # If WHERE condition is not specified, use /builtins else use /vertices if isinstance(vertexType, str) and vertexType != "*": if where: - res = self._get(self.restppUrl + "/graph/" + self.graphname + "/vertices/" + vertexType - + "?count_only=true" + "&filter=" + where)[0]["count"] + res = self._req("GET", self.restppUrl + "/graph/" + self.graphname + "/vertices/" + vertexType + + "?count_only=true" + "&filter=" + where)[0]["count"] else: - res = self._post(self.restppUrl + "/builtins/" + self.graphname + ("?realtime=true" if realtime else ""), - data={"function": "stat_vertex_number", "type": vertexType}, - jsonData=True)[0]["count"] + res = self._req("POST", self.restppUrl + "/builtins/" + self.graphname + ("?realtime=true" if realtime else ""), + data={"function": "stat_vertex_number", + "type": vertexType}, + jsonData=True)[0]["count"] if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) @@ -165,21 +181,11 @@ def getVertexCount(self, vertexType: Union[str, list] = "*", where: str = "", re return res - if where: - if vertexType == "*": - raise TigerGraphException( - "VertexType cannot be \"*\" if where condition is specified.", None) - else: - raise TigerGraphException( - "VertexType cannot be a list if where condition is specified.", None) - - res = self._post(self.restppUrl + "/builtins/" + self.graphname + ("?realtime=true" if realtime else ""), - data={"function": "stat_vertex_number", "type": "*"}, - jsonData=True) - ret = {d["v_type"]: d["count"] for d in res} + res = self._req("POST", self.restppUrl + "/builtins/" + self.graphname + ("?realtime=true" if realtime else ""), + data={"function": "stat_vertex_number", "type": "*"}, + jsonData=True) - if isinstance(vertexType, list): - ret = {vt: ret[vt] for vt in vertexType} + ret = _parse_get_vertex_count(res, vertexType, where) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -223,10 +229,11 @@ def upsertVertex(self, vertexType: str, vertexId: str, attributes: dict = None) if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - vals = self._upsertAttrs(attributes) + vals = _upsert_attrs(attributes) data = json.dumps({"vertices": {vertexType: {vertexId: vals}}}) - ret = self._post(self.restppUrl + "/graph/" + self.graphname, data=data)[0]["accepted_vertices"] + ret = self._req("POST", self.restppUrl + "/graph/" + + self.graphname, data=data)[0]["accepted_vertices"] if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -279,11 +286,12 @@ def upsertVertices(self, vertexType: str, vertices: list) -> int: data = {} for v in vertices: - vals = self._upsertAttrs(v[1]) + vals = _upsert_attrs(v[1]) data[v[0]] = vals data = json.dumps({"vertices": {vertexType: data}}) - ret = self._post(self.restppUrl + "/graph/" + self.graphname, data=data)[0]["accepted_vertices"] + ret = self._req("POST", self.restppUrl + "/graph/" + + self.graphname, data=data)[0]["accepted_vertices"] if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -292,7 +300,7 @@ def upsertVertices(self, vertexType: str, vertices: list) -> int: return ret def upsertVertexDataFrame(self, df: 'pd.DataFrame', vertexType: str, v_id: bool = None, - attributes: dict = None) -> int: + attributes: dict = None) -> int: """Upserts vertices from a Pandas DataFrame. Args: @@ -316,16 +324,8 @@ def upsertVertexDataFrame(self, df: 'pd.DataFrame', vertexType: str, v_id: bool if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - json_up = [] - - for index in df.index: - json_up.append(json.loads(df.loc[index].to_json())) - json_up[-1] = ( - index if v_id is None else json_up[-1][v_id], - json_up[-1] if attributes is None - else {target: json_up[-1][source] for target, source in attributes.items()} - ) - + json_up = _prep_upsert_vertex_dataframe( + df=df, v_id=v_id, attributes=attributes) ret = self.upsertVertices(vertexType=vertexType, vertices=json_up) if logger.level == logging.DEBUG: @@ -335,8 +335,8 @@ def upsertVertexDataFrame(self, df: 'pd.DataFrame', vertexType: str, v_id: bool return ret def getVertices(self, vertexType: str, select: str = "", where: str = "", - limit: Union[int, str] = None, sort: str = "", fmt: str = "py", withId: bool = True, - withType: bool = False, timeout: int = 0) -> Union[dict, str, 'pd.DataFrame']: + limit: Union[int, str] = None, sort: str = "", fmt: str = "py", withId: bool = True, + withType: bool = False, timeout: int = 0) -> Union[dict, str, 'pd.DataFrame']: """Retrieves vertices of the given vertex type. *Note*: @@ -383,29 +383,21 @@ def getVertices(self, vertexType: str, select: str = "", where: str = "", if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - url = self.restppUrl + "/graph/" + self.graphname + "/vertices/" + vertexType - isFirst = True - if select: - url += "?select=" + select - isFirst = False - if where: - url += ("?" if isFirst else "&") + "filter=" + where - isFirst = False - if limit: - url += ("?" if isFirst else "&") + "limit=" + str(limit) - isFirst = False - if sort: - url += ("?" if isFirst else "&") + "sort=" + sort - isFirst = False - if timeout and timeout > 0: - url += ("?" if isFirst else "&") + "timeout=" + str(timeout) - - ret = self._get(url) + url = _prep_get_vertices( + restppUrl=self.restppUrl, + graphname=self.graphname, + vertexType=vertexType, + select=select, + where=where, + limit=limit, + sort=sort, + timeout=timeout) + ret = self._req("GET", url) if fmt == "json": ret = json.dumps(ret) elif fmt == "df": - ret = self.vertexSetToDataFrame(ret, withId, withType) + ret = _vS2DF(ret, withId, withType) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -414,7 +406,7 @@ def getVertices(self, vertexType: str, select: str = "", where: str = "", return ret def getVertexDataFrame(self, vertexType: str, select: str = "", where: str = "", - limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': + limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': """Retrieves vertices of the given vertex type and returns them as pandas DataFrame. This is a shortcut to `getVertices(..., fmt="df", withId=True, withType=False)`. @@ -451,7 +443,7 @@ def getVertexDataFrame(self, vertexType: str, select: str = "", where: str = "", logger.debug("params: " + self._locals(locals())) ret = self.getVertices(vertexType, select=select, where=where, limit=limit, sort=sort, - fmt="df", withId=True, withType=False, timeout=timeout) + fmt="df", withId=True, withType=False, timeout=timeout) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -460,7 +452,7 @@ def getVertexDataFrame(self, vertexType: str, select: str = "", where: str = "", return ret def getVertexDataframe(self, vertexType: str, select: str = "", where: str = "", - limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': + limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': """DEPRECATED Use `getVertexDataFrame()` instead. @@ -470,11 +462,11 @@ def getVertexDataframe(self, vertexType: str, select: str = "", where: str = "", DeprecationWarning) return self.getVertexDataFrame(vertexType, select=select, where=where, limit=limit, - sort=sort, timeout=timeout) + sort=sort, timeout=timeout) def getVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], select: str = "", - fmt: str = "py", withId: bool = True, withType: bool = False, - timeout: int = 0) -> Union[list, str, 'pd.DataFrame']: + fmt: str = "py", withId: bool = True, withType: bool = False, + timeout: int = 0) -> Union[list, str, 'pd.DataFrame']: """Retrieves vertices of the given vertex type, identified by their ID. Args: @@ -510,23 +502,21 @@ def getVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], sel if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - if not vertexIds: - raise TigerGraphException("No vertex ID was specified.", None) - vids = [] - if isinstance(vertexIds, (int, str)): - vids.append(vertexIds) - else: - vids = vertexIds - url = self.restppUrl + "/graph/" + self.graphname + "/vertices/" + vertexType + "/" + vids, url = _prep_get_vertices_by_id( + restppUrl=self.restppUrl, + graphname=self.graphname, + vertexIds=vertexIds, + vertexType=vertexType + ) ret = [] for vid in vids: - ret += self._get(url + self._safeChar(vid)) + ret += self._req("GET", url + _safe_char(vid)) if fmt == "json": ret = json.dumps(ret) elif fmt == "df": - ret = self.vertexSetToDataFrame(ret, withId, withType) + ret = _vS2DF(ret, withId, withType) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -535,7 +525,7 @@ def getVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], sel return ret def getVertexDataFrameById(self, vertexType: str, vertexIds: Union[int, str, list], - select: str = "") -> 'pd.DataFrame': + select: str = "") -> 'pd.DataFrame': """Retrieves vertices of the given vertex type, identified by their ID. This is a shortcut to ``getVerticesById(..., fmt="df", withId=True, withType=False)``. @@ -556,7 +546,7 @@ def getVertexDataFrameById(self, vertexType: str, vertexIds: Union[int, str, lis logger.debug("params: " + self._locals(locals())) ret = self.getVerticesById(vertexType, vertexIds, select, fmt="df", withId=True, - withType=False) + withType=False) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -565,7 +555,7 @@ def getVertexDataFrameById(self, vertexType: str, vertexIds: Union[int, str, lis return ret def getVertexDataframeById(self, vertexType: str, vertexIds: Union[int, str, list], - select: str = "") -> 'pd.DataFrame': + select: str = "") -> 'pd.DataFrame': """DEPRECATED Use `getVertexDataFrameById()` instead. @@ -606,22 +596,14 @@ def getVertexStats(self, vertexTypes: Union[str, list], skipNA: bool = False) -> else: vts = vertexTypes - ret = {} + responses = [] for vt in vts: data = '{"function":"stat_vertex_attr","type":"' + vt + '"}' - res = self._post(self.restppUrl + "/builtins/" + self.graphname, data=data, resKey="", - skipCheck=True) - if res["error"]: - if "stat_vertex_attr is skip" in res["message"]: - if not skipNA: - ret[vt] = {} - else: - raise TigerGraphException(res["message"], - (res["code"] if "code" in res else None)) - else: - res = res["results"] - for r in res: - ret[r["v_type"]] = r["attributes"] + res = self._req("POST", self.restppUrl + "/builtins/" + self.graphname, data=data, resKey="", + skipCheck=True) + responses.append((vt, res)) + + ret = _parse_get_vertex_stats(responses, skipNA) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -630,7 +612,7 @@ def getVertexStats(self, vertexTypes: Union[str, list], skipNA: bool = False) -> return ret def delVertices(self, vertexType: str, where: str = "", limit: str = "", sort: str = "", - permanent: bool = False, timeout: int = 0) -> int: + permanent: bool = False, timeout: int = 0) -> int: """Deletes vertices from graph. *Note*: @@ -671,21 +653,17 @@ def delVertices(self, vertexType: str, where: str = "", limit: str = "", sort: s if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - url = self.restppUrl + "/graph/" + self.graphname + "/vertices/" + vertexType - isFirst = True - if where: - url += "?filter=" + where - isFirst = False - if limit and sort: # These two must be provided together - url += ("?" if isFirst else "&") + "limit=" + str(limit) + "&sort=" + sort - isFirst = False - if permanent: - url += ("?" if isFirst else "&") + "permanent=true" - isFirst = False - if timeout and timeout > 0: - url += ("?" if isFirst else "&") + "timeout=" + str(timeout) - - ret = self._delete(url)["deleted_vertices"] + url = _prep_del_vertices( + restppUrl=self.restppUrl, + graphname=self.graphname, + vertexType=vertexType, + where=where, + limit=limit, + sort=sort, + permanent=permanent, + timeout=timeout + ) + ret = self._req("DELETE", url)["deleted_vertices"] if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -694,7 +672,7 @@ def delVertices(self, vertexType: str, where: str = "", limit: str = "", sort: s return ret def delVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], - permanent: bool = False, timeout: int = 0) -> int: + permanent: bool = False, timeout: int = 0) -> int: """Deletes vertices from graph identified by their ID. Args: @@ -719,23 +697,18 @@ def delVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - if not vertexIds: - raise TigerGraphException("No vertex ID was specified.", None) - vids = [] - if isinstance(vertexIds, (int, str)): - vids.append(self._safeChar(vertexIds)) - else: - vids = [self._safeChar(f) for f in vertexIds] - - url1 = self.restppUrl + "/graph/" + self.graphname + "/vertices/" + vertexType + "/" - url2 = "" - if permanent: - url2 = "?permanent=true" - if timeout and timeout > 0: - url2 += ("&" if url2 else "?") + "timeout=" + str(timeout) + url1, url2, vids = _prep_del_vertices_by_id( + restppUrl=self.restppUrl, + graphname=self.graphname, + vertexIds=vertexIds, + vertexType=vertexType, + permanent=permanent, + timeout=timeout + ) ret = 0 for vid in vids: - ret += self._delete(url1 + str(vid) + url2)["deleted_vertices"] + res = self._req("DELETE", url1 + str(vid) + url2) + ret += res["deleted_vertices"] if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -784,69 +757,28 @@ def delVerticesByType(self, vertexType: str, permanent: bool = False, ack: str = # TODO GET /deleted_vertex_check/{graph_name} - def vertexSetToDataFrame(self, vertexSet: list, withId: bool = True, - withType: bool = False) -> 'pd.DataFrame': - """Converts a vertex set to Pandas DataFrame. - - Vertex sets are used for both the input and output of `SELECT` statements. They contain - instances of vertices of the same type. - For each vertex instance, the vertex ID, the vertex type and the (optional) attributes are - present under the `v_id`, `v_type` and `attributes` keys, respectively. / - See an example in `edgeSetToDataFrame()`. - - A vertex set has this structure (when serialised as JSON): - [source.wrap,json] - ---- - [ - { - "v_id": , - "v_type": , - "attributes": - { - "attr1": , - "attr2": , - ⋮ - } - }, - ⋮ - ] - ---- - For more information on vertex sets see xref:gsql-ref:querying:declaration-and-assignment-statements.adoc#_vertex_set_variables[Vertex set variables]. + def vertexSetToDataFrame(self, vertexSet: dict, withId: bool = True, withType: bool = False) -> 'pd.DataFrame': + """Converts a vertex set (dictionary) to a pandas DataFrame. Args: vertexSet: - A JSON array containing a vertex set in the format returned by queries (see below). + The vertex set to convert. withId: - Whether to include vertex primary ID as a column. + Should the vertex ID be included in the DataFrame? withType: - Whether to include vertex type info as a column. + Should the vertex type be included in the DataFrame? Returns: - A pandas DataFrame containing the vertex attributes (and optionally the vertex primary - ID and type). + The vertex set as a pandas DataFrame. """ logger.info("entry: vertexSetToDataFrame") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - try: - import pandas as pd - except ImportError: - raise ImportError("Pandas is required to use this function. " - "Download pandas using 'pip install pandas'.") - - df = pd.DataFrame(vertexSet) - cols = [] - if withId: - cols.append(df["v_id"]) - if withType: - cols.append(df["v_type"]) - cols.append(pd.DataFrame(df["attributes"].tolist())) - - ret = pd.concat(cols, axis=1) + ret = _vS2DF(vertexSet, withId, withType) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) logger.info("exit: vertexSetToDataFrame") - return ret + return ret \ No newline at end of file diff --git a/pyTigerGraph/pytgasync/__init__.py b/pyTigerGraph/pytgasync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyTigerGraph/pytgasync/datasets.py b/pyTigerGraph/pytgasync/datasets.py new file mode 100644 index 00000000..ddd806ba --- /dev/null +++ b/pyTigerGraph/pytgasync/datasets.py @@ -0,0 +1,160 @@ +"""Datasets + +Stock datasets that can be ingested into a TigerGraph database through the `ingestDataset` +function in pyTigerGraph. +""" +import json +import tarfile +import warnings +from abc import ABC, abstractmethod +from os import makedirs +from os.path import isdir +from os.path import join as pjoin +from shutil import rmtree +from urllib.parse import urljoin +from typing_extensions import Self +import io + +import httpx + +from pyTigerGraph.datasets import Datasets, BaseDataset + + +class AsyncDatasets(Datasets): + # using a factory method instead of __init__ since we need to call asynchronous methods when creating this. + # AsyncDatasets can only be created by calling: dataset = await AsyncDatasets.create() + + @classmethod + async def create(cls, name: str = None, tmp_dir: str = "./tmp") -> Self: + """Stock datasets. + + Please see https://tigergraph-public-data.s3.us-west-1.amazonaws.com/inventory.json[this link] + for datasets that are currently available. The files for the dataset with `name` will be + downloaded to local `tmp_dir` automatically when this class is instantiated. + For offline environments, download the desired .tar manually from the inventory page, and extract in the desired location. + Specify the `tmp_dir` parameter to point to where the unzipped directory resides. + + + Args: + name (str, optional): + Name of the dataset to get. If not provided or None, available datasets will be printed out. + Defaults to None. + tmp_dir (str, optional): + Where to store the artifacts of this dataset. Defaults to "./tmp". + """ + self = cls() + BaseDataset.__init__(self, name) + self.base_url = "https://tigergraph-public-data.s3.us-west-1.amazonaws.com/" + self.tmp_dir = tmp_dir + + if not name: + await self.list() + return self + + # Download the dataset and extract + if isdir(pjoin(tmp_dir, name)): + print( + "A folder with name {} already exists in {}. Skip downloading.".format( + name, tmp_dir + ) + ) + + if not isdir(pjoin(tmp_dir, name)): + dataset_url = await self.get_dataset_url() + # Check if it is an in-stock dataset. + if not dataset_url: + raise Exception("Cannot find this dataset in the inventory.") + self.dataset_url = dataset_url + self.download_extract() + + self.ingest_ready = True + + return self + + # For overriding Dataset's init method + def __init__(self): + pass + + async def get_dataset_url(self) -> str: + "NO DOC" + inventory_url = urljoin(self.base_url, "inventory.json") + async with httpx.AsyncClient() as client: + resp = await client.request("GET", inventory_url) + resp.raise_for_status() + resp = resp.json() + if self.name in resp: + return resp[self.name] + else: + return None + + async def download_extract(self) -> None: + "NO DOC" + makedirs(self.tmp_dir, exist_ok=True) + client = httpx.AsyncClient() + async with client.stream("GET", self.dataset_url) as resp: + raw_content = b'' + async for byte in resp.aiter_raw(): + raw_content += byte + raw_content = io.BytesIO(raw_content) + try: + from tqdm.auto import tqdm + total_length = int(resp.headers.get("Content-Length")) + with tqdm.wrapattr( + raw_content, "read", total=total_length, desc="Downloading" + ) as raw: + with tarfile.open(fileobj=raw, mode="r|gz") as tarobj: + tarobj.extractall(path=self.tmp_dir) + except ImportError: + warnings.warn( + "Cannot import tqdm. Downloading without progress report.") + with tarfile.open(fileobj=raw_content, mode="r|gz") as tarobj: + tarobj.extractall(path=self.tmp_dir) + print("Dataset downloaded.") + + async def create_graph(self, conn) -> str: + "NO DOC" + with open(pjoin(self.tmp_dir, self.name, "create_graph.gsql"), "r") as infile: + resp = await conn.gsql(infile.read()) + return resp + + async def create_schema(self, conn) -> str: + "NO DOC" + with open(pjoin(self.tmp_dir, self.name, "create_schema.gsql"), "r") as infile: + resp = await conn.gsql(infile.read()) + return resp + + async def create_load_job(self, conn) -> None: + "NO DOC" + with open( + pjoin(self.tmp_dir, self.name, "create_load_job.gsql"), "r" + ) as infile: + resp = await conn.gsql(infile.read()) + return resp + + async def run_load_job(self, conn) -> dict: + "NO DOC" + with open(pjoin(self.tmp_dir, self.name, "run_load_job.json"), "r") as infile: + jobs = json.load(infile) + + for job in jobs: + resp = await conn.runLoadingJobWithFile( + pjoin(self.tmp_dir, self.name, job["filePath"]), + job["fileTag"], + job["jobName"], + sep=job.get("sep", ","), + eol=job.get("eol", "\n"), + timeout=job.get("timeout", 60000), + sizeLimit=job.get("sizeLimit", 128000000), + ) + yield resp + + async def list(self) -> None: + """List available stock datasets + """ + inventory_url = urljoin(self.base_url, "inventory.json") + async with httpx.AsyncClient() as client: + resp = await client.request("GET", inventory_url) + resp.raise_for_status() + print("Available datasets:") + for k in resp.json(): + print("- {}".format(k)) diff --git a/pyTigerGraph/pytgasync/pyTigerGraph.py b/pyTigerGraph/pytgasync/pyTigerGraph.py new file mode 100644 index 00000000..357e775c --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraph.py @@ -0,0 +1,69 @@ +import sys +import warnings +from typing import TYPE_CHECKING, Union + +import urllib3 + +from pyTigerGraph.pytgasync.pyTigerGraphVertex import AsyncPyTigerGraphVertex +from pyTigerGraph.pytgasync.pyTigerGraphDataset import AsyncPyTigerGraphDataset +from pyTigerGraph.pytgasync.pyTigerGraphEdge import AsyncPyTigerGraphEdge +from pyTigerGraph.pytgasync.pyTigerGraphLoading import AsyncPyTigerGraphLoading +from pyTigerGraph.pytgasync.pyTigerGraphPath import AsyncPyTigerGraphPath +from pyTigerGraph.pytgasync.pyTigerGraphUDT import AsyncPyTigerGraphUDT + +if TYPE_CHECKING: + from ..gds import gds + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +if not sys.warnoptions: + warnings.filterwarnings("once", category=DeprecationWarning) + + +# TODO Proper deprecation handling; import deprecation? + +class AsyncTigerGraphConnection(AsyncPyTigerGraphVertex, AsyncPyTigerGraphEdge, AsyncPyTigerGraphUDT, + AsyncPyTigerGraphLoading, AsyncPyTigerGraphPath, AsyncPyTigerGraphDataset, object): + """Python wrapper for TigerGraph's REST++ and GSQL APIs""" + + def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", + gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", + tgCloud: bool = False, restppPort: Union[int, str] = "9000", + gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", + apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): + super().__init__(host, graphname, gsqlSecret, username, password, tgCloud, restppPort, + gsPort, gsqlVersion, version, apiToken, useCert, certPath, debug, sslPort, gcp, jwtToken) + + self.gds = None + self.ai = None + + def __getattribute__(self, name): + if name == "gds": + if super().__getattribute__(name) is None: + try: + from ..gds import gds + self.gds = gds.GDS(self) + return super().__getattribute__(name) + except: + raise Exception( + "Please install the GDS package requirements to use the GDS functionality." + "Check the https://docs.tigergraph.com/pytigergraph/current/getting-started/install#_install_pytigergraphgds for more details.") + else: + return super().__getattribute__(name) + elif name == "ai": + if super().__getattribute__(name) is None: + try: + from ..ai import ai + self.ai = ai.AI(self) + return super().__getattribute__(name) + except Exception as e: + raise Exception( + "Error importing AI submodule. "+str(e) + ) + else: + return super().__getattribute__(name) + else: + return super().__getattribute__(name) + +# EOF diff --git a/pyTigerGraph/pytgasync/pyTigerGraphAuth.py b/pyTigerGraph/pytgasync/pyTigerGraphAuth.py new file mode 100644 index 00000000..4849b1d7 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphAuth.py @@ -0,0 +1,230 @@ +"""Authentication Functions + +The functions on this page authenticate connections and manage TigerGraph credentials. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" + +import logging +from typing import Union, Dict +import warnings +import httpx + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.auth import ( + _parse_get_secrets, + _parse_create_secret, + _prep_token_request, + _parse_token_response +) +from pyTigerGraph.pytgasync.pyTigerGraphGSQL import AsyncPyTigerGraphGSQL + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphAuth(AsyncPyTigerGraphGSQL): + + async def getSecrets(self) -> Dict[str, str]: + """Issues a `SHOW SECRET` GSQL statement and returns the secret generated by that + statement. + Secrets are unique strings that serve as credentials when generating authentication tokens. + + Returns: + A dictionary of `alias: secret_string` pairs. + + Notes: + This function returns the masked version of the secret. The original value of the secret cannot + be retrieved after creation. + """ + logger.info("entry: getSecrets") + + res = await self.gsql(""" + USE GRAPH {} + SHOW SECRET""".format(self.graphname), ) + ret = _parse_get_secrets(res) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getSecrets") + + return ret + # TODO Process response, return a dictionary of alias/secret pairs + + async def showSecrets(self) -> Dict[str, str]: + """DEPRECATED + + Use `getSecrets()` instead. + """ + warnings.warn("The `showSecrets()` function is deprecated; use `getSecrets()` instead.", + DeprecationWarning) + + ret = await self.getSecrets() + return ret + + async def createSecret(self, alias: str = "", withAlias: bool = False) -> Union[str, Dict[str, str]]: + """Issues a `CREATE SECRET` GSQL statement and returns the secret generated by that statement. + Secrets are unique strings that serve as credentials when generating authentication tokens. + + Args: + alias: + The alias of the secret. / + The system will generate a random alias for the secret if the user does not provide + an alias for that secret. Randomly generated aliases begin with + `AUTO_GENERATED_ALIAS_` and include a random 7-character string. + withAlias: + Return the new secret as an `{"alias": "secret"}` dictionary. This can be useful if + an alias was not provided, for example if it is auto-generated). + + Returns: + The secret string. + + Notes: + Generally, secrets are generated by the database administrator and + used to generate a token. If you use this function, please consider reviewing your + internal processes of granting access to TigerGraph instances. Normally, this function + should not be necessary and should not be executable by generic users. + """ + logger.info("entry: createSecret") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + res = await self.gsql(""" + USE GRAPH {} + CREATE SECRET {} """.format(self.graphname, alias)) + secret = _parse_create_secret( + res, alias=alias, withAlias=withAlias) + + # Alias was not provided, let's find out the autogenerated one + # done in createSecret since need to call self.getSecrets which is a possibly async function + if withAlias and not alias: + masked = secret[:3] + "****" + secret[-3:] + secs = await self.getSecrets() + for a, s in secs.items(): + if s == masked: + secret = {a: secret} + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(secret)) + logger.info("exit: createSecret") + return secret + + async def dropSecret(self, alias: Union[str, list], ignoreErrors: bool = True) -> str: + """Drops a secret. + See https://docs.tigergraph.com/tigergraph-server/current/user-access/managing-credentials#_drop_a_secret + + Args: + alias: + One or more alias(es) of secret(s). + ignoreErrors: + Ignore errors arising from trying to drop non-existent secrets. + + Raises: + `TigerGraphException` if a non-existent secret is attempted to be dropped (unless + `ignoreErrors` is `True`). Re-raises other exceptions. + """ + logger.info("entry: dropSecret") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + if isinstance(alias, str): + alias = [alias] + cmd = """ + USE GRAPH {}""".format(self.graphname) + for a in alias: + cmd += """ + DROP SECRET {}""".format(a) + res = await self.gsql(cmd) + + if "Failed to drop secrets" in res and not ignoreErrors: + raise TigerGraphException(res) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: dropSecret") + + return res + + async def _token(self, secret: str = None, lifetime: int = None, token=None, _method=None) -> Union[tuple, str]: + method, url, alt_url, authMode, data, alt_data = _prep_token_request(self.restppUrl, + self.gsUrl, + self.graphname, + secret=secret, + lifetime=lifetime, + token=token) + # _method Used for delete and refresh token + + # method == GET when using old version since _prep_newToken() gets the method for getting a new token for a version + if method == "GET": + if _method: + method = _method + + # Use TG < 3.5 format (no json data) + res = await self._req(method, url, authMode=authMode, data=data, resKey=None) + mainVer = 3 + else: + if _method: + method = _method + + # Try using TG 4.1 endpoint first, if url not found then try <4.1 endpoint + try: + res = await self._req(method, url, authMode=authMode, data=data, resKey=None, jsonData=True) + mainVer = 4 + except: + try: + res = await self._req(method, alt_url, authMode=authMode, data=alt_data, resKey=None) + mainVer = 3 + except: + raise TigerGraphException("Error requesting token. Check if the connection's graphname is correct.", 400) + + # uses mainVer instead of _versionGreaterThan4_0 since you need a token for verson checking + return res, mainVer + + async def getToken(self, secret: str = None, setToken: bool = True, lifetime: int = None) -> Union[tuple, str]: + logger.info("entry: getToken") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + res, mainVer = await self._token(secret, lifetime) + token, auth_header = _parse_token_response(res, + setToken, + mainVer, + self.base64_credential + ) + + self.apiToken = token + self.authHeader = auth_header + + logger.info("exit: getToken") + return token + + async def refreshToken(self, secret: str = None, setToken: bool = True, lifetime: int = None, token="") -> Union[tuple, str]: + logger.info("entry: refreshToken") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + if await self._version_greater_than_4_0(): + logger.info("exit: refreshToken") + raise TigerGraphException( + "Refreshing tokens is only supported on versions of TigerGraph <= 4.0.0.", 0) + + if not token: + token = self.apiToken + res, mainVer = await self._token(secret=secret, lifetime=lifetime, token=token, _method="PUT") + newToken = _parse_token_response(res, setToken, mainVer) + + logger.info("exit: refreshToken") + return newToken + + async def deleteToken(self, secret: str, token=None, skipNA=True) -> bool: + if not token: + token = self.apiToken + res, _ = await self._token(secret=secret, token=token, _method="DELETE") + + if not res["error"] or (res["code"] == "REST-3300" and skipNA): + if logger.level == logging.DEBUG: + logger.debug("return: " + str(True)) + logger.info("exit: deleteToken") + + return True + + raise TigerGraphException( + res["message"], (res["code"] if "code" in res else None)) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphBase.py b/pyTigerGraph/pytgasync/pyTigerGraphBase.py new file mode 100644 index 00000000..b6fdcc83 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphBase.py @@ -0,0 +1,340 @@ +import json +import logging +import httpx + +from typing import Union +from urllib.parse import urlparse + +from pyTigerGraph.common.base import PyTigerGraphCore + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphBase(PyTigerGraphCore): + def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", + gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", + tgCloud: bool = False, restppPort: Union[int, str] = "9000", + gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", + apiToken: str = "", useCert: bool = None, certPath: str = None, debug: bool = None, + sslPort: Union[int, str] = "443", gcp: bool = False, jwtToken: str = ""): + """Initiate a connection object (doc string copied from synchronous __init__). + + Args: + host: + The host name or IP address of the TigerGraph server. Make sure to include the + protocol (http:// or https://). If `certPath` is `None` and the protocol is https, + a self-signed certificate will be used. + graphname: + The default graph for running queries. + gsqlSecret: + The secret key for GSQL. See https://docs.tigergraph.com/tigergraph-server/current/user-access/managing-credentials#_secrets. + username: + The username on the TigerGraph server. + password: + The password for that user. + tgCloud: + Set to `True` if using TigerGraph Cloud. If your hostname contains `tgcloud`, then + this is automatically set to `True`, and you do not need to set this argument. + restppPort: + The port for REST++ queries. + gsPort: + The port for gsql server. + gsqlVersion: + The version of the GSQL client to be used. Effectively the version of the database + being connected to. + version: + DEPRECATED; use `gsqlVersion`. + apiToken: + DEPRECATED; use `getToken()` with a secret to get a session token. + useCert: + DEPRECATED; the need for a CA certificate is now determined by URL scheme. + certPath: + The filesystem path to the CA certificate. Required in case of https connections. + debug: + DEPRECATED; configure standard logging in your app. + sslPort: + Port for fetching SSL certificate in case of firewall. + gcp: + DEPRECATED. Previously used for connecting to databases provisioned on GCP in TigerGraph Cloud. + jwtToken: + The JWT token generated from customer side for authentication + + Raises: + TigerGraphException: In case on invalid URL scheme. + + """ + + super().__init__(host=host, graphname=graphname, gsqlSecret=gsqlSecret, + username=username, password=password, tgCloud=tgCloud, + restppPort=restppPort, gsPort=gsPort, gsqlVersion=gsqlVersion, + version=version, apiToken=apiToken, useCert=useCert, certPath=certPath, + debug=debug, sslPort=sslPort, gcp=gcp, jwtToken=jwtToken) + + async def _req(self, method: str, url: str, authMode: str = "token", headers: dict = None, + data: Union[dict, list, str] = None, resKey: str = "results", skipCheck: bool = False, + params: Union[dict, list, str] = None, strictJson: bool = True, jsonData: bool = False, + jsonResponse: bool = True, func=None) -> Union[dict, list]: + """Generic REST++ API request. Copied from synchronous version, changing requests to httpx with async functionality. + + Args: + method: + HTTP method, currently one of `GET`, `POST` or `DELETE`. + url: + Complete REST++ API URL including path and parameters. + authMode: + Authentication mode, either `"token"` (default) or `"pwd"`. + headers: + Standard HTTP request headers. + data: + Request payload, typically a JSON document. + resKey: + The JSON subdocument to be returned, default is `"result"`. + skipCheck: + Some endpoints return an error to indicate that the requested + action is not applicable. This argument skips error checking. + params: + Request URL parameters. + strictJson: + If JSON should load the response in strict mode or not. + jsonData: + If data in data var is a JSON document. + + Returns: + The (relevant part of the) response from the request (as a dictionary). + """ + _headers, _data, verify = self._prep_req( + authMode, headers, url, method, data) + + async with httpx.AsyncClient() as client: + if jsonData: + res = await client.request(method, url, headers=_headers, json=_data, params=params) + else: + res = await client.request(method, url, headers=_headers, data=_data, params=params) + + try: + if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: + try: + self._error_check(json.loads(res.text)) + except json.decoder.JSONDecodeError: + # could not parse the res text (probably returned an html response) + pass + res.raise_for_status() + except Exception as e: + # In TG 4.x the port for restpp has changed from 9000 to 14240. + # This block should only be called once. When using 4.x, using port 9000 should fail so self.restppurl will change to host:14240/restpp + # ---- + # Changes port to 14240, adds /restpp to end to url, tries again, saves changes if successful + if self.restppPort == "9000" and "9000" in url: + newRestppUrl = self.host + ":14240/restpp" + # In tgcloud /restpp can already be in the restpp url. We want to extract everything after the port or /restpp + if '/restpp' in url: + url = newRestppUrl + '/' + \ + '/'.join(url.split(':')[2].split('/')[2:]) + else: + url = newRestppUrl + '/' + \ + '/'.join(url.split(':')[2].split('/')[1:]) + async with httpx.AsyncClient() as client: + if jsonData: + res = await client.request(method, url, headers=_headers, json=_data, params=params) + else: + res = await client.request(method, url, headers=_headers, data=_data, params=params) + if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: + try: + self._error_check(json.loads(res.text)) + except json.decoder.JSONDecodeError: + # could not parse the res text (probably returned an html response) + pass + res.raise_for_status() + self.restppUrl = newRestppUrl + self.restppPort = "14240" + else: + raise e + + return self._parse_req(res, jsonResponse, strictJson, skipCheck, resKey) + + async def _get(self, url: str, authMode: str = "token", headers: dict = None, resKey: str = "results", + skipCheck: bool = False, params: Union[dict, list, str] = None, strictJson: bool = True) -> Union[dict, list]: + """Generic GET method. + + Args: + url: + Complete REST++ API URL including path and parameters. + authMode: + Authentication mode, either `"token"` (default) or `"pwd"`. + headers: + Standard HTTP request headers. + resKey: + The JSON subdocument to be returned, default is `"result"`. + skipCheck: + Some endpoints return an error to indicate that the requested + action is not applicable. This argument skips error checking. + params: + Request URL parameters. + + Returns: + The (relevant part of the) response from the request (as a dictionary). + """ + logger.info("entry: _get") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + res = await self._req("GET", url, authMode, headers, None, resKey, skipCheck, params, strictJson) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: _get") + + return res + + async def _post(self, url: str, authMode: str = "token", headers: dict = None, + data: Union[dict, list, str, bytes] = None, resKey: str = "results", skipCheck: bool = False, + params: Union[dict, list, str] = None, jsonData: bool = False) -> Union[dict, list]: + """Generic POST method. + + Args: + url: + Complete REST++ API URL including path and parameters. + authMode: + Authentication mode, either `"token"` (default) or `"pwd"`. + headers: + Standard HTTP request headers. + data: + Request payload, typically a JSON document. + resKey: + The JSON subdocument to be returned, default is `"result"`. + skipCheck: + Some endpoints return an error to indicate that the requested + action is not applicable. This argument skips error checking. + params: + Request URL parameters. + + Returns: + The (relevant part of the) response from the request (as a dictionary). + """ + logger.info("entry: _post") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + res = await self._req("POST", url, authMode, headers, data, resKey, skipCheck, params, jsonData=jsonData) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: _post") + + return res + + async def _put(self, url: str, authMode: str = "token", data=None, resKey=None, jsonData=False) -> Union[dict, list]: + """Generic PUT method. + + Args: + url: + Complete REST++ API URL including path and parameters. + authMode: + Authentication mode, either `"token"` (default) or `"pwd"`. + + Returns: + The response from the request (as a dictionary). + """ + logger.info("entry: _put") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + res = await self._req("PUT", url, authMode, data=data, resKey=resKey, jsonData=jsonData) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: _put") + + return res + + async def _delete(self, url: str, authMode: str = "token", data: dict = None, resKey="results", jsonData=False) -> Union[dict, list]: + """Generic DELETE method. + + Args: + url: + Complete REST++ API URL including path and parameters. + authMode: + Authentication mode, either `"token"` (default) or `"pwd"`. + + Returns: + The response from the request (as a dictionary). + """ + logger.info("entry: _delete") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + res = await self._req("DELETE", url, authMode, data=data, resKey=resKey, jsonData=jsonData) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: _delete") + + return res + + async def getVersion(self, raw: bool = False) -> Union[str, list]: + """Retrieves the git versions of all components of the system. + + Args: + raw: + Return unprocessed version info string, or extract version info for each component + into a list. + + Returns: + Either an unprocessed string containing the version info details, or a list with version + info for each component. + + Endpoint: + - `GET /version` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_component_versions[Show component versions] + """ + logger.info("entry: getVersion") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + response = await self._get(self.restppUrl+"/version", strictJson=False, resKey="message") + components = self._parse_get_version(response, raw) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(components)) + logger.info("exit: getVersion") + return components + + async def getVer(self, component: str = "product", full: bool = False) -> str: + """Gets the version information of a specific component. + + Get the full list of components using `getVersion()`. + + Args: + component: + One of TigerGraph's components (e.g. product, gpe, gse). + full: + Return the full version string (with timestamp, etc.) or just X.Y.Z. + + Returns: + Version info for specified component. + + Raises: + `TigerGraphException` if invalid/non-existent component is specified. + """ + logger.info("entry: getVer") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + version = await self.getVersion() + ret = self._parse_get_ver(version, component, full) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVer") + + return ret + + async def _version_greater_than_4_0(self) -> bool: + """Gets if the TigerGraph database version is greater than 4.0 using gerVer(). + + Returns: + Boolean of whether databse version is greater than 4.0. + """ + version = await self.getVer() + version = version.split('.') + if version[0] >= "4" and version[1] > "0": + return True + return False diff --git a/pyTigerGraph/pytgasync/pyTigerGraphDataset.py b/pyTigerGraph/pytgasync/pyTigerGraphDataset.py new file mode 100644 index 00000000..d47d4ae8 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphDataset.py @@ -0,0 +1,94 @@ +"""Data Ingestion Functions + +Ingest stock datasets into a TigerGraph database. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" + +import logging + +from pyTigerGraph.common.dataset import _parse_ingest_dataset +from pyTigerGraph.pytgasync.datasets import AsyncDatasets +from pyTigerGraph.pytgasync.pyTigerGraphAuth import AsyncPyTigerGraphAuth + + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphDataset(AsyncPyTigerGraphAuth): + async def ingestDataset( + self, + dataset: AsyncDatasets, + cleanup: bool = True, + getToken: bool = False + ) -> None: + """Ingest a stock dataset to a TigerGraph database. + + Args: + dataset (Datasets): + A Datasets object as `pyTigerGraph.datasets.Datasets`. + cleanup (bool, optional): + Whether or not to remove local artifacts downloaded by `Datasets` + after ingestion is done. Defaults to True. + getToken (bool, optional): + Whether or not to get auth token from the database. This is required + when auth token is enabled for the database. Defaults to False. + """ + logger.info("entry: ingestDataset") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + if not dataset.ingest_ready: + raise Exception("This dataset is not ingestable.") + + print("---- Checking database ----", flush=True) + if await self.check_exist_graphs(dataset.name): + # self.gsql("USE GRAPH {}\nDROP JOB ALL\nDROP GRAPH {}".format( + # dataset.name, dataset.name + # )) + self.graphname = dataset.name + if getToken: + await self.getToken(await self.createSecret()) + print( + "A graph with name {} already exists in the database. " + "Skip ingestion.".format(dataset.name) + ) + print("Graph name is set to {} for this connection.".format(dataset.name)) + return + + print("---- Creating graph ----", flush=True) + resp = await dataset.create_graph(self) + print(resp, flush=True) + if "Failed" in resp: + return + + print("---- Creating schema ----", flush=True) + resp = await dataset.create_schema(self) + print(resp, flush=True) + if "Failed" in resp: + return + + print("---- Creating loading job ----", flush=True) + resp = await dataset.create_load_job(self) + print(resp, flush=True) + if "Failed" in resp: + return + + print("---- Ingesting data ----", flush=True) + self.graphname = dataset.name + if getToken: + secret = await self.createSecret() + await self.getToken(secret) + + responses = [] + for resp in await dataset.run_load_job(self): + responses.append(resp) + + _parse_ingest_dataset(responses, cleanup, dataset) + + print("---- Finished ingestion ----", flush=True) + logger.info("exit: ingestDataset") + + async def check_exist_graphs(self, name: str) -> bool: + "NO DOC" + resp = await self.gsql("ls") + return "Graph {}".format(name) in resp diff --git a/pyTigerGraph/pytgasync/pyTigerGraphEdge.py b/pyTigerGraph/pytgasync/pyTigerGraphEdge.py new file mode 100644 index 00000000..6a62c530 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphEdge.py @@ -0,0 +1,911 @@ +"""Edge Functions + +Functions to upsert, retrieve and delete edges. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import json +import logging +import warnings + +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + import pandas as pd + +from pyTigerGraph.common.edge import ( + _parse_get_edge_source_vertex_type, + _parse_get_edge_target_vertex_type, + _prep_get_edge_count_from, + _parse_get_edge_count_from, + _prep_upsert_edge, + _dumps, + _prep_upsert_edges, + _prep_upsert_edge_dataframe, + _prep_get_edges, + _prep_get_edges_by_type, + _parse_get_edge_stats, + _prep_del_edges +) + +from pyTigerGraph.common.edge import edgeSetToDataFrame as _eS2DF + +from pyTigerGraph.common.schema import ( + _get_attr_type, + _upsert_attrs +) + +from pyTigerGraph.pytgasync.pyTigerGraphQuery import AsyncPyTigerGraphQuery + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphEdge(AsyncPyTigerGraphQuery): + + ___trgvtxids = "___trgvtxids" + + async def getEdgeTypes(self, force: bool = False) -> list: + """Returns the list of edge type names of the graph. + + Args: + force: + If `True`, forces the retrieval the schema metadata again, otherwise returns a + cached copy of edge type metadata (if they were already fetched previously). + + Returns: + The list of edge types defined in the current graph. + """ + logger.info("entry: getEdgeTypes") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = [] + schema = await self.getSchema(force=force) + for et in schema["EdgeTypes"]: + ret.append(et["Name"]) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdgeTypes") + + return ret + + async def getEdgeType(self, edgeType: str, force: bool = False) -> dict: + """Returns the details of the edge type. + + Args: + edgeType: + The name of the edge type. + force: + If `True`, forces the retrieval the schema details again, otherwise returns a cached + copy of edge type details (if they were already fetched previously). + + Returns: + The metadata of the edge type. + """ + logger.info("entry: getEdgeType") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + schema = await self.getSchema(force=force) + for et in schema["EdgeTypes"]: + if et["Name"] == edgeType: + if logger.level == logging.DEBUG: + logger.debug("return: " + str(et)) + logger.info("exit: getEdgeType (found)") + + return et + + logger.warning("Edge type `" + edgeType + "` was not found.") + logger.info("exit: getEdgeType (not found)") + + return {} + + async def getEdgeAttrs(self, edgeType: str) -> list: + """Returns the names and types of the attributes of the edge type. + + Args: + edgeType: + The name of the edge type. + + Returns: + A list of (attribute_name, attribute_type) tuples. + The format of attribute_type is one of + - "scalar_type" + - "complex_type(scalar_type)" + - "map_type(key_type,value_type)" + and it is a string. + """ + logger.info("entry: getAttributes") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + et = await self.getEdgeType(edgeType) + ret = [] + + for at in et["Attributes"]: + ret.append( + (at["AttributeName"], _get_attr_type(at["AttributeType"]))) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getAttributes") + + return ret + + async def getEdgeSourceVertexType(self, edgeType: str) -> Union[str, set]: + """Returns the type(s) of the edge type's source vertex. + + Args: + edgeType: + The name of the edge type. + + Returns: + - A single source vertex type name string if the edge has a single source vertex type. + - "*" if the edge can originate from any vertex type (notation used in 2.6.1 and earlier + versions). + See https://docs.tigergraph.com/v/2.6/dev/gsql-ref/ddl-and-loading/defining-a-graph-schema#creating-an-edge-from-or-to-any-vertex-type + - A set of vertex type name strings (unique values) if the edge has multiple source + vertex types (notation used in 3.0 and later versions). / + Even if the source vertex types were defined as `"*"`, the REST API will list them as + pairs (i.e. not as `"*"` in 2.6.1 and earlier versions), just like as if there were + defined one by one (e.g. `FROM v1, TO v2 | FROM v3, TO v4 | …`). + + The returned set contains all source vertex types, but it does not certainly mean that + the edge is defined between all source and all target vertex types. You need to look + at the individual source/target pairs to find out which combinations are + valid/defined. + """ + logger.info("entry: getEdgeSourceVertexType") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + edgeTypeDetails = await self.getEdgeType(edgeType) + res = _parse_get_edge_source_vertex_type(edgeTypeDetails) + return res + + async def getEdgeTargetVertexType(self, edgeType: str) -> Union[str, set]: + """Returns the type(s) of the edge type's target vertex. + + Args: + edgeType: + The name of the edge type. + + Returns: + - A single target vertex type name string if the edge has a single target vertex type. + - "*" if the edge can end in any vertex type (notation used in 2.6.1 and earlier + versions). + See https://docs.tigergraph.com/v/2.6/dev/gsql-ref/ddl-and-loading/defining-a-graph-schema#creating-an-edge-from-or-to-any-vertex-type + - A set of vertex type name strings (unique values) if the edge has multiple target + vertex types (notation used in 3.0 and later versions). / + Even if the target vertex types were defined as "*", the REST API will list them as + pairs (i.e. not as "*" in 2.6.1 and earlier versions), just like as if there were + defined one by one (e.g. `FROM v1, TO v2 | FROM v3, TO v4 | …`). + + The returned set contains all target vertex types, but does not certainly mean that the + edge is defined between all source and all target vertex types. You need to look at + the individual source/target pairs to find out which combinations are valid/defined. + """ + logger.info("entry: getEdgeTargetVertexType") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + edgeTypeDetails = await self.getEdgeType(edgeType) + ret = _parse_get_edge_target_vertex_type(edgeTypeDetails) + return ret + + async def isDirected(self, edgeType: str) -> bool: + """Is the specified edge type directed? + + Args: + edgeType: + The name of the edge type. + + Returns: + `True`, if the edge is directed. + """ + logger.info("entry: isDirected") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = await self.getEdgeType(edgeType) + ret = ret["IsDirected"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: isDirected") + + return ret + + async def getReverseEdge(self, edgeType: str) -> str: + """Returns the name of the reverse edge of the specified edge type, if applicable. + + Args: + edgeType: + The name of the edge type. + + Returns: + The name of the reverse edge, if it was defined. + """ + logger.info("entry: getReverseEdge") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + if not await self.isDirected(edgeType): + logger.error(edgeType + " is not a directed edge") + logger.info("exit: getReverseEdge (not directed)") + + return "" + + config = await self.getEdgeType(edgeType) + config = config["Config"] + if "REVERSE_EDGE" in config: + ret = config["REVERSE_EDGE"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getReverseEdge (reverse edge found)") + + return ret + + logger.info("exit: getReverseEdge (reverse edge not found)") + + return "" + # TODO Should return some other value or raise exception? + + async def isMultiEdge(self, edgeType: str) -> bool: + """Can the edge have multiple instances between the same pair of vertices? + + Args: + edgeType: + The name of the edge type. + + Returns: + `True`, if the edge can have multiple instances between the same pair of vertices. + """ + logger.info("entry: isMultiEdge") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + et = await self.getEdgeType(edgeType) + ret = ("DiscriminatorCount" in et) and et["DiscriminatorCount"] > 0 + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: isMultiEdge") + + return ret + + async def getDiscriminators(self, edgeType: str) -> list: + """Returns the names and types of the discriminators of the edge type. + + Args: + edgeType: + The name of the edge type. + + Returns: + A list of (attribute_name, attribute_type) tuples. + """ + logger.info("entry: getDiscriminators") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + et = await self.getEdgeType(edgeType) + ret = [] + + for at in et["Attributes"]: + if "IsDiscriminator" in at and at["IsDiscriminator"]: + ret.append( + (at["AttributeName"], _get_attr_type(at["AttributeType"]))) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getDiscriminators") + + return ret + + async def getEdgeCountFrom(self, sourceVertexType: str = "", sourceVertexId: Union[str, int] = None, + edgeType: str = "", targetVertexType: str = "", targetVertexId: Union[str, int] = None, + where: str = "") -> dict: + """Returns the number of edges from a specific vertex. + + Args: + sourceVertexType: + The name of the source vertex type. + sourceVertexId: + The primary ID value of the source vertex instance. + edgeType: + The name of the edge type. + targetVertexType: + The name of the target vertex type. + targetVertexId: + The primary ID value of the target vertex instance. + where: + A comma separated list of conditions that are all applied on each edge's attributes. + The conditions are in logical conjunction (i.e. they are "AND'ed" together). + + Returns: + A dictionary of `edge_type: edge_count` pairs. + + Uses: + - If `edgeType` = "*": edge count of all edge types (no other arguments can be specified + in this case). + - If `edgeType` is specified only: edge count of the given edge type. + - If `sourceVertexType`, `edgeType`, `targetVertexType` are specified: edge count of the + given edge type between source and target vertex types. + - If `sourceVertexType`, `sourceVertexId` are specified: edge count of all edge types + from the given vertex instance. + - If `sourceVertexType`, `sourceVertexId`, `edgeType` are specified: edge count of all + edge types from the given vertex instance. + - If `sourceVertexType`, `sourceVertexId`, `edgeType`, `where` are specified: the edge + count of the given edge type after filtered by `where` condition. + - If `targetVertexId` is specified, then `targetVertexType` must also be specified. + - If `targetVertexType` is specified, then `edgeType` must also be specified. + + Endpoints: + - `GET /graph/{graph_name}/edges/{source_vertex_type}/{source_vertex_id}` + See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_list_edges_of_a_vertex + - `POST /builtins/{graph_name}` + See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_run_built_in_functions_on_graph + """ + logger.info("entry: getEdgeCountFrom") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + url, data = _prep_get_edge_count_from(restppUrl=self.restppUrl, + graphname=self.graphname, + sourceVertexType=sourceVertexType, + sourceVertexId=sourceVertexId, + edgeType=edgeType, + targetVertexType=targetVertexType, + targetVertexId=targetVertexId, + where=where) + if data: + res = await self._req("POST", url, data=data) + else: + res = await self._req("GET", url) + ret = _parse_get_edge_count_from(res, edgeType) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdgeCountFrom (multiple edge types)") + + return ret + + async def getEdgeCount(self, edgeType: str = "*", sourceVertexType: str = "", + targetVertexType: str = "") -> dict: + """Returns the number of edges of an edge type. + + This is a simplified version of `getEdgeCountFrom()`, to be used when the total number of + edges of a given type is needed, regardless which vertex instance they are originated from. + See documentation of `getEdgeCountFrom` above for more details. + + Args: + edgeType: + The name of the edge type. + sourceVertexType: + The name of the source vertex type. + targetVertexType: + The name of the target vertex type. + + Returns: + A dictionary of `edge_type: edge_count` pairs. + """ + logger.info("entry: getEdgeCount") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = await self.getEdgeCountFrom(edgeType=edgeType, sourceVertexType=sourceVertexType, + targetVertexType=targetVertexType) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdgeCount") + + return ret + + async def upsertEdge(self, sourceVertexType: str, sourceVertexId: str, edgeType: str, + targetVertexType: str, targetVertexId: str, attributes: dict = None) -> int: + """Upserts an edge. + + Data is upserted: + + - If edge is not yet present in graph, it will be created (see special case below). + - If it's already in the graph, it is updated with the values specified in the request. + - If `vertex_must_exist` is True then edge will only be created if both vertex exists + in graph. Otherwise missing vertices are created with the new edge; the newly created + vertices' attributes (if any) will be created with default values. + + Args: + sourceVertexType: + The name of the source vertex type. + sourceVertexId: + The primary ID value of the source vertex instance. + edgeType: + The name of the edge type. + targetVertexType: + The name of the target vertex type. + targetVertexId: + The primary ID value of the target vertex instance. + attributes: + A dictionary in this format: + ``` + {, |(, ), …} + ``` + Example: + ``` + {"visits": (1482, "+"), "max_duration": (371, "max")} + ``` + For valid values of `` see https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#operation-codes . + + Returns: + A single number of accepted (successfully upserted) edges (0 or 1). + + Endpoint: + - `POST /graph/{graph_name}` + See https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#upsert-data-to-graph + + TODO Add ack, new_vertex_only, vertex_must_exist, update_vertex_only and atomic_level + parameters and functionality. + """ + logger.info("entry: upsertEdge") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + data = _prep_upsert_edge(sourceVertexType, + sourceVertexId, + edgeType, + targetVertexType, + targetVertexId, + attributes) + ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data) + ret = ret[0]["accepted_edges"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: upsertEdge") + + return ret + + async def upsertEdges(self, sourceVertexType: str, edgeType: str, targetVertexType: str, + edges: list) -> int: + """Upserts multiple edges (of the same type). + + Args: + sourceVertexType: + The name of the source vertex type. + edgeType: + The name of the edge type. + targetVertexType: + The name of the target vertex type. + edges: + A list in of tuples in this format: + ``` + [ + (, , {: , …}), + (, , {: (, ), …}) + ⋮ + ] + ``` + Example: + ``` + [ + (17, "home_page", {"visits": (35, "+"), "max_duration": (93, "max")}), + (42, "search", {"visits": (17, "+"), "max_duration": (41, "max")}) + ] + ``` + For valid values of `` see https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#operation-codes . + + Returns: + A single number of accepted (successfully upserted) edges (0 or positive integer). + + Endpoint: + - `POST /graph/{graph_name}` + See https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#upsert-data-to-graph + + TODO Add ack, new_vertex_only, vertex_must_exist, update_vertex_only and atomic_level + parameters and functionality. + """ + + logger.info("entry: upsertEdges") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + """ + NOTE: The source and target vertex primary IDs are converted below to string as the keys + in a JSON document must be string. + This probably should not be an issue as the primary ID has a predefined data type, so if + the same primary ID is sent as two different literal (say: 1 as number and "1" as + string), it will be converted anyhow to the same (numerical or string) data type. + Converting the primary IDs to string here prevents inconsistencies as Python dict would + otherwise handle 1 and "1" as two separate keys. + """ + + data = _prep_upsert_edges(sourceVertexType=sourceVertexType, + edgeType=edgeType, + targetVertexType=targetVertexType, + edges=edges) + ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data) + ret = ret[0]["accepted_edges"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: upsertEdges") + + return ret + + async def upsertEdgeDataFrame(self, df: 'pd.DataFrame', sourceVertexType: str, edgeType: str, + targetVertexType: str, from_id: str = "", to_id: str = "", + attributes: dict = None) -> int: + """Upserts edges from a Pandas DataFrame. + + Args: + df: + The DataFrame to upsert. + sourceVertexType: + The type of source vertex for the edge. + edgeType: + The type of edge to upsert data to. + targetVertexType: + The type of target vertex for the edge. + from_id: + The field name where the source vertex primary id is given. If omitted, the + dataframe index would be used instead. + to_id: + The field name where the target vertex primary id is given. If omitted, the + dataframe index would be used instead. + attributes: + A dictionary in the form of `{target: source}` where source is the column name in + the dataframe and target is the attribute name on the edge. When omitted, + all columns would be upserted with their current names. In this case column names + must match the edges's attribute names. + + Returns: + The number of edges upserted. + """ + logger.info("entry: upsertEdgeDataFrame") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + json_up = _prep_upsert_edge_dataframe(df, from_id, to_id, attributes) + ret = await self.upsertEdges(sourceVertexType, edgeType, targetVertexType, json_up) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: upsertEdgeDataFrame") + + return ret + + async def getEdges(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = "", + targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", fmt: str = "py", withId: bool = True, + withType: bool = False, timeout: int = 0) -> Union[dict, str, 'pd.DataFrame']: + """Retrieves edges of the given edge type originating from a specific source vertex. + + Only `sourceVertexType` and `sourceVertexId` are required. + If `targetVertexId` is specified, then `targetVertexType` must also be specified. + If `targetVertexType` is specified, then `edgeType` must also be specified. + + Args: + sourceVertexType: + The name of the source vertex type. + sourceVertexId: + The primary ID value of the source vertex instance. + edgeType: + The name of the edge type. + targetVertexType: + The name of the target vertex type. + targetVertexId: + The primary ID value of the target vertex instance. + select: + Comma separated list of edge attributes to be retrieved or omitted. + where: + Comma separated list of conditions that are all applied on each edge's attributes. + The conditions are in logical conjunction (i.e. they are "AND'ed" together). + sort: + Comma separated list of attributes the results should be sorted by. + limit: + Maximum number of edge instances to be returned (after sorting). + fmt: + Format of the results returned: + - "py": Python objects + - "json": JSON document + - "df": pandas DataFrame + withId: + (When the output format is "df") Should the source and target vertex types and IDs + be included in the dataframe? + withType: + (When the output format is "df") Should the edge type be included in the dataframe? + timeout: + Time allowed for successful execution (0 = no time limit, default). + + Returns: + The (selected) details of the (matching) edge instances (sorted, limited) as dictionary, + JSON or pandas DataFrame. + + Endpoint: + - `GET /graph/{graph_name}/edges/{source_vertex_type}/{source_vertex_id}` + See https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#list-edges-of-a-vertex + """ + logger.info("entry: getEdges") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + url = _prep_get_edges(self.restppUrl, + self.graphname, + sourceVertexType, + sourceVertexId, + edgeType, + targetVertexType, + targetVertexId, + select, + where, + limit, + sort, + timeout) + ret = await self._req("GET", url) + + if fmt == "json": + ret = json.dumps(ret) + elif fmt == "df": + ret = _eS2DF(ret, withId, withType) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdges") + + return ret + + async def getEdgesDataFrame(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = "", + targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': + """Retrieves edges of the given edge type originating from a specific source vertex. + + This is a shortcut to ``getEdges(..., fmt="df", withId=True, withType=False)``. + Only ``sourceVertexType`` and ``sourceVertexId`` are required. + If ``targetVertexId`` is specified, then ``targetVertexType`` must also be specified. + If ``targetVertexType`` is specified, then ``edgeType`` must also be specified. + + Args: + sourceVertexType: + The name of the source vertex type. + sourceVertexId: + The primary ID value of the source vertex instance. + edgeType: + The name of the edge type. + targetVertexType: + The name of the target vertex type. + targetVertexId: + The primary ID value of the target vertex instance. + select: + Comma separated list of edge attributes to be retrieved or omitted. + where: + Comma separated list of conditions that are all applied on each edge's attributes. + The conditions are in logical conjunction (i.e. they are "AND'ed" together). + sort: + Comma separated list of attributes the results should be sorted by. + limit: + Maximum number of edge instances to be returned (after sorting). + timeout: + Time allowed for successful execution (0 = no limit, default). + + Returns: + The (selected) details of the (matching) edge instances (sorted, limited) as dictionary, + JSON or pandas DataFrame. + """ + logger.info("entry: getEdgesDataFrame") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = await self.getEdges(sourceVertexType, sourceVertexId, edgeType, targetVertexType, + targetVertexId, select, where, limit, sort, fmt="df", timeout=timeout) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdgesDataFrame") + + return ret + + async def getEdgesDataframe(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = "", + targetVertexType: str = "", targetVertexId: str = "", select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': + """DEPRECATED + + Use `getEdgesDataFrame()` instead. + """ + warnings.warn( + "The `getEdgesDataframe()` function is deprecated; use `getEdgesDataFrame()` instead.", + DeprecationWarning) + + return await self.getEdgesDataFrame(sourceVertexType, sourceVertexId, edgeType, targetVertexType, + targetVertexId, select, where, limit, sort, timeout) + + async def getEdgesByType(self, edgeType: str, fmt: str = "py", withId: bool = True, + withType: bool = False) -> Union[dict, str, 'pd.DataFrame']: + """Retrieves edges of the given edge type regardless the source vertex. + + Args: + edgeType: + The name of the edge type. + fmt: + Format of the results returned: + - "py": Python objects + - "json": JSON document + - "df": pandas DataFrame + withId: + (When the output format is "df") Should the source and target vertex types and IDs + be included in the dataframe? + withType: + (When the output format is "df") should the edge type be included in the dataframe? + + Returns: + The details of the edge instances of the given edge type as dictionary, JSON or pandas + DataFrame. + + TODO Add limit parameter + """ + logger.info("entry: getEdgesByType") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + if not edgeType: + logger.warning("Edge type is not specified") + logger.info("exit: getEdgesByType") + + return {} + + sourceVertexType = await self.getEdgeSourceVertexType(edgeType) + queryText = _prep_get_edges_by_type(self.graphname, sourceVertexType, edgeType) + ret = await self.runInterpretedQuery(queryText) + + ret = ret[0]["edges"] + + if fmt == "json": + ret = json.dumps(ret) + elif fmt == "df": + ret = _eS2DF(ret, withId, withType) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: _upsertAttrs") + + return ret + + # TODO getEdgesDataFrameByType + + async def getEdgeStats(self, edgeTypes: Union[str, list], skipNA: bool = False) -> dict: + """Returns edge attribute statistics. + + Args: + edgeTypes: + A single edge type name or a list of edges types names or '*' for all edges types. + skipNA: + Skip those edges that do not have attributes or none of their attributes have + statistics gathered. + + Returns: + Attribute statistics of edges; a dictionary of dictionaries. + + Endpoint: + - `POST /builtins/{graph_name}` + See https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#run-built-in-functions-on-graph + """ + logger.info("entry: getEdgeStats") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ets = [] + if edgeTypes == "*": + ets = await self.getEdgeTypes() + elif isinstance(edgeTypes, str): + ets = [edgeTypes] + elif isinstance(edgeTypes, list): + ets = edgeTypes + else: + logger.warning("The `edgeTypes` parameter is invalid.") + logger.info("exit: getEdgeStats") + + return {} + + responses = [] + for et in ets: + data = '{"function":"stat_edge_attr","type":"' + \ + et + '","from_type":"*","to_type":"*"}' + res = await self._req("POST", self.restppUrl + "/builtins/" + self.graphname, data=data, resKey="", + skipCheck=True) + responses.append((et, res)) + ret = _parse_get_edge_stats(responses, skipNA) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEdgeStats") + + return ret + + async def delEdges(self, sourceVertexType: str, sourceVertexId: str, edgeType: str = "", + targetVertexType: str = "", targetVertexId: str = "", where: str = "", + limit: str = "", sort: str = "", timeout: int = 0) -> dict: + """Deletes edges from the graph. + + Only `sourceVertexType` and `sourceVertexId` are required. + If `targetVertexId` is specified, then `targetVertexType` must also be specified. + If `targetVertexType` is specified, then `edgeType` must also be specified. + + Args: + sourceVertexType: + The name of the source vertex type. + sourceVertexId: + The primary ID value of the source vertex instance. + edgeType: + The name of the edge type. + targetVertexType: + The name of the target vertex type. + targetVertexId: + The primary ID value of the target vertex instance. + where: + Comma separated list of conditions that are all applied on each edge's attributes. + The conditions are in logical conjunction (they are connected as if with an `AND` statement). + limit: + Maximum number of edge instances to be returned after sorting. + sort: + Comma-separated list of attributes the results should be sorted by. + timeout: + Time allowed for successful execution. The default is `0`, or no limit. + + Returns: + A dictionary of `edge_type: deleted_edge_count` pairs. + + Endpoint: + - `DELETE /graph/{graph_name}/edges/{source_vertex_type}/{source_vertex_id}/{edge_type}/{target_vertex_type}/{target_vertex_id}` + See https://docs.tigergraph.com/dev/restpp-api/built-in-endpoints#delete-an-edge + """ + logger.info("entry: delEdges") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + url = _prep_del_edges(self.restppUrl, + self.graphname, + sourceVertexType, + sourceVertexId, + edgeType, + targetVertexType, + targetVertexId, + where, + limit, + sort, + timeout) + res = await self._req("DELETE", url) + ret = {} + for r in res: + ret[r["e_type"]] = r["deleted_edges"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: delEdges") + + return ret + + async def edgeSetToDataFrame(self, edgeSet: list, withId: bool = True, withType: bool = False) -> 'pd.DataFrame': + """Converts an edge set to a pandas DataFrame. + + Args: + edgeSet: + The edge set to convert. + withId: + Should the source and target vertex types and IDs be included in the dataframe? + withType: + Should the edge type be included in the dataframe? + + Returns: + The edge set as a pandas DataFrame. + """ + logger.info("entry: edgeSetToDataFrame") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = _eS2DF(edgeSet, withId, withType) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: edgeSetToDataFrame") + + return ret \ No newline at end of file diff --git a/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py b/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py new file mode 100644 index 00000000..bf9ba1ef --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py @@ -0,0 +1,110 @@ +"""GSQL Interface + +Use GSQL within pyTigerGraph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import logging +import re +import httpx + +from typing import Union, Tuple, Dict +from urllib.parse import urlparse, quote_plus + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.gsql import ( + _parse_gsql, + _prep_get_udf, + _parse_get_udf +) + +from pyTigerGraph.pytgasync.pyTigerGraphBase import AsyncPyTigerGraphBase + + +logger = logging.getLogger(__name__) + +ANSI_ESCAPE = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + + +class AsyncPyTigerGraphGSQL(AsyncPyTigerGraphBase): + async def gsql(self, query: str, graphname: str = None, options=None) -> Union[str, dict]: + """Runs a GSQL query and processes the output. + + Args: + query: + The text of the query to run as one string. The query is one or more GSQL statement. + graphname: + The name of the graph to attach to. If not specified, the graph name provided at the + time of establishing the connection will be used. + options: + DEPRECATED + + Returns: + The output of the statement(s) executed. + + Endpoint: + - `POST /gsqlserver/gsql/file` (In TigerGraph versions 3.x) + - `POST /gsql/v1/statements` (In TigerGraph versions 4.x) + """ + # Can't use self._isVersionGreaterThan4_0 since you need a token to call /version url + # but you need a secret to get a token and you need this function to get a secret + try: + res = await self._req("POST", + self.gsUrl + "/gsql/v1/statements", + # quote_plus would not work with the new endpoint + data=query.encode("utf-8"), + authMode="pwd", resKey=None, skipCheck=True, + jsonResponse=False, + headers={"Content-Type": "text/plain"}) + + except httpx.HTTPError as e: + if e.response.status_code == 404: + res = await self._req("POST", + self.gsUrl + "/gsqlserver/gsql/file", + data=quote_plus(query.encode("utf-8")), + authMode="pwd", resKey=None, skipCheck=True, + jsonResponse=False) + else: + raise e + return _parse_gsql(res, query, graphname=graphname, options=options) + + # TODO IMPLEMENT INSTALL_UDF + + async def getUDF(self, ExprFunctions: bool = True, ExprUtil: bool = True, json_out=False) -> Union[str, Tuple[str, str], Dict[str, str]]: + """Get user defined functions (UDF) installed in the database. + See https://docs.tigergraph.com/gsql-ref/current/querying/func/query-user-defined-functions for details on UDFs. + + Args: + ExprFunctions (bool, optional): + Whether to get ExprFunctions. Defaults to True. + ExprUtil (bool, optional): + Whether to get ExprUtil. Defaults to True. + json_out (bool, optional): + Whether to output as JSON. Defaults to False. + Only supported on version >=4.1 + + Returns: + str: If only one of `ExprFunctions` or `ExprUtil` is True, return of the content of that file. + Tuple[str, str]: content of ExprFunctions and content of ExprUtil. + + Endpoints: + - `GET /gsqlserver/gsql/userdefinedfunction?filename={ExprFunctions or ExprUtil}` (In TigerGraph versions 3.x) + - `GET /gsql/v1/udt/files/{ExprFunctions or ExprUtil}` (In TigerGraph versions 4.x) + """ + logger.info("entry: getUDF") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + urls, alt_urls = _prep_get_udf( + ExprFunctions=ExprFunctions, ExprUtil=ExprUtil) + if not await self._version_greater_than_4_0(): + if json_out: + raise TigerGraphException( + "The 'json_out' parameter is only supported in TigerGraph Versions >=4.1.") + urls = alt_urls + responses = {} + + for file_name in urls: + resp = await self._req("GET", f"{self.gsUrl}{urls[file_name]}", resKey="") + responses[file_name] = resp + + return _parse_get_udf(responses, json_out=json_out) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphLoading.py b/pyTigerGraph/pytgasync/pyTigerGraphLoading.py new file mode 100644 index 00000000..15da68a7 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphLoading.py @@ -0,0 +1,83 @@ +"""Loading Job Functions + +The functions on this page run loading jobs on the TigerGraph server. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import logging +import warnings + +from typing import Union + +from pyTigerGraph.common.loading import _prep_run_loading_job_with_file +from pyTigerGraph.pytgasync.pyTigerGraphBase import AsyncPyTigerGraphBase + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphLoading(AsyncPyTigerGraphBase): + + async def runLoadingJobWithFile(self, filePath: str, fileTag: str, jobName: str, sep: str = None, + eol: str = None, timeout: int = 16000, sizeLimit: int = 128000000) -> Union[dict, None]: + """Execute a loading job with the referenced file. + + The file will first be uploaded to the TigerGraph server and the value of the appropriate + FILENAME definition will be updated to point to the freshly uploaded file. + + NOTE: The argument `USING HEADER="true"` in the GSQL loading job may not be enough to + load the file correctly. Remove the header from the data file before using this function. + + Args: + filePath: + File variable name or file path for the file containing the data. + fileTag: + The name of file variable in the loading job (DEFINE FILENAME ). + jobName: + The name of the loading job. + sep: + Data value separator. If your data is JSON, you do not need to specify this + parameter. The default separator is a comma `,`. + eol: + End-of-line character. Only one or two characters are allowed, except for the + special case `\\r\\n`. The default value is `\\n` + timeout: + Timeout in seconds. If set to `0`, use the system-wide endpoint timeout setting. + sizeLimit: + Maximum size for input file in bytes. + + Endpoint: + - `POST /ddl/{graph_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_run_a_loading_job[Run a loading job] + """ + logger.info("entry: runLoadingJobWithFile") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + data, params = _prep_run_loading_job_with_file( + filePath, jobName, fileTag, sep, eol) + + if not data and not params: + # failed to read file + return None + + res = await self._req("POST", self.restppUrl + "/ddl/" + self.graphname, params=params, data=data, + headers={"RESPONSE-LIMIT": str(sizeLimit), "GSQL-TIMEOUT": str(timeout)}) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: runLoadingJobWithFile") + + return res + + async def uploadFile(self, filePath, fileTag, jobName="", sep=None, eol=None, timeout=16000, + sizeLimit=128000000) -> dict: + """DEPRECATED + + Use `runLoadingJobWithFile()` instead. + """ + warnings.warn( + "The `uploadFile()` function is deprecated; use `runLoadingJobWithFile()` instead.", + DeprecationWarning) + + return await self.runLoadingJobWithFile(filePath, fileTag, jobName, sep, eol, timeout, sizeLimit) + + # TODO POST /restpploader/{graph_name} diff --git a/pyTigerGraph/pytgasync/pyTigerGraphPath.py b/pyTigerGraph/pytgasync/pyTigerGraphPath.py new file mode 100644 index 00000000..2f357387 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphPath.py @@ -0,0 +1,132 @@ +"""Path Finding Functions. + +The functions on this page find paths between vertices within the graph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" + +import logging +from typing import Union + +from pyTigerGraph.common.path import _prepare_path_params +from pyTigerGraph.pytgasync.pyTigerGraphBase import AsyncPyTigerGraphBase + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphPath(AsyncPyTigerGraphBase): + + async def shortestPath(self, sourceVertices: Union[dict, tuple, list], + targetVertices: Union[dict, tuple, list], maxLength: int = None, + vertexFilters: Union[list, dict] = None, edgeFilters: Union[list, dict] = None, + allShortestPaths: bool = False) -> dict: + """Finds the shortest path (or all shortest paths) between the source and target vertex sets. + + A vertex set is a set of dictionaries that each has three top-level keys: `v_type`, `v_id`, + and `attributes` (also a dictionary). + + Args: + sourceVertices: + A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; + the source vertices of the shortest paths sought. + targetVertices: + A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; + the target vertices of the shortest paths sought. + maxLength: + The maximum length of a shortest path. Optional, default is 6. + vertexFilters: + An optional list of `(vertexType, condition)` tuples or + `{"type": , "condition": }` dictionaries. + edgeFilters: + An optional list of `(edgeType, condition)` tuples or + `{"type": , "condition": }` dictionaries. + allShortestPaths: + If `True`, the endpoint will return all shortest paths between the source and target. + Default is `False`, meaning that the endpoint will return only one path. + + Returns: + The shortest path between the source and the target. + The returned value is a subgraph: all vertices and edges that are part of the path(s); + i.e. not a (list of individual) path(s). + + Examples: + + [source.wrap,python] + ---- + path = conn.shortestPath(("account", 10), ("person", 50), maxLength=3) + + path = conn.shortestPath(("account", 10), ("person", 50), allShortestPaths=True, + vertexFilters=("transfer", "amount>950"), edgeFilters=("receive", "type=4")) + ---- + + Endpoint: + - `POST /shortestpath/{graphName}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_find_shortest_path[Find the shortest path]. + """ + logger.info("entry: shortestPath") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + data = _prepare_path_params(sourceVertices, targetVertices, maxLength, vertexFilters, + edgeFilters, allShortestPaths) + ret = await self._post(self.restppUrl + "/shortestpath/" + self.graphname, data=data) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: shortestPath") + + return ret + + async def allPaths(self, sourceVertices: Union[dict, tuple, list], + targetVertices: Union[dict, tuple, list], maxLength: int, + vertexFilters: Union[list, dict] = None, edgeFilters: Union[list, dict] = None) -> dict: + """Find all possible paths up to a given maximum path length between the source and target + vertex sets. + + A vertex set is a dict that has three top-level keys: v_type, v_id, attributes (a dict). + + Args: + sourceVertices: + A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; + the source vertices of the shortest paths sought. + targetVertices: + A vertex set (a list of vertices) or a list of `(vertexType, vertexID)` tuples; + the target vertices of the shortest paths sought. + maxLength: + The maximum length of the paths. + vertexFilters: + An optional list of `(vertexType, condition)` tuples or + `{"type": , "condition": }` dictionaries. + edgeFilters: + An optional list of `(edgeType, condition)` tuples or + `{"type": , "condition": }` dictionaries. + + Returns: + All paths between a source vertex (or vertex set) and target vertex (or vertex set). + The returned value is a subgraph: all vertices and edges that are part of the path(s); + i.e. not a (list of individual) path(s). + + Example: + [source.wrap, python] + ---- + path = conn.allPaths(("account", 10), ("person", 50), allShortestPaths=True, + vertexFilters=("transfer", "amount>950"), edgeFilters=("receive", "type=4")) + ---- + + + Endpoint: + - `POST /allpaths/{graphName}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_find_all_paths[Find all paths] + """ + logger.info("entry: allPaths") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + data = _prepare_path_params(sourceVertices, targetVertices, maxLength, vertexFilters, + edgeFilters) + ret = await self._post(self.restppUrl + "/allpaths/" + self.graphname, data=data) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: allPaths") + + return ret diff --git a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py new file mode 100644 index 00000000..2f497d63 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py @@ -0,0 +1,502 @@ +"""Query Functions. + +The functions on this page run installed or interpret queries in TigerGraph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import logging + +from typing import TYPE_CHECKING, Union, Optional + +if TYPE_CHECKING: + import pandas as pd + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.query import ( + _parse_get_installed_queries, + _parse_query_parameters, + _prep_run_installed_query, + _prep_get_statistics +) +from pyTigerGraph.pytgasync.pyTigerGraphGSQL import AsyncPyTigerGraphGSQL + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphQuery(AsyncPyTigerGraphGSQL): + # TODO getQueries() # List _all_ query names + async def showQuery(self, queryName: str) -> str: + """Returns the string of the given GSQL query. + + Args: + queryName (str): + Name of the query to get metadata of. + """ + if logger.level == logging.DEBUG: + logger.debug("entry: showQuery") + res = await self.gsql("USE GRAPH "+self.graphname+" SHOW QUERY "+queryName) + if logger.level == logging.DEBUG: + logger.debug("exit: showQuery") + return res + + async def getQueryMetadata(self, queryName: str) -> dict: + """Returns metadata details about a query. + Specifically, it lists the input parameters in the same order as they exist in the query + and outputs `PRINT` statement syntax. + + Args: + queryName (str): + Name of the query to get metadata of. + + Endpoints: + - `POST /gsqlserver/gsql/queryinfo` (In TigerGraph versions 3.x) + See xref:tigergraph-server:API:built-in-endpoints.adoc_get_query_metadata + - `POST /gsql/v1/queries/signature` (In TigerGraph versions 4.x) + """ + if logger.level == logging.DEBUG: + logger.debug("entry: getQueryMetadata") + if await self._version_greater_than_4_0(): + params = {"graph": self.graphname, "queryName": queryName} + res = await self._req("POST", self.gsUrl+"/gsql/v1/queries/signature", params=params, authMode="pwd", resKey="") + else: + params = {"graph": self.graphname, "query": queryName} + res = await self._req("GET", self.gsUrl+"/gsqlserver/gsql/queryinfo", params=params, authMode="pwd", resKey="") + if not res["error"]: + if logger.level == logging.DEBUG: + logger.debug("exit: getQueryMetadata") + return res + else: + TigerGraphException(res["message"], res["code"]) + + async def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, 'pd.DataFrame']: + """Returns a list of installed queries. + + Args: + fmt: + Format of the results: + - "py": Python objects (default) + - "json": JSON document + - "df": pandas DataFrame + + Returns: + The names of the installed queries. + + TODO This function returns all (installed and non-installed) queries + Modify to return only installed ones + TODO Return with query name as key rather than REST endpoint as key? + """ + logger.info("entry: getInstalledQueries") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = await self.getEndpoints(dynamic=True) + ret = _parse_get_installed_queries(fmt, ret) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getInstalledQueries") + + return ret + + # TODO installQueries() + # POST /gsql/queries/install + # xref:tigergraph-server:API:built-in-endpoints.adoc#_install_a_query[Install a query] + + # TODO checkQueryInstallationStatus() + # GET /gsql/queries/install/{request_id} + # xref:tigergraph-server:API:built-in-endpoints.adoc#_check_query_installation_status[Check query installation status] + + async def runInstalledQuery(self, queryName: str, params: Union[str, dict] = None, + timeout: int = None, sizeLimit: int = None, usePost: bool = False, runAsync: bool = False, + replica: int = None, threadLimit: int = None, memoryLimit: int = None) -> list: + """Runs an installed query. + + The query must be already created and installed in the graph. + Use `getEndpoints(dynamic=True)` or GraphStudio to find out the generated endpoint URL of + the query. Only the query name needs to be specified here. + + Args: + queryName: + The name of the query to be executed. + params: + Query parameters. A string of param1=value1¶m2=value2 format or a dictionary. + See below for special rules for dictionaries. + timeout: + Maximum duration for successful query execution (in milliseconds). + See xref:tigergraph-server:API:index.adoc#_gsql_query_timeout[GSQL query timeout] + sizeLimit: + Maximum size of response (in bytes). + See xref:tigergraph-server:API:index.adoc#_response_size[Response size] + usePost: + Defaults to False. The RESTPP accepts a maximum URL length of 8192 characters. Use POST if additional parameters cause + you to exceed this limit, or if you choose to pass an empty set into a query for database versions >= 3.8 + runAsync: + Run the query in asynchronous mode. + See xref:gsql-ref:querying:query-operations#_detached_mode_async_option[Async operation] + replica: + If your TigerGraph instance is an HA cluster, specify which replica to run the query on. Must be a + value between [1, (cluster replication factor)]. + See xref:tigergraph-server:API:built-in-endpoints#_specify_replica[Specify replica] + threadLimit: + Specify a limit of the number of threads the query is allowed to use on each node of the TigerGraph cluster. + See xref:tigergraph-server:API:built-in-endpoints#_specify_thread_limit[Thread limit] + memoryLimit: + Specify a limit to the amount of memory consumed by the query (in MB). If the limit is exceeded, the query will abort automatically. + Supported in database versions >= 3.8. + See xref:tigergraph-server:system-management:memory-management#_by_http_header[Memory limit] + + Returns: + The output of the query, a list of output elements (vertex sets, edge sets, variables, + accumulators, etc. + + Notes: + When specifying parameter values in a dictionary: + + - For primitive parameter types use + `"key": value` + - For `SET` and `BAG` parameter types with primitive values, use + `"key": [value1, value2, ...]` + - For `VERTEX` use + `"key": primary_id` + - For `VERTEX` (no vertex type specified) use + `"key": (primary_id, "vertex_type")` + - For `SET>` use + `"key": [primary_id1, primary_id2, ...]` + - For `SET` (no vertex type specified) use + `"key": [(primary_id1, "vertex_type1"), (primary_id2, "vertex_type2"), ...]` + + Endpoints: + - `GET /query/{graph_name}/{query_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_run_an_installed_query_get[Run an installed query (GET)] + - `POST /query/{graph_name}/{query_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_run_an_installed_query_post[Run an installed query (POST)] + """ + logger.info("entry: runInstalledQuery") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + headers, res_key = _prep_run_installed_query(timeout=timeout, sizeLimit=sizeLimit, runAsync=runAsync, + replica=replica, threadLimit=threadLimit, memoryLimit=memoryLimit) + + if usePost: + ret = await self._req("POST", self.restppUrl + "/query/" + self.graphname + "/" + queryName, + data=params, headers=headers, resKey=res_key, jsonData=True) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: runInstalledQuery (POST)") + + return ret + else: + if isinstance(params, dict): + params = _parse_query_parameters(params) + ret = await self._req("GET", self.restppUrl + "/query/" + self.graphname + "/" + queryName, + params=params, headers=headers, resKey=res_key) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: runInstalledQuery (GET)") + + return ret + + async def checkQueryStatus(self, requestId: str = ""): + """Checks the status of the queries running on the graph specified in the connection. + + Args: + requestId (str, optional): + String ID of the request. If empty, returns all running requests. + See xref:tigergraph-server:API:built-in-endpoints.adoc#_check_query_status_detached_mode[Check query status (detached mode)] + + Endpoint: + - `GET /query_status/{graph_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_check_query_status_detached_mode[Check query status (detached mode)] + """ + if requestId != "": + return await self._req("GET", self.restppUrl + "/query_status?graph_name="+self.graphname+"&requestid="+requestId) + else: + return await self._req("GET", self.restppUrl + "/query_status?graph_name="+self.graphname+"&requestid=all") + + async def getQueryResult(self, requestId: str = ""): + """Gets the result of a detached query. + + Args: + requestId (str): + String ID of the request. + See xref:tigergraph-server:API:built-in-endpoints.adoc#_check_query_results_detached_mode[Check query results (detached mode)] + """ + return await self._req("GET", self.restppUrl + "/query_result?graph_name="+self.graphname+"&requestid="+requestId) + + async def runInterpretedQuery(self, queryText: str, params: Union[str, dict] = None) -> list: + """Runs an interpreted query. + + Use ``$graphname`` or ``@graphname@`` in the ``FOR GRAPH`` clause to avoid hardcoding the + name of the graph in your app. It will be replaced by the actual graph name. + + Args: + queryText: + The text of the GSQL query that must be provided in this format: + + [source.wrap, gsql] + ---- + INTERPRET QUERY () FOR GRAPH { + + } + ---- + + params: + A string of `param1=value1¶m2=value2...` format or a dictionary. + See below for special rules for dictionaries. + + Returns: + The output of the query, a list of output elements such as vertex sets, edge sets, variables and + accumulators. + + Notes: + When specifying parameter values in a dictionary: + + - For primitive parameter types use + `"key": value` + - For `SET` and `BAG` parameter types with primitive values, use + `"key": [value1, value2, ...]` + - For `VERTEX` use + `"key": primary_id` + - For `VERTEX` (no vertex type specified) use + `"key": (primary_id, "vertex_type")` + - For `SET>` use + `"key": [primary_id1, primary_id2, ...]` + - For `SET` (no vertex type specified) use + `"key": [(primary_id1, "vertex_type1"), (primary_id2, "vertex_type2"), ...]` + + + Endpoints: + - `POST /gsqlserver/interpreted_query` (In TigerGraph versions 3.x) + See xref:tigergraph-server:API:built-in-endpoints.adoc#_run_an_interpreted_query[Run an interpreted query] + - `POST /gsql/v1/queries/interpret` (In TigerGraph versions 4.x) + + TODO Add "GSQL-TIMEOUT: " and "RESPONSE-LIMIT: " + plus parameters if applicable to interpreted queries (see runInstalledQuery() above) + """ + logger.info("entry: runInterpretedQuery") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + queryText = queryText.replace("$graphname", self.graphname) + queryText = queryText.replace("@graphname@", self.graphname) + if isinstance(params, dict): + params = _parse_query_parameters(params) + + if await self._version_greater_than_4_0(): + ret = await self._req("POST", self.gsUrl + "/gsql/v1/queries/interpret", + params=params, data=queryText, authMode="pwd", + headers={'Content-Type': 'text/plain'}) + else: + ret = await self._req("POST", self.gsUrl + "/gsqlserver/interpreted_query", data=queryText, + params=params, authMode="pwd") + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: runInterpretedQuery") + + return ret + + async def getRunningQueries(self) -> dict: + """Reports the statistics of currently running queries on the graph. + """ + if logger.level == logging.DEBUG: + logger.debug("entry: getRunningQueries") + res = await self._req("GET", self.restppUrl+"/showprocesslist/"+self.graphname, resKey="") + if not res["error"]: + if logger.level == logging.DEBUG: + logger.debug("exit: getRunningQueries") + return res + else: + raise TigerGraphException(res["message"], res["code"]) + + async def abortQuery(self, request_id: Union[str, list] = None, url: str = None): + """This function safely abortsa a selected query by ID or all queries of an endpoint by endpoint URL of a graph. + If neither `request_id` or `url` are specified, all queries currently running on the graph are aborted. + + Args: + request_id (str, list, optional): + The ID(s) of the query(s) to abort. If set to "all", it will abort all running queries. + url + """ + if logger.level == logging.DEBUG: + logger.debug("entry: abortQuery") + params = {} + if request_id: + params["requestid"] = request_id + if url: + params["url"] = url + res = await self._get(self.restppUrl+"/abortquery/"+self.graphname, params=params, resKey="") + if not res["error"]: + if logger.level == logging.DEBUG: + logger.debug("exit: abortQuery") + return res + else: + raise TigerGraphException(res["message"], res["code"]) + + async def getStatistics(self, seconds: int = 10, segments: int = 10) -> dict: + """Retrieves real-time query performance statistics over the given time period. + + Args: + seconds: + The duration of statistic collection period (the last _n_ seconds before the function + call). + segments: + The number of segments of the latency distribution (shown in results as + `LatencyPercentile`). By default, segments is `10`, meaning the percentile range 0-100% + will be divided into ten equal segments: 0%-10%, 11%-20%, etc. + This argument must be an integer between 1 and 100. + + Endpoint: + - `GET /statistics/{graph_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_query_performance[Show query performance] + """ + logger.info("entry: getStatistics") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + seconds, segments = _prep_get_statistics(self, seconds, segments) + ret = await self._req("GET", self.restppUrl + "/statistics/" + self.graphname + "?seconds=" + + str(seconds) + "&segment=" + str(segments), resKey="") + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getStatistics") + + return ret + + async def describeQuery(self, queryName: str, queryDescription: str, parameterDescriptions: dict = {}): + """Add a query description and parameter descriptions. Only supported on versions of TigerGraph >= 4.0.0. + + Args: + queryName: + The name of the query to describe. + queryDescription: + A description of the query. + parameterDescriptions (optional): + A dictionary of parameter descriptions. The keys are the parameter names and the values are the descriptions. + + Returns: + The response from the database. + + Endpoints: + - `PUT /gsqlserver/gsql/description?graph={graph_name}` (In TigerGraph version 4.0) + - `PUT /gsql/v1/description?graph={graph_name}` (In TigerGraph versions >4.0) + """ + logger.info("entry: describeQuery") + self.ver = await self.getVer() + major_ver, minor_ver, patch_ver = self.ver.split(".") + if int(major_ver) < 4: + logger.info("exit: describeQuery") + raise TigerGraphException( + "This function is only supported on versions of TigerGraph >= 4.0.0.", 0) + + if parameterDescriptions: + params = {"queries": [ + {"queryName": queryName, + "description": queryDescription, + "parameters": [{"paramName": k, "description": v} for k, v in parameterDescriptions.items()]} + ]} + else: + params = {"queries": [ + {"queryName": queryName, + "description": queryDescription} + ]} + if logger.level == logging.DEBUG: + logger.debug("params: " + params) + if await self._version_greater_than_4_0(): + res = await self._put(self.gsUrl+"/gsql/v1/description?graph="+self.graphname, data=params, authMode="pwd", jsonData=True) + else: + res = await self._put(self.gsUrl+"/gsqlserver/gsql/description?graph="+self.graphname, data=params, authMode="pwd", jsonData=True) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: describeQuery") + + return res + + async def getQueryDescription(self, queryName: Optional[Union[str, list]] = "all"): + """Get the description of a query. Only supported on versions of TigerGraph >= 4.0.0. + + Args: + queryName: + The name of the query to get the description of. + If multiple query descriptions are desired, pass a list of query names. + If set to "all", returns the description of all queries. + + Returns: + The description of the query(ies). + + Endpoints: + - `GET /gsqlserver/gsql/description?graph={graph_name}` (In TigerGraph version 4.0) + - `GET /gsql/v1/description?graph={graph_name}` (In TigerGraph versions >4.0) + """ + logger.info("entry: getQueryDescription") + self.ver = await self.getVer() + major_ver, minor_ver, patch_ver = self.ver.split(".") + if int(major_ver) < 4: + logger.info("exit: getQueryDescription") + raise TigerGraphException( + "This function is only supported on versions of TigerGraph >= 4.0.0.", 0) + + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + if isinstance(queryName, list): + queryName = ",".join(queryName) + + if await self._version_greater_than_4_0(): + res = await self._get(self.gsUrl+"/gsql/v1/description?graph="+self.graphname+"&query="+queryName, authMode="pwd", resKey=None) + else: + res = await self._get(self.gsUrl+"/gsqlserver/gsql/description?graph="+self.graphname+"&query="+queryName, authMode="pwd", resKey=None) + if not res["error"]: + if logger.level == logging.DEBUG: + logger.debug("exit: getQueryDescription") + return res["results"]["queries"] + else: + raise TigerGraphException(res["message"], res["code"]) + + async def dropQueryDescription(self, queryName: str, dropParamDescriptions: bool = True): + """Drop the description of a query. Only supported on versions of TigerGraph >= 4.0.0. + + Args: + queryName: + The name of the query to drop the description of. + If set to "*", drops the description of all queries. + dropParamDescriptions: + Whether to drop the parameter descriptions as well. Defaults to True. + + Returns: + The response from the database. + + Endpoints: + - `DELETE /gsqlserver/gsql/description?graph={graph_name}` (In TigerGraph version 4.0) + - `DELETE /gsql/v1/description?graph={graph_name}` (In TigerGraph versions >4.0) + """ + logger.info("entry: dropQueryDescription") + self.ver = await self.getVer() + major_ver, minor_ver, patch_ver = self.ver.split(".") + if int(major_ver) < 4: + logger.info("exit: describeQuery") + raise TigerGraphException( + "This function is only supported on versions of TigerGraph >= 4.0.0.", 0) + + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + if dropParamDescriptions: + params = {"queries": [queryName], + "queryParameters": [queryName+".*"]} + else: + params = {"queries": [queryName]} + print(params) + if await self._versionGreaterThan4_0(): + res = await self._delete(self.gsUrl+"/gsql/v1/description?graph="+self.graphname, authMode="pwd", data=params, jsonData=True, resKey=None) + else: + res = await self._delete(self.gsUrl+"/gsqlserver/gsql/description?graph="+self.graphname, authMode="pwd", data=params, jsonData=True, resKey=None) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: dropQueryDescription") + + return res diff --git a/pyTigerGraph/pytgasync/pyTigerGraphSchema.py b/pyTigerGraph/pytgasync/pyTigerGraphSchema.py new file mode 100644 index 00000000..af5ce3e0 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphSchema.py @@ -0,0 +1,189 @@ +"""Schema Functions. + +The functions in this page retrieve information about the graph schema. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" + +import logging +import re + +from typing import Union + +from pyTigerGraph.pytgasync.pyTigerGraphBase import AsyncPyTigerGraphBase +from pyTigerGraph.common.schema import ( + _prep_upsert_data, + _prep_get_endpoints +) + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphSchema(AsyncPyTigerGraphBase): + + async def _getUDTs(self) -> dict: + """Retrieves all User Defined Types (UDTs) of the graph. + + Returns: + The list of names of UDTs (defined in the global scope, i.e. not in queries). + + Endpoint: + GET /gsqlserver/gsql/udtlist (In TigerGraph versions 3.x) + GET /gsql/v1/udt/tuples (In TigerGraph versions 4.x) + """ + logger.info("entry: _getUDTs") + + if await self._version_greater_than_4_0(): + res = await self._req("GET", self.gsUrl + "/gsql/v1/udt/tuples?graph=" + self.graphname, + authMode="pwd") + else: + res = await self._req("GET", self.gsUrl + "/gsqlserver/gsql/udtlist?graph=" + self.graphname, + authMode="pwd") + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: _getUDTs") + + return res + + async def getSchema(self, udts: bool = True, force: bool = False) -> dict: + """Retrieves the schema metadata (of all vertex and edge type and, if not disabled, the + User-Defined Type details) of the graph. + + Args: + udts: + If `True`, the output includes User-Defined Types in the schema details. + force: + If `True`, retrieves the schema metadata again, otherwise returns a cached copy of + the schema metadata (if they were already fetched previously). + + Returns: + The schema metadata. + + Endpoint: + - `GET /gsqlserver/gsql/schema` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_graph_schema_metadata[Show graph schema metadata] + - `GET /gsql/v1/schema/graphs/{graph_name}` + """ + logger.info("entry: getSchema") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + if not self.schema or force: + if await self._version_greater_than_4_0(): + self.schema = await self._req("GET", self.gsUrl + "/gsql/v1/schema/graphs/" + self.graphname, + authMode="pwd") + else: + self.schema = await self._req("GET", self.gsUrl + "/gsqlserver/gsql/schema?graph=" + self.graphname, + authMode="pwd") + if udts and ("UDTs" not in self.schema or force): + self.schema["UDTs"] = await self._getUDTs() + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(self.schema)) + logger.info("exit: getSchema") + + return self.schema + + async def upsertData(self, data: Union[str, object], atomic: bool = False, ackAll: bool = False, + newVertexOnly: bool = False, vertexMustExist: bool = False, + updateVertexOnly: bool = False) -> dict: + """Upserts data (vertices and edges) from a JSON file or a file with equivalent object structure. + + Args: + data: + The data of vertex and edge instances, in a specific format. + atomic: + The request is an atomic transaction. An atomic transaction means that updates to + the database contained in the request are all-or-nothing: either all changes are + successful, or none are successful. + ackAll: + If `True`, the request will return after all GPE instances have acknowledged the + POST. Otherwise, the request will return immediately after RESTPP processes the POST. + newVertexOnly: + If `True`, the request will only insert new vertices and not update existing ones. + vertexMustExist: + If `True`, the request will only insert an edge if both the `FROM` and `TO` vertices + of the edge already exist. If the value is `False`, the request will always insert new + edges and create the necessary vertices with default values for their attributes. + Note that this parameter does not affect vertices. + updateVertexOnly: + If `True`, the request will only update existing vertices and not insert new + vertices. + + Returns: + The result of upsert (number of vertices and edges accepted/upserted). + + Endpoint: + - `POST /graph/{graph_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_upsert_data_to_graph[Upsert data to graph] + """ + logger.info("entry: upsertData") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + data, headers, params = _prep_upsert_data(data=data, atomic=atomic, ackAll=ackAll, newVertexOnly=newVertexOnly, + vertexMustExist=vertexMustExist, updateVertexOnly=updateVertexOnly) + + res = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, headers=headers, data=data, + params=params) + res = res[0] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: getSchema") + + return res + + async def getEndpoints(self, builtin: bool = False, dynamic: bool = False, + static: bool = False) -> dict: + """Lists the REST++ endpoints and their parameters. + + Args: + builtin: + List the TigerGraph-provided REST++ endpoints. + dynamic: + List endpoints for user-installed queries. + static: + List static endpoints. + + If none of the above arguments are specified, all endpoints are listed. + + Endpoint: + - `GET /endpoints/{graph_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_list_all_endpoints[List all endpoints] + """ + logger.info("entry: getEndpoints") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + bui, dyn, sta, url, ret = _prep_get_endpoints( + restppUrl=self.restppUrl, + graphname=self.graphname, + builtin=builtin, + dynamic=dynamic, + static=static + ) + if bui: + eps = {} + res = await self._req("GET", url + "builtin=true", resKey="") + for ep in res: + if not re.search(" /graph/", ep) or re.search(" /graph/{graph_name}/", ep): + eps[ep] = res[ep] + ret.update(eps) + if dyn: + eps = {} + res = await self._req("GET", url + "dynamic=true", resKey="") + for ep in res: + if re.search("^GET /query/" + self.graphname, ep): + eps[ep] = res[ep] + ret.update(eps) + if sta: + ret.update(await self._req("GET", url + "static=true", resKey="")) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getEndpoints") + + return ret + + # TODO GET /rebuildnow/{graph_name} diff --git a/pyTigerGraph/pytgasync/pyTigerGraphUDT.py b/pyTigerGraph/pytgasync/pyTigerGraphUDT.py new file mode 100644 index 00000000..7976e147 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphUDT.py @@ -0,0 +1,68 @@ +"""User Defined Tuple (UDT) Functions. + +The functions on this page retrieve information about user-defined tuples (UDT) for the graph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" + +import logging + +from pyTigerGraph.pytgasync.pyTigerGraphSchema import AsyncPyTigerGraphSchema +# from pyTigerGraph.pyTigerGraphUDT import pyTigerGraphUDT + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphUDT(AsyncPyTigerGraphSchema): + + async def getUDTs(self) -> list: + """Returns the list of User-Defined Tuples (names only). + + For information on UDTs see xref:gsql-ref:ddl-and-loading:system-and-language-basics.adoc#typedef-tuple[User-Defined Tuple] + + Returns: + The list of names of UDTs (defined in the global scope, i.e. not in queries). + """ + logger.info("entry: getUDTs") + + ret = [] + for udt in await self._getUDTs(): + ret.append(udt["name"]) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getUDTs") + + return ret + + async def getUDT(self, udtName: str) -> list: + """Returns the details of a specific User-Defined Tuple (defined in the global scope). + + For information on UDTs see xref:gsql-ref:ddl-and-loading:system-and-language-basics.adoc#typedef-tuple[User-Defined Tuple] + + Args: + udtName: + The name of the User-Defined Tuple. + + Returns: + The metadata (the details of the fields) of the UDT. + + """ + logger.info("entry: getUDT") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + for udt in await self._getUDTs(): + if udt["name"] == udtName: + ret = udt["fields"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getUDT (found)") + + return ret + + if logger.level == logging.DEBUG: + logger.warning("UDT `" + udtName + "` was not found") + logger.info("exit: getUDT (not found)") + + return [] # UDT was not found diff --git a/pyTigerGraph/pytgasync/pyTigerGraphUtils.py b/pyTigerGraph/pytgasync/pyTigerGraphUtils.py new file mode 100644 index 00000000..d273160b --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphUtils.py @@ -0,0 +1,224 @@ +"""Utility Functions. + +Utility functions for pyTigerGraph. +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import json +import logging + +from typing import Any, Union, TYPE_CHECKING +from urllib.parse import urlparse + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.common.util import ( + _parse_get_license_info, + _prep_get_system_metrics +) +from pyTigerGraph.pytgasync.pyTigerGraphBase import AsyncPyTigerGraphBase + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphUtils(AsyncPyTigerGraphBase): + + async def echo(self, usePost: bool = False) -> str: + """Pings the database. + + Args: + usePost: + Use POST instead of GET + + Returns: + "Hello GSQL" if everything was OK. + + Endpoint: + - `GET /echo` + - `POST /echo` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_echo[Echo] + """ + logger.info("entry: echo") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + if usePost: + ret = str(await self._req("POST", self.restppUrl + "/echo/", resKey="message")) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: echo (POST)") + + return ret + + ret = str(await self._req("GET", self.restppUrl + "/echo/", resKey="message")) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: echo (GET)") + + return ret + + async def getLicenseInfo(self) -> dict: + """Returns the expiration date and remaining days of the license. + + Returns: + Returns license details. For an evaluation/trial deployment, returns an information message and -1 remaining days. + + """ + logger.info("entry: getLicenseInfo") + + res = await self._req("GET", self.restppUrl + "/showlicenseinfo", resKey="", skipCheck=True) + ret = _parse_get_license_info(res) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getLicenseInfo") + + return ret + + async def ping(self) -> dict: + """Public health check endpoint. + + Returns: + Returns a JSON object with a key of "message" and a value of "pong" + """ + if logger.level == logging.DEBUG: + logger.debug("entry: ping") + res = await self._req("GET", self.gsUrl+"/api/ping", resKey="") + if not res["error"]: + if logger.level == logging.DEBUG: + logger.debug("exit: ping") + return res + else: + raise TigerGraphException(res["message"], res["code"]) + + async def getSystemMetrics(self, from_ts: int = None, to_ts: int = None, latest: int = None, what: str = None, who: str = None, where: str = None): + """Monitor system usage metrics. + + Args: + from_ts (int, optional): + The epoch timestamp that indicates the start of the time filter. + to_ts (int, optional): + The epoch timestamp that indicates the end of the time filter. + latest (int, optional): + Number of datapoints to return. If provided, `from_ts` and `to_ts` will be ignored. + what (str, optional): + Name of the metric to filter for. Possible choices are: + - "cpu": Percentage of CPU usage by component + - "mem": Memory usage in megabytes by component + - "diskspace": Disk usage in megabytes by directory + - "network": Network traffic in bytes since the service started + - "qps": Number of requests per second by endpoint + - "servicestate": The state of the service, either online 1 or offline 0 (Only avaliable in version <4.1) + - "connection": Number of open TCP connections (Only avaliable in version <4.1) + who (str, optional): + Name of the component that reported the datapoint. (Only avaliable in version <4.1) + where (str, optional): + Name of the node that reported the datapoint. + + Returns: + JSON object of datapoints collected. + Note: Output format differs between 3.x and 4.x versions of TigerGraph. + + Endpoints: + See xref:tigergraph-server:API:built-in-endpoints.adoc#_show_component_versions[Show component versions] + - `GET /ts3/api/datapoints` (In TigerGraph versions 3.x) + See xref:tigergraph-server:API:built-in-endpoints.adoc#_monitor_system_metrics_ts3_deprecated + - `POST /informant/metrics/get/{metrics_category}` (In TigerGraph versions 4.x) + ee xref:tigergraph-server:API:built-in-endpoints.adoc#_monitor_system_metrics_by_category + """ + if logger.level == logging.DEBUG: + logger.debug("entry: getSystemMetrics") + + params, _json = _prep_get_system_metrics( + from_ts=from_ts, to_ts=to_ts, latest=latest, who=who, where=where) + + # Couldn't be placed in prep since version checking requires await statements + if what: + if await self._version_greater_than_4_0(): + if what == "servicestate" or what == "connection": + raise TigerGraphException( + "This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", 0) + if what == "cpu" or what == "mem": + what = "cpu-memory" # in >=4.1 cpu and mem have been conjoined into one category + params["what"] = what + # in >=4.1 the datapoints endpoint has been removed and replaced + if await self._version_greater_than_4_0(): + res = await self._req("POST", self.gsUrl+"/informant/metrics/get/"+what, data=_json, jsonData=True, resKey="") + else: + res = await self._req("GET", self.gsUrl+"/ts3/api/datapoints", authMode="pwd", params=params, resKey="") + if logger.level == logging.DEBUG: + logger.debug("exit: getSystemMetrics") + return res + + async def getQueryPerformance(self, seconds: int = 10): + """Returns real-time query performance statistics over the given time period, as specified by the seconds parameter. + + Args: + seconds (int, optional): + Seconds are measured up to 60, so the seconds parameter must be a positive integer less than or equal to 60. + Defaults to 10. + """ + if logger.level == logging.DEBUG: + logger.debug("entry: getQueryPerformance") + params = {} + if seconds: + params["seconds"] = seconds + res = await self._get(self.restppUrl+"/statistics/"+self.graphname, params=params, resKey="") + if logger.level == logging.DEBUG: + logger.debug("exit: getQueryPerformance") + return res + + async def getServiceStatus(self, request_body: dict): + """Returns the status of the TigerGraph services specified in the request. + Supported on databases versions 3.4 and above. + + Args: + request_body (dict): + Must be formatted as specified here: https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_show_service_status + """ + if logger.level == logging.DEBUG: + logger.debug("entry: getServiceStatus") + res = await self._req("POST", self.gsUrl+"/informant/current-service-status", data=json.dumps(request_body), resKey="") + if logger.level == logging.DEBUG: + logger.debug("exit: getServiceStatus") + return res + + async def rebuildGraph(self, threadnum: int = None, vertextype: str = "", segid: str = "", path: str = "", force: bool = False): + """Rebuilds the graph engine immediately. See https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_rebuild_graph_engine for more information. + + Args: + threadnum (int, optional): + Number of threads to execute the rebuild. + vertextype (str, optional): + Vertex type to perform the rebuild for. Will perform for all vertex types if not specified. + segid (str, optional): + Segment ID of the segments to rebuild. If not provided, all segments will be rebuilt. + In general, it is recommneded not to provide this parameter and rebuild all segments. + path (str, optional): + Path to save the summary of the rebuild to. If not provided, the default path is "/tmp/rebuildnow". + force (bool, optional): + Boolean value that indicates whether to perform rebuilds for segments for which there are no records of new data. + Normally, a rebuild would skip such segments, but if force is set true, the segments will not be skipped. + Returns: + JSON response with message containing the path to the summary file. + """ + if logger.level == logging.DEBUG: + logger.debug("entry: rebuildGraph") + params = {} + if threadnum: + params["threadnum"] = threadnum + if vertextype: + params["vertextype"] = vertextype + if segid: + params["segid"] = segid + if path: + params["path"] = path + if force: + params["force"] = force + res = await self._req("GET", self.restppUrl+"/rebuildnow/"+self.graphname, params=params, resKey="") + if not res["error"]: + if logger.level == logging.DEBUG: + logger.debug("exit: rebuildGraph") + return res + else: + raise TigerGraphException(res["message"], res["code"]) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphVertex.py b/pyTigerGraph/pytgasync/pyTigerGraphVertex.py new file mode 100644 index 00000000..c342c012 --- /dev/null +++ b/pyTigerGraph/pytgasync/pyTigerGraphVertex.py @@ -0,0 +1,758 @@ +"""Vertex Functions. + +Functions to upsert, retrieve and delete vertices. + +All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. +""" +import json +import logging +import warnings + +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + import pandas as pd + +from pyTigerGraph.common.vertex import ( + _parse_get_vertex_count, + _prep_upsert_vertex_dataframe, + _prep_get_vertices, + _prep_get_vertices_by_id, + _parse_get_vertex_stats, + _prep_del_vertices, + _prep_del_vertices_by_id +) + +from pyTigerGraph.common.schema import _upsert_attrs +from pyTigerGraph.common.vertex import vertexSetToDataFrame as _vS2DF +from pyTigerGraph.common.util import _safe_char + +from pyTigerGraph.pytgasync.pyTigerGraphSchema import AsyncPyTigerGraphSchema +from pyTigerGraph.pytgasync.pyTigerGraphUtils import AsyncPyTigerGraphUtils + +logger = logging.getLogger(__name__) + + +class AsyncPyTigerGraphVertex(AsyncPyTigerGraphUtils, AsyncPyTigerGraphSchema): + + async def getVertexTypes(self, force: bool = False) -> list: + """Returns the list of vertex type names of the graph. + + Args: + force: + If `True`, forces the retrieval the schema metadata again, otherwise returns a + cached copy of vertex type metadata (if they were already fetched previously). + + Returns: + The list of vertex types defined in the current graph. + """ + logger.info("entry: getVertexTypes") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = [] + vertexTypes = await self.getSchema(force=force) + vertexTypes = vertexTypes["VertexTypes"] + for vt in vertexTypes: + ret.append(vt["Name"]) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVertexTypes") + + return ret + + async def getVertexAttrs(self, vertexType: str) -> list: + """Returns the names and types of the attributes of the vertex type. + + Args: + vertexType: + The name of the vertex type. + + Returns: + A list of (attribute_name, attribute_type) tuples. + The format of attribute_type is one of + - "scalar_type" + - "complex_type(scalar_type)" + - "map_type(key_type,value_type)" + and it is a string. + """ + logger.info("entry: getAttributes") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + et = await self.getVertexType(vertexType) + ret = [] + + for at in et["Attributes"]: + ret.append( + (at["AttributeName"], self._getAttrType(at["AttributeType"]))) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getAttributes") + + return ret + + async def getVertexType(self, vertexType: str, force: bool = False) -> dict: + """Returns the details of the specified vertex type. + + Args: + vertexType: + The name of the vertex type. + force: + If `True`, forces the retrieval the schema metadata again, otherwise returns a + cached copy of vertex type details (if they were already fetched previously). + + Returns: + The metadata of the vertex type. + """ + logger.info("entry: getVertexType") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + vertexTypes = await self.getSchema(force=force) + vertexTypes = vertexTypes["VertexTypes"] + for vt in vertexTypes: + if vt["Name"] == vertexType: + if logger.level == logging.DEBUG: + logger.debug("return: " + str(vt)) + logger.info("exit: getVertexType (found)") + + return vt + + logger.warning("Vertex type `" + vertexType + "` was not found.") + logger.info("exit: getVertexType (not found)") + + return {} # Vertex type was not found + + async def getVertexCount(self, vertexType: Union[str, list] = "*", where: str = "", realtime: bool = False) -> Union[int, dict]: + """Returns the number of vertices of the specified type. + + Args: + vertexType (Union[str, list], optional): + The name of the vertex type. If `vertexType` == "*", then count the instances of all + vertex types (`where` cannot be specified in this case). Defaults to "*". + where (str, optional): + A comma separated list of conditions that are all applied on each vertex's + attributes. The conditions are in logical conjunction (i.e. they are "AND'ed" + together). Defaults to "". + realtime (bool, optional): + Whether to get the most up-to-date number by force. When there are frequent updates happening, + a slightly outdated number (up to 30 seconds delay) might be fetched. Set `realtime=True` to + force the system to recount the vertices, which will get a more up-to-date result but will + also take more time. This parameter only works with TigerGraph DB 3.6 and above. + Defaults to False. + + Returns: + - A dictionary of : pairs if `vertexType` is a list or "*". + - An integer of vertex count if `vertexType` is a single vertex type. + + Uses: + - If `vertexType` is specified only: count of the instances of the given vertex type(s). + - If `vertexType` and `where` are specified: count of the instances of the given vertex + type after being filtered by `where` condition(s). + + Raises: + `TigerGraphException` when "*" is specified as vertex type and a `where` condition is + provided; or when invalid vertex type name is specified. + + Endpoints: + - `GET /graph/{graph_name}/vertices` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_list_vertices[List vertices] + - `POST /builtins` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_run_built_in_functions_on_graph[Run built-in functions] + """ + logger.info("entry: getVertexCount") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + # If WHERE condition is not specified, use /builtins else use /vertices + if isinstance(vertexType, str) and vertexType != "*": + if where: + res = await self._req("GET", self.restppUrl + "/graph/" + self.graphname + "/vertices/" + vertexType + + "?count_only=true" + "&filter=" + where) + res = res[0]["count"] + else: + res = await self._req("POST", self.restppUrl + "/builtins/" + self.graphname + ("?realtime=true" if realtime else ""), + data={"function": "stat_vertex_number", + "type": vertexType}, + jsonData=True) + + res = res[0]["count"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(res)) + logger.info("exit: getVertexCount (1)") + + return res + + res = await self._req("POST", self.restppUrl + "/builtins/" + self.graphname + ("?realtime=true" if realtime else ""), + data={"function": "stat_vertex_number", "type": "*"}, + jsonData=True) + + ret = _parse_get_vertex_count(res, vertexType, where) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVertexCount (2)") + + return ret + + async def upsertVertex(self, vertexType: str, vertexId: str, attributes: dict = None) -> int: + """Upserts a vertex. + + Data is upserted: + + - If vertex is not yet present in graph, it will be created. + - If it's already in the graph, its attributes are updated with the values specified in + the request. An optional operator controls how the attributes are updated. + + Args: + vertexType: + The name of the vertex type. + vertexId: + The primary ID of the vertex to be upserted. + attributes: + The attributes of the vertex to be upserted; a dictionary in this format: + ``` + {: |(, ), …} + ``` + Example: + ``` + {"name": "Thorin", points: (10, "+"), "bestScore": (67, "max")} + ``` + For valid values of `` see xref:tigergraph-server:API:built-in-endpoints.adoc#_operation_codes[Operation codes]. + + Returns: + A single number of accepted (successfully upserted) vertices (0 or 1). + + Endpoint: + - `POST /graph/{graph_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_upsert_data_to_graph[Upsert data to graph] + """ + logger.info("entry: upsertVertex") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + vals = _upsert_attrs(attributes) + data = json.dumps({"vertices": {vertexType: {vertexId: vals}}}) + + ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data) + ret = ret[0]["accepted_vertices"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: upsertVertex") + + return ret + + async def upsertVertices(self, vertexType: str, vertices: list) -> int: + """Upserts multiple vertices (of the same type). + + See the description of ``upsertVertex`` for generic information. + + Args: + vertexType: + The name of the vertex type. + vertices: + A list of tuples in this format: + + [source.wrap,json] + ---- + [ + (, {: , …}), + (, {: (, ), …}), + ⋮ + ] + ---- + + Example: + + [source.wrap, json] + ---- + [ + (2, {"name": "Balin", "points": (10, "+"), "bestScore": (67, "max")}), + (3, {"name": "Dwalin", "points": (7, "+"), "bestScore": (35, "max")}) + ] + ---- + + For valid values of `` see xref:tigergraph-server:API:built-in-endpoints.adoc#_operation_codes[Operation codes]. + + Returns: + A single number of accepted (successfully upserted) vertices (0 or positive integer). + + Endpoint: + - `POST /graph/{graph_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_upsert_data_to_graph[Upsert data to graph] + """ + logger.info("entry: upsertVertices") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + data = {} + for v in vertices: + vals = _upsert_attrs(v[1]) + data[v[0]] = vals + data = json.dumps({"vertices": {vertexType: data}}) + + ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data) + ret = ret[0]["accepted_vertices"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: upsertVertices") + + return ret + + async def upsertVertexDataFrame(self, df: 'pd.DataFrame', vertexType: str, v_id: bool = None, + attributes: dict = None) -> int: + """Upserts vertices from a Pandas DataFrame. + + Args: + df: + The DataFrame to upsert. + vertexType: + The type of vertex to upsert data to. + v_id: + The field name where the vertex primary id is given. If omitted the dataframe index + would be used instead. + attributes: + A dictionary in the form of `{target: source}` where source is the column name in + the dataframe and target is the attribute name in the graph vertex. When omitted, + all columns would be upserted with their current names. In this case column names + must match the vertex's attribute names. + + Returns: + The number of vertices upserted. + """ + logger.info("entry: upsertVertexDataFrame") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + json_up = _prep_upsert_vertex_dataframe( + df=df, v_id=v_id, attributes=attributes) + ret = await self.upsertVertices(vertexType=vertexType, vertices=json_up) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: upsertVertexDataFrame") + + return ret + + async def getVertices(self, vertexType: str, select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", fmt: str = "py", withId: bool = True, + withType: bool = False, timeout: int = 0) -> Union[dict, str, 'pd.DataFrame']: + """Retrieves vertices of the given vertex type. + + *Note*: + The primary ID of a vertex instance is NOT an attribute, thus cannot be used in + `select`, `where` or `sort` parameters (unless the `WITH primary_id_as_attribute` clause + was used when the vertex type was created). / + Use `getVerticesById()` if you need to retrieve vertices by their primary ID. + + Args: + vertexType: + The name of the vertex type. + select: + Comma separated list of vertex attributes to be retrieved. + where: + Comma separated list of conditions that are all applied on each vertex' attributes. + The conditions are in logical conjunction (i.e. they are "AND'ed" together). + sort: + Comma separated list of attributes the results should be sorted by. + Must be used with `limit`. + limit: + Maximum number of vertex instances to be returned (after sorting). + Must be used with `sort`. + fmt: + Format of the results: + - "py": Python objects + - "json": JSON document + - "df": pandas DataFrame + withId: + (When the output format is "df") should the vertex ID be included in the dataframe? + withType: + (When the output format is "df") should the vertex type be included in the dataframe? + timeout: + Time allowed for successful execution (0 = no limit, default). + + Returns: + The (selected) details of the (matching) vertex instances (sorted, limited) as + dictionary, JSON or pandas DataFrame. + + Endpoint: + - `GET /graph/{graph_name}/vertices/{vertex_type}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_list_vertices[List vertices] + """ + logger.info("entry: getVertices") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + url = _prep_get_vertices( + restppUrl=self.restppUrl, + graphname=self.graphname, + vertexType=vertexType, + select=select, + where=where, + limit=limit, + sort=sort, + timeout=timeout + ) + ret = await self._req("GET", url) + + if fmt == "json": + ret = json.dumps(ret) + elif fmt == "df": + ret = _vS2DF(ret, withId, withType) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVertices") + + return ret + + async def getVertexDataFrame(self, vertexType: str, select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': + """Retrieves vertices of the given vertex type and returns them as pandas DataFrame. + + This is a shortcut to `getVertices(..., fmt="df", withId=True, withType=False)`. + + *Note*: + The primary ID of a vertex instance is NOT an attribute, thus cannot be used in + `select`, `where` or `sort` parameters (unless the `WITH primary_id_as_attribute` clause + was used when the vertex type was created). / + Use `getVerticesById()` if you need to retrieve vertices by their primary ID. + + Args: + vertexType: + The name of the vertex type. + select: + Comma separated list of vertex attributes to be retrieved. + where: + Comma separated list of conditions that are all applied on each vertex' attributes. + The conditions are in logical conjunction (i.e. they are "AND'ed" together). + sort: + Comma separated list of attributes the results should be sorted by. + Must be used with 'limit'. + limit: + Maximum number of vertex instances to be returned (after sorting). + Must be used with `sort`. + timeout: + Time allowed for successful execution (0 = no limit, default). + + Returns: + The (selected) details of the (matching) vertex instances (sorted, limited) as pandas + DataFrame. + """ + logger.info("entry: getVertexDataFrame") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = await self.getVertices(vertexType, select=select, where=where, limit=limit, sort=sort, + fmt="df", withId=True, withType=False, timeout=timeout) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVertexDataFrame") + + return ret + + async def getVertexDataframe(self, vertexType: str, select: str = "", where: str = "", + limit: Union[int, str] = None, sort: str = "", timeout: int = 0) -> 'pd.DataFrame': + """DEPRECATED + + Use `getVertexDataFrame()` instead. + """ + warnings.warn( + "The `getVertexDataframe()` function is deprecated; use `getVertexDataFrame()` instead.", + DeprecationWarning) + + return await self.getVertexDataFrame(vertexType, select=select, where=where, limit=limit, + sort=sort, timeout=timeout) + + async def getVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], select: str = "", + fmt: str = "py", withId: bool = True, withType: bool = False, + timeout: int = 0) -> Union[list, str, 'pd.DataFrame']: + """Retrieves vertices of the given vertex type, identified by their ID. + + Args: + vertexType: + The name of the vertex type. + vertexIds: + A single vertex ID or a list of vertex IDs. + select: + Comma separated list of vertex attributes to be retrieved. + fmt: + Format of the results: + "py": Python objects (in a list) + "json": JSON document + "df": pandas DataFrame + withId: + (If the output format is "df") should the vertex ID be included in the dataframe? + withType: + (If the output format is "df") should the vertex type be included in the dataframe? + timeout: + Time allowed for successful execution (0 = no limit, default). + + Returns: + The (selected) details of the (matching) vertex instances as dictionary, JSON or pandas + DataFrame. + + Endpoint: + - `GET /graph/{graph_name}/vertices/{vertex_type}/{vertex_id}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_retrieve_a_vertex[Retrieve a vertex] + + TODO Find out how/if select and timeout can be specified + """ + logger.info("entry: getVerticesById") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + vids, url = _prep_get_vertices_by_id( + restppUrl=self.restppUrl, + graphname=self.graphname, + vertexIds=vertexIds, + vertexType=vertexType + ) + + ret = [] + for vid in vids: + ret += await self._req("GET", url + _safe_char(vid)) + + if fmt == "json": + ret = json.dumps(ret) + elif fmt == "df": + ret = _vS2DF(ret, withId, withType) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVerticesById") + + return ret + + async def getVertexDataFrameById(self, vertexType: str, vertexIds: Union[int, str, list], + select: str = "") -> 'pd.DataFrame': + """Retrieves vertices of the given vertex type, identified by their ID. + + This is a shortcut to ``getVerticesById(..., fmt="df", withId=True, withType=False)``. + + Args: + vertexType: + The name of the vertex type. + vertexIds: + A single vertex ID or a list of vertex IDs. + select: + Comma separated list of vertex attributes to be retrieved. + + Returns: + The (selected) details of the (matching) vertex instances as pandas DataFrame. + """ + logger.info("entry: getVertexDataFrameById") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = await self.getVerticesById(vertexType, vertexIds, select, fmt="df", withId=True, + withType=False) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVertexDataFrameById") + + return ret + + async def getVertexDataframeById(self, vertexType: str, vertexIds: Union[int, str, list], + select: str = "") -> 'pd.DataFrame': + """DEPRECATED + + Use `getVertexDataFrameById()` instead. + """ + warnings.warn( + "The `getVertexDataframeById()` function is deprecated; use `getVertexDataFrameById()` instead.", + DeprecationWarning) + + return await self.getVertexDataFrameById(vertexType, vertexIds, select) + + async def getVertexStats(self, vertexTypes: Union[str, list], skipNA: bool = False) -> dict: + """Returns vertex attribute statistics. + + Args: + vertexTypes: + A single vertex type name or a list of vertex types names or "*" for all vertex + types. + skipNA: + Skip those non-applicable vertices that do not have attributes or none of their + attributes have statistics gathered. + + Returns: + A dictionary of various vertex stats for each vertex type specified. + + Endpoint: + - `POST /builtins/{graph_name}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_run_built_in_functions_on_graph[Run built-in functions] + """ + logger.info("entry: getVertexStats") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + vts = [] + if vertexTypes == "*": + vts = await self.getVertexTypes() + elif isinstance(vertexTypes, str): + vts = [vertexTypes] + else: + vts = vertexTypes + + responses = [] + for vt in vts: + data = '{"function":"stat_vertex_attr","type":"' + vt + '"}' + res = await self._req("POST", self.restppUrl + "/builtins/" + self.graphname, data=data, resKey="", + skipCheck=True) + responses.append((vt, res)) + + ret = _parse_get_vertex_stats(responses, skipNA) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: getVertexStats") + + return ret + + async def delVertices(self, vertexType: str, where: str = "", limit: str = "", sort: str = "", + permanent: bool = False, timeout: int = 0) -> int: + """Deletes vertices from graph. + + *Note*: + The primary ID of a vertex instance is not an attribute. A primary ID cannot be used in + `select`, `where` or `sort` parameters (unless the `WITH primary_id_as_attribute` clause + was used when the vertex type was created). / + Use `delVerticesById()` if you need to retrieve vertices by their primary ID. + + Args: + vertexType: + The name of the vertex type. + where: + Comma separated list of conditions that are all applied on each vertex' attributes. + The conditions are in logical conjunction (i.e. they are "AND'ed" together). + sort: + Comma separated list of attributes the results should be sorted by. + Must be used with `limit`. + limit: + Maximum number of vertex instances to be returned (after sorting). + Must be used with `sort`. + permanent: + If true, the deleted vertex IDs can never be inserted back, unless the graph is + dropped or the graph store is cleared. + timeout: + Time allowed for successful execution (0 = no limit, default). + + Returns: + A single number of vertices deleted. + + The primary ID of a vertex instance is NOT an attribute, thus cannot be used in above + arguments. + + Endpoint: + - `DELETE /graph/{graph_name}/vertices/{vertex_type}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_delete_vertices[Delete vertices] + """ + logger.info("entry: delVertices") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + url = _prep_del_vertices( + restppUrl=self.restppUrl, + graphname=self.graphname, + vertexType=vertexType, + where=where, + limit=limit, + sort=sort, + permanent=permanent, + timeout=timeout + ) + ret = await self._req("DELETE", url) + ret = ret["deleted_vertices"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: delVertices") + + return ret + + async def delVerticesById(self, vertexType: str, vertexIds: Union[int, str, list], + permanent: bool = False, timeout: int = 0) -> int: + """Deletes vertices from graph identified by their ID. + + Args: + vertexType: + The name of the vertex type. + vertexIds: + A single vertex ID or a list of vertex IDs. + permanent: + If true, the deleted vertex IDs can never be inserted back, unless the graph is + dropped or the graph store is cleared. + timeout: + Time allowed for successful execution (0 = no limit, default). + + Returns: + A single number of vertices deleted. + + Endpoint: + - `DELETE /graph/{graph_name}/vertices/{vertex_type}/{vertex_id}` + See xref:tigergraph-server:API:built-in-endpoints.adoc#_delete_a_vertex[Delete a vertex] + """ + logger.info("entry: delVerticesById") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + url1, url2, vids = _prep_del_vertices_by_id( + restppUrl=self.restppUrl, + graphname=self.graphname, + vertexIds=vertexIds, + vertexType=vertexType, + permanent=permanent, + timeout=timeout + ) + ret = 0 + for vid in vids: + res = await self._req("DELETE", url1 + str(vid) + url2) + ret += res["deleted_vertices"] + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: delVerticesById") + + return ret + + # def delVerticesByType(self, vertexType: str, permanent: bool = False): + # TODO Implementation + # TODO DELETE /graph/{graph_name}/delete_by_type/vertices/{vertex_type}/ + # TODO Maybe call it truncateVertex[Type] or delAllVertices? + + # TODO GET /deleted_vertex_check/{graph_name} + + async def vertexSetToDataFrame(self, vertexSet: dict, withId: bool = True, withType: bool = False) -> 'pd.DataFrame': + """Converts a vertex set (dictionary) to a pandas DataFrame. + + Args: + vertexSet: + The vertex set to convert. + withId: + Should the vertex ID be included in the DataFrame? + withType: + Should the vertex type be included in the DataFrame? + + Returns: + The vertex set as a pandas DataFrame. + """ + logger.info("entry: vertexSetToDataFrame") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + ret = _vS2DF(vertexSet, withId, withType) + + if logger.level == logging.DEBUG: + logger.debug("return: " + str(ret)) + logger.info("exit: vertexSetToDataFrame") + + return ret \ No newline at end of file diff --git a/pyTigerGraph/schema.py b/pyTigerGraph/schema.py index 2c5b427a..05a4eb69 100644 --- a/pyTigerGraph/schema.py +++ b/pyTigerGraph/schema.py @@ -75,32 +75,40 @@ class CITES(Edge): ``` """ -from pyTigerGraph.pyTigerGraphException import TigerGraphException -from pyTigerGraph.pyTigerGraph import TigerGraphConnection -from dataclasses import dataclass, make_dataclass, fields, _MISSING_TYPE -from typing import List, Dict, Union, get_origin, get_args -from datetime import datetime import json import hashlib import warnings -BASE_TYPES = ["string", "int", "uint", "float", "double", "bool", "datetime"] +from typing import List, Dict, Union, get_origin, get_args +from datetime import datetime +from dataclasses import dataclass, make_dataclass, fields, _MISSING_TYPE + +from pyTigerGraph.common.exception import TigerGraphException +from pyTigerGraph.pyTigerGraph import TigerGraphConnection + + +BASE_TYPES = ["string", "int", "uint", "float", "double", "bool", "datetime"] PRIMARY_ID_TYPES = ["string", "int", "uint", "datetime"] COLLECTION_TYPES = ["list", "set", "map"] -COLLECTION_VALUE_TYPES = ["int", "double", "float", "string", "datetime", "udt"] +COLLECTION_VALUE_TYPES = ["int", "double", + "float", "string", "datetime", "udt"] MAP_KEY_TYPES = ["int", "string", "datetime"] + def _parse_type(attr): """NO DOC: function to parse gsql complex types""" collection_types = "" if attr["AttributeType"].get("ValueTypeName"): if attr["AttributeType"].get("KeyTypeName"): - collection_types += "<"+ attr["AttributeType"].get("KeyTypeName") + "," + attr["AttributeType"].get("ValueTypeName") + ">" + collection_types += "<" + attr["AttributeType"].get( + "KeyTypeName") + "," + attr["AttributeType"].get("ValueTypeName") + ">" else: - collection_types += "<"+attr["AttributeType"].get("ValueTypeName") + ">" + collection_types += "<" + \ + attr["AttributeType"].get("ValueTypeName") + ">" attr_type = (attr["AttributeType"]["Name"] + collection_types).upper() return attr_type + def _get_type(attr_type): """NO DOC: function to convert GSQL type to Python type""" if attr_type == "STRING": @@ -136,7 +144,8 @@ def _py_to_tg_type(attr_type): elif attr_type == list: raise TigerGraphException("Must define value type within list") elif attr_type == dict: - raise TigerGraphException("Must define key and value types within dictionary/map") + raise TigerGraphException( + "Must define key and value types within dictionary/map") elif attr_type == datetime: return "DATETIME" elif (str(type(attr_type)) == "") and attr_type._name == "List": @@ -144,7 +153,8 @@ def _py_to_tg_type(attr_type): if val_type.lower() in COLLECTION_VALUE_TYPES: return "LIST<"+val_type+">" else: - raise TigerGraphException(val_type + " not a valid type for the value type in LISTs.") + raise TigerGraphException( + val_type + " not a valid type for the value type in LISTs.") elif (str(type(attr_type)) == "") and attr_type._name == "Dict": key_type = _py_to_tg_type(attr_type.__args__[0]) val_type = _py_to_tg_type(attr_type.__args__[1]) @@ -152,21 +162,23 @@ def _py_to_tg_type(attr_type): if val_type.lower() in COLLECTION_VALUE_TYPES: return "MAP<"+key_type+","+val_type+">" else: - raise TigerGraphException(val_type + " not a valid type for the value type in MAPs.") + raise TigerGraphException( + val_type + " not a valid type for the value type in MAPs.") else: - raise TigerGraphException(key_type + " not a valid type for the key type in MAPs.") + raise TigerGraphException( + key_type + " not a valid type for the key type in MAPs.") else: if str(attr_type).lower() in BASE_TYPES: return str(attr_type).upper() else: - raise TigerGraphException(attr_type+"not a valid TigerGraph datatype.") - + raise TigerGraphException( + attr_type+"not a valid TigerGraph datatype.") @dataclass class Vertex(object): """Vertex Object - + Abstract parent class for other types of vertices to be inherited from. Contains class methods to edit the attributes associated with the vertex type. @@ -192,11 +204,11 @@ def __init_subclass__(cls): cls.incoming_edge_types = {} cls.outgoing_edge_types = {} cls._attribute_edits = {"ADD": {}, "DELETE": {}} - cls.primary_id:Union[str, List[str]] - cls.primary_id_as_attribute:bool + cls.primary_id: Union[str, List[str]] + cls.primary_id_as_attribute: bool @classmethod - def _set_attr_edit(self, add:dict = None, delete:dict = None): + def _set_attr_edit(self, add: dict = None, delete: dict = None): """NO DOC: internal updating function for attributes""" if add: self._attribute_edits["ADD"].update(add) @@ -208,11 +220,10 @@ def _get_attr_edit(self): """NO DOC: get attribute edits internal function""" return self._attribute_edits - @classmethod - def add_attribute(self, attribute_name:str, attribute_type, default_value = None): + def add_attribute(self, attribute_name: str, attribute_type, default_value=None): """Function to add an attribute to the given vertex type. - + Args: attribute_name (str): The name of the attribute to add @@ -223,19 +234,22 @@ def add_attribute(self, attribute_name:str, attribute_type, default_value = None The desired default value of the attribute. Defaults to None. """ if attribute_name in self._get_attr_edit()["ADD"].keys(): - warnings.warn(attribute_name + " already in staged edits. Overwriting previous edits.") + warnings.warn( + attribute_name + " already in staged edits. Overwriting previous edits.") for attr in self.attributes: if attr == attribute_name: - raise TigerGraphException(attribute_name + " already exists as an attribute on "+self.__name__ + " vertices") + raise TigerGraphException( + attribute_name + " already exists as an attribute on "+self.__name__ + " vertices") attr_type = _py_to_tg_type(attribute_type) - gsql_add = "ALTER VERTEX "+self.__name__+" ADD ATTRIBUTE ("+attribute_name+" "+attr_type + gsql_add = "ALTER VERTEX "+self.__name__ + \ + " ADD ATTRIBUTE ("+attribute_name+" "+attr_type if default_value: if attribute_type == str: gsql_add += " DEFAULT '"+default_value+"'" else: gsql_add += " DEFAULT "+str(default_value) - gsql_add +=");" - self._set_attr_edit(add ={attribute_name: gsql_add}) + gsql_add += ");" + self._set_attr_edit(add={attribute_name: gsql_add}) @classmethod def remove_attribute(self, attribute_name): @@ -247,14 +261,17 @@ def remove_attribute(self, attribute_name): """ if self.primary_id_as_attribute: if attribute_name == self.primary_id: - raise TigerGraphException("Cannot remove primary ID attribute: "+self.primary_id+".") + raise TigerGraphException( + "Cannot remove primary ID attribute: "+self.primary_id+".") removed = False for attr in self.attributes: if attr == attribute_name: - self._set_attr_edit(delete = {attribute_name: "ALTER VERTEX "+self.__name__+" DROP ATTRIBUTE ("+attribute_name+");"}) + self._set_attr_edit(delete={ + attribute_name: "ALTER VERTEX "+self.__name__+" DROP ATTRIBUTE ("+attribute_name+");"}) removed = True - if not(removed): - raise TigerGraphException("An attribute of "+ attribute_name + " is not an attribute on "+ self.__name__ + " vertices") + if not (removed): + raise TigerGraphException("An attribute of " + attribute_name + + " is not an attribute on " + self.__name__ + " vertices") @classmethod @property @@ -266,18 +283,20 @@ def __getattr__(self, attr): if self.attributes.get(attr): return self.attributes.get(attr) else: - raise TigerGraphException("No attribute named "+ attr + "for vertex type " + self.vertex_type) + raise TigerGraphException( + "No attribute named " + attr + "for vertex type " + self.vertex_type) def __eq__(self, lhs): return isinstance(lhs, Vertex) and lhs.vertex_type == self.vertex_type - + def __repr__(self): return self.vertex_type + @dataclass class Edge: """Edge Object - + Abstract parent class for other types of edges to be inherited from. Contains class methods to edit the attributes associated with the edge type. @@ -300,14 +319,14 @@ class HOLDS_ACCOUNT(Edge): def __init_subclass__(cls): """NO DOC: placeholder for class variables""" cls._attribute_edits = {"ADD": {}, "DELETE": {}} - cls.is_directed:bool - cls.reverse_edge:Union[str, bool] - cls.from_vertex_types:Union[Vertex, List[Vertex]] - cls.to_vertex_types:Union[Vertex, List[Vertex]] - cls.discriminator:Union[str, List[str]] + cls.is_directed: bool + cls.reverse_edge: Union[str, bool] + cls.from_vertex_types: Union[Vertex, List[Vertex]] + cls.to_vertex_types: Union[Vertex, List[Vertex]] + cls.discriminator: Union[str, List[str]] @classmethod - def _set_attr_edit(self, add:dict = None, delete:dict = None): + def _set_attr_edit(self, add: dict = None, delete: dict = None): """NO DOC: function to edit attributes""" if add: self._attribute_edits["ADD"].update(add) @@ -320,9 +339,9 @@ def _get_attr_edit(self): return self._attribute_edits @classmethod - def add_attribute(self, attribute_name, attribute_type, default_value = None): + def add_attribute(self, attribute_name, attribute_type, default_value=None): """Function to add an attribute to the given edge type. - + Args: attribute_name (str): The name of the attribute to add. @@ -333,19 +352,22 @@ def add_attribute(self, attribute_name, attribute_type, default_value = None): The desired default value of the attribute. Defaults to None. """ if attribute_name in self._get_attr_edit()["ADD"].keys(): - warnings.warn(attribute_name + " already in staged edits. Overwriting previous edits.") + warnings.warn( + attribute_name + " already in staged edits. Overwriting previous edits.") for attr in self.attributes: if attr == attribute_name: - raise TigerGraphException(attribute_name + " already exists as an attribute on "+self.__name__ + " edges") + raise TigerGraphException( + attribute_name + " already exists as an attribute on "+self.__name__ + " edges") attr_type = _py_to_tg_type(attribute_type) - gsql_add = "ALTER EDGE "+self.__name__+" ADD ATTRIBUTE ("+attribute_name+" "+attr_type + gsql_add = "ALTER EDGE "+self.__name__ + \ + " ADD ATTRIBUTE ("+attribute_name+" "+attr_type if default_value: if attribute_type == str: gsql_add += " DEFAULT '"+default_value+"'" else: gsql_add += " DEFAULT "+str(default_value) - gsql_add +=");" - self._set_attr_edit(add ={attribute_name: gsql_add}) + gsql_add += ");" + self._set_attr_edit(add={attribute_name: gsql_add}) @classmethod def remove_attribute(self, attribute_name): @@ -358,10 +380,12 @@ def remove_attribute(self, attribute_name): removed = False for attr in self.attributes: if attr == attribute_name: - self._set_attr_edit(delete = {attribute_name:"ALTER EDGE "+self.__name__+" DROP ATTRIBUTE ("+attribute_name+");"}) + self._set_attr_edit(delete={ + attribute_name: "ALTER EDGE "+self.__name__+" DROP ATTRIBUTE ("+attribute_name+");"}) removed = True - if not(removed): - raise TigerGraphException("An attribute of "+ attribute_name + " is not an attribute on "+ self.__name__ + " edges") + if not (removed): + raise TigerGraphException( + "An attribute of " + attribute_name + " is not an attribute on " + self.__name__ + " edges") @classmethod @property @@ -373,7 +397,8 @@ def __getattr__(self, attr): if self.attributes.get(attr): return self.attributes.get(attr) else: - raise TigerGraphException("No attribute named "+ attr + "for edge type " + self.edge_type) + raise TigerGraphException( + "No attribute named " + attr + "for edge type " + self.edge_type) def __eq__(self, lhs): return isinstance(lhs, Edge) and lhs.edge_type == self.edge_type and lhs.from_vertex_type == self.from_vertex_type and lhs.to_vertex_type == self.to_vertex_type @@ -381,6 +406,7 @@ def __eq__(self, lhs): def __repr__(self): return self.edge_type + class Graph(): """Graph Object @@ -394,7 +420,8 @@ class Graph(): g = Graph(conn) ``` """ - def __init__(self, conn:TigerGraphConnection = None): + + def __init__(self, conn: TigerGraphConnection = None): """Graph class for schema representation. Args: @@ -407,46 +434,53 @@ def __init__(self, conn:TigerGraphConnection = None): self._edge_edits = {"ADD": {}, "DELETE": {}} if conn: db_rep = conn.getSchema(force=True) - self.graphname = db_rep["GraphName"] - for v_type in db_rep["VertexTypes"]: - vert = make_dataclass(v_type["Name"], - [(attr["AttributeName"], _get_type(_parse_type(attr)), None) for attr in v_type["Attributes"]] + - [(v_type["PrimaryId"]["AttributeName"], _get_type(_parse_type(v_type["PrimaryId"])), None), - ("primary_id", str, v_type["PrimaryId"]["AttributeName"]), - ("primary_id_as_attribute", bool, v_type["PrimaryId"].get("PrimaryIdAsAttribute", False))], - bases=(Vertex,), repr=False) - self._vertex_types[v_type["Name"]] = vert - - for e_type in db_rep["EdgeTypes"]: - if e_type["FromVertexTypeName"] == "*": - source_vertices = [self._vertex_types[x["From"]] for x in e_type["EdgePairs"]] - else: - source_vertices = self._vertex_types[e_type["FromVertexTypeName"]] - if e_type["ToVertexTypeName"] == "*": - target_vertices = [self._vertex_types[x["To"]] for x in e_type["EdgePairs"]] - else: - target_vertices = self._vertex_types[e_type["ToVertexTypeName"]] - - e = make_dataclass(e_type["Name"], - [(attr["AttributeName"], _get_type(_parse_type(attr)), None) for attr in e_type["Attributes"]] + - [("from_vertex", source_vertices, None), - ("to_vertex", target_vertices, None), - ("is_directed", bool, e_type["IsDirected"]), - ("reverse_edge", str, e_type["Config"].get("REVERSE_EDGE"))], - bases=(Edge,), repr=False) - if isinstance(target_vertices, list): - for tgt_v in target_vertices: - tgt_v.incoming_edge_types[e_type["Name"]] = e - else: - target_vertices.incoming_edge_types[e_type["Name"]] = e - if isinstance(source_vertices, list): - for src_v in source_vertices: - src_v.outgoing_edge_types[e_type["Name"]] = e - else: - source_vertices.outgoing_edge_types[e_type["Name"]] = e - - self._edge_types[e_type["Name"]] = e - self.conn = conn + self.setUpConn(conn, db_rep) + + def setUpConn(self, conn, db_rep): + self.graphname = db_rep["GraphName"] + for v_type in db_rep["VertexTypes"]: + vert = make_dataclass(v_type["Name"], + [(attr["AttributeName"], _get_type(_parse_type(attr)), None) for attr in v_type["Attributes"]] + + [(v_type["PrimaryId"]["AttributeName"], _get_type(_parse_type(v_type["PrimaryId"])), None), + ("primary_id", str, + v_type["PrimaryId"]["AttributeName"]), + ("primary_id_as_attribute", bool, v_type["PrimaryId"].get("PrimaryIdAsAttribute", False))], + bases=(Vertex,), repr=False) + self._vertex_types[v_type["Name"]] = vert + + for e_type in db_rep["EdgeTypes"]: + if e_type["FromVertexTypeName"] == "*": + source_vertices = [self._vertex_types[x["From"]] + for x in e_type["EdgePairs"]] + else: + source_vertices = self._vertex_types[e_type["FromVertexTypeName"]] + if e_type["ToVertexTypeName"] == "*": + target_vertices = [self._vertex_types[x["To"]] + for x in e_type["EdgePairs"]] + else: + target_vertices = self._vertex_types[e_type["ToVertexTypeName"]] + + e = make_dataclass(e_type["Name"], + [(attr["AttributeName"], _get_type(_parse_type(attr)), None) for attr in e_type["Attributes"]] + + [("from_vertex", source_vertices, None), + ("to_vertex", target_vertices, None), + ("is_directed", bool, + e_type["IsDirected"]), + ("reverse_edge", str, e_type["Config"].get("REVERSE_EDGE"))], + bases=(Edge,), repr=False) + if isinstance(target_vertices, list): + for tgt_v in target_vertices: + tgt_v.incoming_edge_types[e_type["Name"]] = e + else: + target_vertices.incoming_edge_types[e_type["Name"]] = e + if isinstance(source_vertices, list): + for src_v in source_vertices: + src_v.outgoing_edge_types[e_type["Name"]] = e + else: + source_vertices.outgoing_edge_types[e_type["Name"]] = e + + self._edge_types[e_type["Name"]] = e + self.conn = conn def add_vertex_type(self, vertex: Vertex, outdegree_stats=True): """Add a vertex type to the list of changes to commit to the graph. @@ -459,9 +493,11 @@ def add_vertex_type(self, vertex: Vertex, outdegree_stats=True): Used for caching outdegree, defaults to True. """ if vertex.__name__ in self._vertex_types.keys(): - raise TigerGraphException(vertex.__name__+" already exists in the database") + raise TigerGraphException( + vertex.__name__+" already exists in the database") if vertex.__name__ in self._vertex_edits.keys(): - warnings.warn(vertex.__name__ + " already in staged edits. Overwriting previous edits.") + warnings.warn( + vertex.__name__ + " already in staged edits. Overwriting previous edits.") gsql_def = "ADD VERTEX "+vertex.__name__+"(" attrs = vertex.attributes primary_id = None @@ -474,16 +510,20 @@ def add_vertex_type(self, vertex: Vertex, outdegree_stats=True): if field.name == "primary_id_as_attribute": primary_id_as_attribute = field.default - if not(primary_id): - raise TigerGraphException("primary_id of vertex type "+str(vertex.__name__)+" not defined") + if not (primary_id): + raise TigerGraphException( + "primary_id of vertex type "+str(vertex.__name__)+" not defined") - if not(primary_id_as_attribute): - raise TigerGraphException("primary_id_as_attribute of vertex type "+str(vertex.__name__)+" not defined") + if not (primary_id_as_attribute): + raise TigerGraphException( + "primary_id_as_attribute of vertex type "+str(vertex.__name__)+" not defined") - if not(_py_to_tg_type(primary_id_type).lower() in PRIMARY_ID_TYPES): - raise TigerGraphException(str(primary_id_type), "is not a supported type for primary IDs.") + if not (_py_to_tg_type(primary_id_type).lower() in PRIMARY_ID_TYPES): + raise TigerGraphException( + str(primary_id_type), "is not a supported type for primary IDs.") - gsql_def += "PRIMARY_ID "+primary_id+" "+_py_to_tg_type(primary_id_type) + gsql_def += "PRIMARY_ID "+primary_id + \ + " "+_py_to_tg_type(primary_id_type) for attr in attrs.keys(): if attr == primary_id or attr == "primary_id" or attr == "primary_id_as_attribute": continue @@ -508,9 +548,11 @@ def add_edge_type(self, edge: Edge): The edge type definition to add to the addition cache. """ if edge in self._edge_types.values(): - raise TigerGraphException(edge.__name__+" already exists in the database") + raise TigerGraphException( + edge.__name__+" already exists in the database") if edge in self._edge_edits.values(): - warnings.warn(edge.__name__ + " already in staged edits. Overwriting previous edits") + warnings.warn( + edge.__name__ + " already in staged edits. Overwriting previous edits") attrs = edge.attributes is_directed = None reverse_edge = None @@ -523,36 +565,40 @@ def add_edge_type(self, edge: Edge): if field.name == "discriminator": discriminator = field.default - - if not(reverse_edge) and is_directed: - raise TigerGraphException("Reverse edge definition not set. Set the reverse_edge variable to a boolean or string.") + + if not (reverse_edge) and is_directed: + raise TigerGraphException( + "Reverse edge definition not set. Set the reverse_edge variable to a boolean or string.") if is_directed is None: - raise TigerGraphConnection("is_directed variable not defined. Define is_directed as a class variable to the desired setting.") - - if not(edge.attributes.get("from_vertex", None)): - raise TigerGraphException("from_vertex is not defined. Define from_vertex class variable.") - - if not(edge.attributes.get("to_vertex", None)): - raise TigerGraphException("to_vertex is not defined. Define to_vertex class variable.") - + raise TigerGraphConnection( + "is_directed variable not defined. Define is_directed as a class variable to the desired setting.") + + if not (edge.attributes.get("from_vertex", None)): + raise TigerGraphException( + "from_vertex is not defined. Define from_vertex class variable.") + + if not (edge.attributes.get("to_vertex", None)): + raise TigerGraphException( + "to_vertex is not defined. Define to_vertex class variable.") + gsql_def = "" if is_directed: gsql_def += "ADD DIRECTED EDGE "+edge.__name__+"(" else: gsql_def += "ADD UNDIRECTED EDGE "+edge.__name__+"(" - - if not(get_origin(edge.attributes["from_vertex"]) is Union) and not(get_origin(edge.attributes["to_vertex"]) is Union): + + if not (get_origin(edge.attributes["from_vertex"]) is Union) and not (get_origin(edge.attributes["to_vertex"]) is Union): from_vert = edge.attributes["from_vertex"].__name__ to_vert = edge.attributes["to_vertex"].__name__ gsql_def += "FROM "+from_vert+", "+"TO "+to_vert - elif get_origin(edge.attributes["from_vertex"]) is Union and not(get_origin(edge.attributes["to_vertex"]) is Union): + elif get_origin(edge.attributes["from_vertex"]) is Union and not (get_origin(edge.attributes["to_vertex"]) is Union): print(get_args(edge.attributes["from_vertex"])) for v in get_args(edge.attributes["from_vertex"]): from_vert = v.__name__ to_vert = edge.attributes["to_vertex"].__name__ gsql_def += "FROM "+from_vert+", "+"TO "+to_vert + "|" gsql_def = gsql_def[:-1] - elif not(get_origin(edge.attributes["from_vertex"]) is Union) and get_origin(edge.attributes["to_vertex"]) is Union: + elif not (get_origin(edge.attributes["from_vertex"]) is Union) and get_origin(edge.attributes["to_vertex"]) is Union: for v in get_args(edge.attributes["to_vertex"]): from_vert = edge.attributes["from_vertex"].__name__ to_vert = v.__name__ @@ -560,15 +606,19 @@ def add_edge_type(self, edge: Edge): gsql_def = gsql_def[:-1] elif get_origin(edge.attributes["from_vertex"]) is Union and get_origin(edge.attributes["to_vertex"]) is Union: if len(get_args(edge.attributes["from_vertex"])) != len(get_args(edge.attributes["to_vertex"])): - raise TigerGraphException("from_vertex and to_vertex list have different lengths.") + raise TigerGraphException( + "from_vertex and to_vertex list have different lengths.") else: for i in range(len(get_args(edge.attributes["from_vertex"]))): - from_vert = get_args(edge.attributes["from_vertex"])[i].__name__ - to_vert = get_args(edge.attributes["to_vertex"])[i].__name__ + from_vert = get_args(edge.attributes["from_vertex"])[ + i].__name__ + to_vert = get_args(edge.attributes["to_vertex"])[ + i].__name__ gsql_def += "FROM "+from_vert+", "+"TO "+to_vert + "|" gsql_def = gsql_def[:-1] else: - raise TigerGraphException("from_vertex and to_vertex parameters have to be of type Union[Vertex, Vertex, ...] or Vertex") + raise TigerGraphException( + "from_vertex and to_vertex parameters have to be of type Union[Vertex, Vertex, ...] or Vertex") if discriminator: if isinstance(discriminator, list): @@ -578,9 +628,11 @@ def add_edge_type(self, edge: Edge): gsql_def = gsql_def[:-2] gsql_def += ")" elif isinstance(discriminator, str): - gsql_def += ", DISCRIMINATOR("+discriminator + " "+_py_to_tg_type(attrs[discriminator])+")" + gsql_def += ", DISCRIMINATOR("+discriminator + \ + " "+_py_to_tg_type(attrs[discriminator])+")" else: - raise TigerGraphException("Discriminator definitions can only be of type string (one discriminator) or list (compound discriminator)") + raise TigerGraphException( + "Discriminator definitions can only be of type string (one discriminator) or list (compound discriminator)") for attr in attrs.keys(): if attr == "from_vertex" or attr == "to_vertex" or attr == "is_directed" or attr == "reverse_edge" or (discriminator and attr in discriminator) or attr == "discriminator": continue @@ -594,8 +646,9 @@ def add_edge_type(self, edge: Edge): elif isinstance(reverse_edge, bool): gsql_def += ' WITH REVERSE_EDGE="reverse_'+edge.__name__+'"' else: - raise TigerGraphException("Reverse edge name of type: "+str(type(attrs["reverse_edge"])+" is not supported.")) - gsql_def+=";" + raise TigerGraphException( + "Reverse edge name of type: "+str(type(attrs["reverse_edge"])+" is not supported.")) + gsql_def += ";" self._edge_edits["ADD"][edge.__name__] = gsql_def def remove_vertex_type(self, vertex: Vertex): @@ -618,33 +671,24 @@ def remove_edge_type(self, edge: Edge): gsql_def = "DROP EDGE "+edge.__name__+";" self._edge_edits["DELETE"][edge.__name__] = gsql_def - def commit_changes(self, conn: TigerGraphConnection = None): - """Commit schema changes to the graph. - Args: - conn (TigerGraphConnection, optional): - Connection to the database to edit the schema of. - Not required if the Graph was instantiated with a connection object. - """ - if not(conn): - if self.conn: - conn = self.conn - else: - raise TigerGraphException("No Connection Defined. Please instantiate a TigerGraphConnection to the database to commit the schema.") - if "does not exist." in conn.gsql("USE GRAPH "+conn.graphname): - conn.gsql("CREATE GRAPH "+conn.graphname+"()") - all_attr = [x._attribute_edits for x in list(self._vertex_types.values()) + list(self._edge_types.values())] - for elem in list(self._vertex_types.values()) + list(self._edge_types.values()): # need to remove the changes locally + def _parsecommit_changes(self, conn): + all_attr = [x._attribute_edits for x in list( + self._vertex_types.values()) + list(self._edge_types.values())] + # need to remove the changes locally + for elem in list(self._vertex_types.values()) + list(self._edge_types.values()): elem._attribute_edits = {"ADD": {}, "DELETE": {}} all_attribute_edits = {"ADD": {}, "DELETE": {}} for change in all_attr: all_attribute_edits["ADD"].update(change["ADD"]) all_attribute_edits["DELETE"].update(change["DELETE"]) md5 = hashlib.md5() - md5.update(json.dumps({**self._vertex_edits, **self._edge_edits, **all_attribute_edits}).encode()) + md5.update(json.dumps( + {**self._vertex_edits, **self._edge_edits, **all_attribute_edits}).encode()) job_name = "pytg_change_"+md5.hexdigest() start_gsql = "USE GRAPH "+conn.graphname+"\n" start_gsql += "DROP JOB "+job_name+"\n" - start_gsql += "CREATE SCHEMA_CHANGE JOB " + job_name + " FOR GRAPH " + conn.graphname + " {\n" + start_gsql += "CREATE SCHEMA_CHANGE JOB " + \ + job_name + " FOR GRAPH " + conn.graphname + " {\n" for v_to_add in self._vertex_edits["ADD"]: start_gsql += self._vertex_edits["ADD"][v_to_add] + "\n" for e_to_add in self._edge_edits["ADD"]: @@ -656,14 +700,34 @@ def commit_changes(self, conn: TigerGraphConnection = None): for attr_to_add in all_attribute_edits["ADD"]: start_gsql += all_attribute_edits["ADD"][attr_to_add] + "\n" for attr_to_drop in all_attribute_edits["DELETE"]: - start_gsql += all_attribute_edits["DELETE"][attr_to_drop] +"\n" + start_gsql += all_attribute_edits["DELETE"][attr_to_drop] + "\n" start_gsql += "}\n" start_gsql += "RUN SCHEMA_CHANGE JOB "+job_name + return start_gsql + + def commit_changes(self, conn: TigerGraphConnection = None): + """Commit schema changes to the graph. + Args: + conn (TigerGraphConnection, optional): + Connection to the database to edit the schema of. + Not required if the Graph was instantiated with a connection object. + """ + if not conn: + if self.conn: + conn = self.conn + else: + raise TigerGraphException( + "No Connection Defined. Please instantiate a TigerGraphConnection to the database to commit the schema.") + + if "does not exist." in conn.gsql("USE GRAPH "+conn.graphname): + conn.gsql("CREATE GRAPH "+conn.graphname+"()") + start_gsql = self._parsecommit_changes(conn) res = conn.gsql(start_gsql) if "updated to new version" in res: self.__init__(conn) else: - raise TigerGraphException("Schema change failed with message:\n"+res) + raise TigerGraphException( + "Schema change failed with message:\n"+res) @property def vertex_types(self): @@ -673,4 +737,4 @@ def vertex_types(self): @property def edge_types(self): """Edge types property.""" - return self._edge_types \ No newline at end of file + return self._edge_types diff --git a/pyTigerGraph/visualization.py b/pyTigerGraph/visualization.py index b7cded72..b89d7287 100644 --- a/pyTigerGraph/visualization.py +++ b/pyTigerGraph/visualization.py @@ -5,8 +5,8 @@ try: import ipycytoscape except: - raise Exception("Please install ipycytoscape to use visualization functions") - + raise Exception( + "Please install ipycytoscape to use visualization functions") def drawSchema(schema: dict, style: list = []): @@ -99,7 +99,8 @@ def _convert_schema_for_ipycytoscape(schema: dict): cytoscape_edge = dict() cytoscape_edge["data"] = dict() cytoscape_edge["data"]["id"] = ( - edge["Name"] + ":" + edgePair["From"] + ":" + edgePair["To"] + edge["Name"] + ":" + edgePair["From"] + + ":" + edgePair["To"] ) cytoscape_edge["data"]["source"] = edgePair["From"] cytoscape_edge["data"]["target"] = edgePair["To"] diff --git a/setup.py b/setup.py index 5b5e0fdc..af1ec4bc 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,8 @@ long_description = (here / "README.md").read_text(encoding="utf-8") # Get version number from a single source of truth + + def get_version(version_path): with open(version_path) as infile: for line in infile: @@ -18,6 +20,8 @@ def get_version(version_path): raise RuntimeError("Unable to find version string.") # Get non-python files under a directory recursively. + + def get_data_files(directory): files = [ str(p.relative_to(directory)) @@ -30,7 +34,8 @@ def get_data_files(directory): setup( name='pyTigerGraph', packages=find_packages(where="."), - package_data={"pyTigerGraph.gds": get_data_files(here / "pyTigerGraph" / "gds")}, + package_data={"pyTigerGraph.gds": get_data_files( + here / "pyTigerGraph" / "gds")}, version=get_version(here/"pyTigerGraph"/"__init__.py"), license='Apache 2', description='Library to connect to TigerGraph databases', @@ -40,12 +45,15 @@ def get_data_files(directory): author_email='support@tigergraph.com', url='https://docs.tigergraph.com/pytigergraph/current/intro/', download_url='', - keywords=['TigerGraph', 'Graph Database', 'Data Science', 'Machine Learning'], + keywords=['TigerGraph', 'Graph Database', + 'Data Science', 'Machine Learning'], install_requires=[ 'validators', - 'requests'], + 'requests', + 'httpx'], classifiers=[ - 'Development Status :: 5 - Production/Stable', # 3 - Alpha, 4 - Beta or 5 - Production/Stable + # 3 - Alpha, 4 - Beta or 5 - Production/Stable + 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', 'License :: OSI Approved :: Apache Software License', diff --git a/tests/pyTigerGraphUnitTestAsync.py b/tests/pyTigerGraphUnitTestAsync.py new file mode 100644 index 00000000..a3af680a --- /dev/null +++ b/tests/pyTigerGraphUnitTestAsync.py @@ -0,0 +1,51 @@ +import json +import os +from os.path import exists + +from pyTigerGraph import AsyncTigerGraphConnection + + +async def make_connection(graphname: str = None): + server_config = { + "host": "http://127.0.0.1", + "graphname": "tests", + "username": "tigergraph", + "password": "tigergraph", + "gsqlSecret": "", + "restppPort": "9000", + "gsPort": "14240", + "gsqlVersion": "", + "userCert": None, + "certPath": None, + "sslPort": "443", + "tgCloud": False, + "gcp": False, + "jwtToken": "" + } + + path = os.path.dirname(os.path.realpath(__file__)) + fname = os.path.join(path, "testserver.json") + if exists(fname): + with open(fname, "r") as config_file: + config = json.load(config_file) + server_config.update(config) + + conn = AsyncTigerGraphConnection( + host=server_config["host"], + graphname=graphname if graphname else server_config["graphname"], + username=server_config["username"], + password=server_config["password"], + tgCloud=server_config["tgCloud"], + restppPort=server_config["restppPort"], + gsPort=server_config["gsPort"], + gsqlVersion=server_config["gsqlVersion"], + useCert=server_config["userCert"], + certPath=server_config["certPath"], + sslPort=server_config["sslPort"], + gcp=server_config["gcp"], + jwtToken=server_config["jwtToken"] + ) + if server_config.get("getToken", False): + await conn.getToken(await conn.createSecret()) + + return conn diff --git a/tests/test_OGM.py b/tests/test_OGM.py index 2efa4f78..b5b1d87f 100644 --- a/tests/test_OGM.py +++ b/tests/test_OGM.py @@ -8,7 +8,6 @@ from dataclasses import dataclass - class TestHomogeneousOGM(unittest.TestCase): @classmethod def setUpClass(cls): @@ -26,6 +25,7 @@ def test_type(self): def test_add_vertex_type(self): g = Graph(self.conn) + @dataclass class AccountHolder(Vertex): name: str @@ -41,7 +41,8 @@ class AccountHolder(Vertex): g.commit_changes() - self.assertIn("name", g.vertex_types["AccountHolder"].attributes.keys()) + self.assertIn( + "name", g.vertex_types["AccountHolder"].attributes.keys()) def test_add_edge_type(self): g = Graph(self.conn) @@ -67,8 +68,8 @@ class HOLDS_ACCOUNT(Edge): g.commit_changes() - self.assertIn("opened_on", g.edge_types["HOLDS_ACCOUNT"].attributes.keys()) - + self.assertIn( + "opened_on", g.edge_types["HOLDS_ACCOUNT"].attributes.keys()) def test_add_multi_target_edge_type(self): g = Graph(self.conn) @@ -94,15 +95,16 @@ class SOME_EDGE_NAME(Edge): g.commit_changes() - self.assertIn("some_attr", g.edge_types["SOME_EDGE_NAME"].attributes.keys()) - + self.assertIn( + "some_attr", g.edge_types["SOME_EDGE_NAME"].attributes.keys()) + def test_drop_edge_type(self): g = Graph(self.conn) g.remove_edge_type(g.edge_types["HOLDS_ACCOUNT"]) g.commit_changes() - + self.assertNotIn("HOLDS_ACOUNT", g.edge_types) def test_drop_multi_target_edge_type(self): @@ -111,7 +113,7 @@ def test_drop_multi_target_edge_type(self): g.remove_edge_type(g.edge_types["SOME_EDGE_NAME"]) g.commit_changes() - + self.assertNotIn("SOME_EDGE_NAME", g.edge_types) def test_drop_vertex_type(self): @@ -126,12 +128,14 @@ def test_drop_vertex_type(self): def test_add_vertex_attribute_default_value(self): g = Graph(self.conn) - g.vertex_types["Paper"].add_attribute("ThisIsATest", str, "test_default") + g.vertex_types["Paper"].add_attribute( + "ThisIsATest", str, "test_default") g.commit_changes() self.assertIn("ThisIsATest", g.vertex_types["Paper"].attributes.keys()) - sample = self.conn.getVertices("Paper", limit=1)[0]["attributes"]["ThisIsATest"] + sample = self.conn.getVertices("Paper", limit=1)[ + 0]["attributes"]["ThisIsATest"] self.assertEqual("'test_default'", sample) @@ -142,7 +146,8 @@ def test_drop_vertex_attribute(self): g.commit_changes() - self.assertNotIn("ThisIsATest", g.vertex_types["Paper"].attributes.keys()) + self.assertNotIn( + "ThisIsATest", g.vertex_types["Paper"].attributes.keys()) class TestHeterogeneousOGM(unittest.TestCase): @@ -153,7 +158,7 @@ def setUpClass(cls): def test_init(self): g = Graph(self.conn) self.assertEqual(len(g.vertex_types.keys()), 3) - + def test_type(self): g = Graph(self.conn) attrs = g.vertex_types["v0"].attributes @@ -204,6 +209,7 @@ class CITES(Edge): self.conn.gsql("DROP GRAPH Cora2") + if __name__ == "__main__": suite = unittest.TestSuite() suite.addTest(TestHeterogeneousOGM("test_init")) @@ -212,8 +218,8 @@ class CITES(Edge): suite.addTest(TestHomogeneousOGM("test_add_edge_type")) suite.addTest(TestHomogeneousOGM("test_drop_edge_type")) suite.addTest(TestHomogeneousOGM("test_drop_vertex_type")) - suite.addTest(TestHomogeneousOGM("test_add_vertex_attribute_default_value")) - suite.addTest(TestHomogeneousOGM("test_drop_vertex_attribute")) + # suite.addTest(TestHomogeneousOGM("test_add_vertex_attribute_default_value")) + # suite.addTest(TestHomogeneousOGM("test_drop_vertex_attribute")) suite.addTest(TestHeterogeneousOGM("test_init")) suite.addTest(TestHeterogeneousOGM("test_type")) suite.addTest(TestHeterogeneousOGM("test_outgoing_edge_types")) diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 00000000..53f588cd --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,36 @@ +import asyncio +import unittest +from pyTigerGraphUnitTestAsync import make_connection + + +class test_async(unittest.IsolatedAsyncioTestCase): + @classmethod + async def asyncSetUp(self): + self.conn = await make_connection() + + ''' + async def test_task_results(self): + if not hasattr(self, 'conn'): + raise AttributeError( + "Connection was not initialized. Please check the setup.") + + tasks: list[asyncio.Task] = [] + + async with asyncio.TaskGroup() as tg: + for i in range(100): + if i % 2 == 0: + task = tg.create_task(self.conn.getVertexCount("vertex7")) + tasks.append(task) + else: + task = tg.create_task( + self.conn.getEdgeCount("edge1_undirected")) + tasks.append(task) + + for t in tasks: + result = t.result() + # print(result) + self.assertIsInstance(result, int) + + ''' +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_datasetsAsync.py b/tests/test_datasetsAsync.py new file mode 100644 index 00000000..a5f3c8fd --- /dev/null +++ b/tests/test_datasetsAsync.py @@ -0,0 +1,66 @@ +import unittest +from io import StringIO +from os.path import exists +from textwrap import dedent +from unittest.mock import patch + +from pyTigerGraph.pytgasync.datasets import AsyncDatasets + + +class TestDatasetsAsync(unittest.IsolatedAsyncioTestCase): + async def test_get_dataset_url(self): + dataset = await AsyncDatasets.create() + + dataset.name = "Cora" + self.assertEqual( + await dataset.get_dataset_url(), + "https://tigergraph-public-data.s3.us-west-1.amazonaws.com/Cora.tar.gz", + ) + + dataset.name = "SomethingNotThere" + self.assertIsNone(await dataset.get_dataset_url()) + + async def test_download_extract(self): + dataset = await AsyncDatasets.create() + dataset.name = "Cora" + dataset.dataset_url = await dataset.get_dataset_url() + await dataset.download_extract() + self.assertTrue(exists("./tmp/Cora/create_graph.gsql")) + self.assertTrue(exists("./tmp/Cora/create_load_job.gsql")) + self.assertTrue(exists("./tmp/Cora/create_schema.gsql")) + self.assertTrue(exists("./tmp/Cora/run_load_job.json")) + self.assertTrue(exists("./tmp/Cora/edges.csv")) + self.assertTrue(exists("./tmp/Cora/nodes.csv")) + + async def test_clean_up(self): + dataset = await AsyncDatasets.create() + dataset.name = "Cora" + dataset.clean_up() + self.assertFalse(exists("./tmp/Cora")) + + @patch("sys.stdout", new_callable=StringIO) + async def test_list(self, mock_stdout): + dataset = await AsyncDatasets.create() + truth = """\ + Available datasets: + - Cora + - CoraV2 + - Ethereum + - ldbc_snb + - LastFM + - imdb + - movie + - social + """ + self.assertIn(dedent(truth), mock_stdout.getvalue()) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(TestDatasetsAsync("test_get_dataset_url")) + suite.addTest(TestDatasetsAsync("test_download_extract")) + suite.addTest(TestDatasetsAsync("test_clean_up")) + suite.addTest(TestDatasetsAsync("test_list")) + + runner = unittest.TextTestRunner(verbosity=2, failfast=True) + runner.run(suite) diff --git a/tests/test_gds_BaseLoader.py b/tests/test_gds_BaseLoader.py index d2987c8d..851c1954 100644 --- a/tests/test_gds_BaseLoader.py +++ b/tests/test_gds_BaseLoader.py @@ -81,15 +81,19 @@ def test_validate_vertex_attributes(self): self.assertListEqual(self.loader._validate_vertex_attributes(None), []) self.assertListEqual(self.loader._validate_vertex_attributes([]), []) self.assertListEqual(self.loader._validate_vertex_attributes({}), []) - self.assertDictEqual(self.loader._validate_vertex_attributes(None, True), {}) - self.assertDictEqual(self.loader._validate_vertex_attributes([], True), {}) - self.assertDictEqual(self.loader._validate_vertex_attributes({}, True), {}) + self.assertDictEqual( + self.loader._validate_vertex_attributes(None, True), {}) + self.assertDictEqual( + self.loader._validate_vertex_attributes([], True), {}) + self.assertDictEqual( + self.loader._validate_vertex_attributes({}, True), {}) # Extra spaces self.assertListEqual( self.loader._validate_vertex_attributes(["x ", " y"]), ["x", "y"] ) self.assertDictEqual( - self.loader._validate_vertex_attributes({"Paper": ["x ", " y"]}, True), + self.loader._validate_vertex_attributes( + {"Paper": ["x ", " y"]}, True), {"Paper": ["x", "y"]}, ) # Wrong input @@ -104,16 +108,20 @@ def test_validate_vertex_attributes(self): with self.assertRaises(ValueError): self.loader._validate_vertex_attributes(["x"], is_hetero=True) with self.assertRaises(ValueError): - self.loader._validate_vertex_attributes({"Paper": ["x"]}, is_hetero=False) + self.loader._validate_vertex_attributes( + {"Paper": ["x"]}, is_hetero=False) def test_validate_edge_attributes(self): # Empty input self.assertListEqual(self.loader._validate_edge_attributes(None), []) self.assertListEqual(self.loader._validate_edge_attributes([]), []) self.assertListEqual(self.loader._validate_edge_attributes({}), []) - self.assertDictEqual(self.loader._validate_edge_attributes(None, True), {}) - self.assertDictEqual(self.loader._validate_edge_attributes([], True), {}) - self.assertDictEqual(self.loader._validate_edge_attributes({}, True), {}) + self.assertDictEqual( + self.loader._validate_edge_attributes(None, True), {}) + self.assertDictEqual( + self.loader._validate_edge_attributes([], True), {}) + self.assertDictEqual( + self.loader._validate_edge_attributes({}, True), {}) # Extra spaces self.assertListEqual( self.loader._validate_edge_attributes(["time ", "is_train"]), @@ -137,7 +145,8 @@ def test_validate_edge_attributes(self): with self.assertRaises(ValueError): self.loader._validate_edge_attributes(["time"], is_hetero=True) with self.assertRaises(ValueError): - self.loader._validate_edge_attributes({"Cite": ["time"]}, is_hetero=False) + self.loader._validate_edge_attributes( + {"Cite": ["time"]}, is_hetero=False) def test_read_vertex(self): read_task_q = Queue() @@ -253,7 +262,6 @@ def test_read_edge_callback(self): data = data_q.get() self.assertEqual(data, 1) - def test_read_graph_out_df(self): read_task_q = Queue() data_q = Queue(4) @@ -298,7 +306,6 @@ def test_read_graph_out_df(self): data = data_q.get() self.assertIsNone(data) - def test_read_graph_out_df_callback(self): read_task_q = Queue() data_q = Queue(4) @@ -330,7 +337,6 @@ def test_read_graph_out_df_callback(self): self.assertEqual(data[0], 1) self.assertEqual(data[1], 2) - def test_read_graph_out_pyg(self): read_task_q = Queue() data_q = Queue(4) @@ -360,7 +366,8 @@ def test_read_graph_out_pyg(self): ["x", "time"], ["y"], ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + {"x": "DOUBLE", "time": "INT", "y": "INT", + "is_train": "BOOL", "category": "LIST:STRING"}, delimiter="|" ) data = data_q.get() @@ -372,7 +379,8 @@ def test_read_graph_out_pyg(self): ) assert_close_torch(data["edge_label"], torch.tensor([1, 0])) assert_close_torch(data["is_train"], torch.tensor([False, True])) - assert_close_torch(data["x"], torch.tensor([[1, 0, 0, 1], [1, 0, 0, 1]])) + assert_close_torch(data["x"], torch.tensor( + [[1, 0, 0, 1], [1, 0, 0, 1]])) assert_close_torch(data["y"], torch.tensor([1, 1])) assert_close_torch(data["train_mask"], torch.tensor([False, True])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) @@ -410,24 +418,29 @@ def test_read_graph_out_dgl(self): ["x", "time"], ["y"], ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + {"x": "DOUBLE", "time": "INT", "y": "INT", + "is_train": "BOOL", "category": "LIST:STRING"}, delimiter="|" ) data = data_q.get() self.assertIsInstance(data, DGLGraph) - assert_close_torch(data.edges(), (torch.tensor([0, 1]), torch.tensor([1, 0]))) + assert_close_torch( + data.edges(), (torch.tensor([0, 1]), torch.tensor([1, 0]))) assert_close_torch( data.edata["edge_feat"], torch.tensor([[0.1, 2021], [1.5, 2020]], dtype=torch.double), ) assert_close_torch(data.edata["edge_label"], torch.tensor([1, 0])) assert_close_torch(data.edata["is_train"], torch.tensor([False, True])) - assert_close_torch(data.ndata["x"], torch.tensor([[1, 0, 0, 1], [1, 0, 0, 1]])) + assert_close_torch(data.ndata["x"], torch.tensor( + [[1, 0, 0, 1], [1, 0, 0, 1]])) assert_close_torch(data.ndata["y"], torch.tensor([1, 1])) - assert_close_torch(data.ndata["train_mask"], torch.tensor([False, True])) + assert_close_torch(data.ndata["train_mask"], + torch.tensor([False, True])) assert_close_torch(data.ndata["is_seed"], torch.tensor([True, False])) self.assertListEqual(data.extra_data["name"], ["Alex", "Bill"]) - self.assertListEqual(data.extra_data["category"], [['a', 'b'], ['c', 'd']]) + self.assertListEqual(data.extra_data["category"], [ + ['a', 'b'], ['c', 'd']]) data = data_q.get() self.assertIsNone(data) @@ -460,7 +473,8 @@ def test_read_graph_parse_error(self): ["x", "time"], ["y"], ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + {"x": "DOUBLE", "time": "INT", "y": "INT", + "is_train": "BOOL", "category": "LIST:STRING"}, delimiter="|" ) data = data_q.get() @@ -536,11 +550,12 @@ def test_read_graph_no_edge(self): ) data = data_q.get() self.assertIsInstance(data, pygData) - self.assertListEqual(list(data["edge_index"].shape), [2,0]) - self.assertListEqual(list(data["edge_feat"].shape), [0,2]) + self.assertListEqual(list(data["edge_index"].shape), [2, 0]) + self.assertListEqual(list(data["edge_feat"].shape), [0, 2]) self.assertListEqual(list(data["edge_label"].shape), [0,]) self.assertListEqual(list(data["is_train"].shape), [0,]) - assert_close_torch(data["x"], torch.tensor([[1, 0, 0, 1], [1, 0, 0, 1]])) + assert_close_torch(data["x"], torch.tensor( + [[1, 0, 0, 1], [1, 0, 0, 1]])) assert_close_torch(data["y"], torch.tensor([1, 1])) assert_close_torch(data["train_mask"], torch.tensor([False, True])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) @@ -566,7 +581,8 @@ def test_read_hetero_graph_out_pyg(self): "pyg", {"People": ["x"], "Company": ["x"]}, {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + {"People": ["train_mask", "name", "is_seed"], + "Company": ["is_seed"]}, { "People": { "x": "LIST:INT", @@ -610,14 +626,18 @@ def test_read_hetero_graph_out_pyg(self): data["Colleague"]["edge_feat"], torch.tensor([[0.1, 2021], [1.5, 2020]], dtype=torch.double), ) - assert_close_torch(data["Colleague"]["edge_label"], torch.tensor([1, 0])) - assert_close_torch(data["Colleague"]["is_train"], torch.tensor([False, True])) + assert_close_torch( + data["Colleague"]["edge_label"], torch.tensor([1, 0])) + assert_close_torch(data["Colleague"]["is_train"], + torch.tensor([False, True])) assert_close_torch( data["People"]["x"], torch.tensor([[1, 0, 0, 1], [1, 0, 0, 1]]) ) assert_close_torch(data["People"]["y"], torch.tensor([1, 1])) - assert_close_torch(data["People"]["train_mask"], torch.tensor([False, True])) - assert_close_torch(data["People"]["is_seed"], torch.tensor([True, False])) + assert_close_torch(data["People"]["train_mask"], + torch.tensor([False, True])) + assert_close_torch(data["People"]["is_seed"], + torch.tensor([True, False])) self.assertListEqual(data["People"]["name"], ["Alex", "Bill"]) assert_close_torch( data["Company"]["x"], torch.tensor([0.3], dtype=torch.double) @@ -690,7 +710,8 @@ def test_read_hetero_graph_no_attr(self): assert_close_torch( data["Work"]["edge_index"], torch.tensor([[0, 1], [0, 0]]) ) - assert_close_torch(data["People"]["is_seed"], torch.tensor([True, False])) + assert_close_torch(data["People"]["is_seed"], + torch.tensor([True, False])) assert_close_torch(data["Company"]["is_seed"], torch.tensor([False])) data = data_q.get() self.assertIsNone(data) @@ -713,7 +734,8 @@ def test_read_hetero_graph_no_edge(self): "pyg", {"People": ["x"], "Company": ["x"]}, {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + {"People": ["train_mask", "name", "is_seed"], + "Company": ["is_seed"]}, { "People": { "x": "LIST:INT", @@ -755,8 +777,10 @@ def test_read_hetero_graph_no_edge(self): data["People"]["x"], torch.tensor([[1, 0, 0, 1], [1, 0, 0, 1]]) ) assert_close_torch(data["People"]["y"], torch.tensor([1, 1])) - assert_close_torch(data["People"]["train_mask"], torch.tensor([False, True])) - assert_close_torch(data["People"]["is_seed"], torch.tensor([True, False])) + assert_close_torch(data["People"]["train_mask"], + torch.tensor([False, True])) + assert_close_torch(data["People"]["is_seed"], + torch.tensor([True, False])) self.assertListEqual(data["People"]["name"], ["Alex", "Bill"]) assert_close_torch( data["Company"]["x"], torch.tensor([0.3], dtype=torch.double) @@ -784,7 +808,8 @@ def test_read_hetero_graph_out_dgl(self): "dgl", {"People": ["x"], "Company": ["x"]}, {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + {"People": ["train_mask", "name", "is_seed"], + "Company": ["is_seed"]}, { "People": { "x": "LIST:INT", @@ -823,29 +848,41 @@ def test_read_hetero_graph_out_dgl(self): data = data_q.get() self.assertIsInstance(data, DGLGraph) assert_close_torch( - data.edges(etype="Colleague"), (torch.tensor([0, 1]), torch.tensor([1, 0])) + data.edges(etype="Colleague"), (torch.tensor( + [0, 1]), torch.tensor([1, 0])) ) assert_close_torch( data.edges["Colleague"].data["edge_feat"], torch.tensor([[0.1, 2021], [1.5, 2020]], dtype=torch.double), ) - assert_close_torch(data.edges["Colleague"].data["edge_label"], torch.tensor([1, 0])) - assert_close_torch(data.edges["Colleague"].data["is_train"], torch.tensor([False, True])) assert_close_torch( - data.nodes["People"].data["x"], torch.tensor([[1, 0, 0, 1], [1, 0, 0, 1]]) + data.edges["Colleague"].data["edge_label"], torch.tensor([1, 0])) + assert_close_torch( + data.edges["Colleague"].data["is_train"], torch.tensor([False, True])) + assert_close_torch( + data.nodes["People"].data["x"], torch.tensor( + [[1, 0, 0, 1], [1, 0, 0, 1]]) ) - assert_close_torch(data.nodes["People"].data["y"], torch.tensor([1, 1])) - assert_close_torch(data.nodes["People"].data["train_mask"], torch.tensor([False, True])) - assert_close_torch(data.nodes["People"].data["is_seed"], torch.tensor([True, False])) - self.assertListEqual(data.extra_data["People"]["name"], ["Alex", "Bill"]) assert_close_torch( - data.nodes["Company"].data["x"], torch.tensor([0.3], dtype=torch.double) + data.nodes["People"].data["y"], torch.tensor([1, 1])) + assert_close_torch( + data.nodes["People"].data["train_mask"], torch.tensor([False, True])) + assert_close_torch( + data.nodes["People"].data["is_seed"], torch.tensor([True, False])) + self.assertListEqual( + data.extra_data["People"]["name"], ["Alex", "Bill"]) + assert_close_torch( + data.nodes["Company"].data["x"], torch.tensor( + [0.3], dtype=torch.double) ) - assert_close_torch(data.nodes["Company"].data["is_seed"], torch.tensor([False])) assert_close_torch( - data.edges(etype="Work"), (torch.tensor([0, 1]), torch.tensor([0, 0])) + data.nodes["Company"].data["is_seed"], torch.tensor([False])) + assert_close_torch( + data.edges(etype="Work"), (torch.tensor( + [0, 1]), torch.tensor([0, 0])) ) - self.assertListEqual(data.extra_data["Work"]["category"], [['a', 'b'], ['c', 'd']]) + self.assertListEqual(data.extra_data["Work"]["category"], [ + ['a', 'b'], ['c', 'd']]) data = data_q.get() self.assertIsNone(data) @@ -890,7 +927,8 @@ def test_read_bool_label(self): ) assert_close_torch(data["edge_label"], torch.tensor([True, False])) assert_close_torch(data["is_train"], torch.tensor([False, True])) - assert_close_torch(data["x"], torch.tensor([[1, 0, 0, 1], [1, 0, 0, 1]])) + assert_close_torch(data["x"], torch.tensor( + [[1, 0, 0, 1], [1, 0, 0, 1]])) assert_close_torch(data["y"], torch.tensor([True, True])) assert_close_torch(data["train_mask"], torch.tensor([False, True])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) diff --git a/tests/test_gds_EdgeLoader.py b/tests/test_gds_EdgeLoader.py index 5e22671c..c36f6a52 100644 --- a/tests/test_gds_EdgeLoader.py +++ b/tests/test_gds_EdgeLoader.py @@ -254,9 +254,10 @@ def test_iterate_as_homo(self): def test_iterate_hetero(self): loader = EdgeLoader( graph=self.conn, - attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, + attributes={"v0v0": ["is_train", "is_val"], + "v2v0": ["is_train", "is_val"]}, batch_size=200, - shuffle=True, # Needed to get around VID distribution issues + shuffle=True, # Needed to get around VID distribution issues filter_by=None, loader_id=None, buffer_size=4, @@ -277,7 +278,8 @@ def test_iterate_hetero(self): def test_iterate_hetero_multichar_delimiter(self): loader = EdgeLoader( graph=self.conn, - attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, + attributes={"v0v0": ["is_train", "is_val"], + "v2v0": ["is_train", "is_val"]}, batch_size=200, shuffle=True, # Needed to get around VID distribution issues filter_by=None, @@ -289,7 +291,8 @@ def test_iterate_hetero_multichar_delimiter(self): for data in loader: # print(num_batches, data) if num_batches == 0: - self.assertEqual(data["v0v0"].shape[0]+data["v2v0"].shape[0], 200) + self.assertEqual(data["v0v0"].shape[0] + + data["v2v0"].shape[0], 200) self.assertEqual(len(data), 2) self.assertIsInstance(data["v0v0"], DataFrame) self.assertIn("is_val", data["v0v0"]) @@ -313,11 +316,13 @@ def test_iterate_hetero_multichar_delimiter(self): suite.addTest(TestGDSEdgeLoaderREST("test_iterate")) suite.addTest(TestGDSEdgeLoaderREST("test_whole_edgelist")) suite.addTest(TestGDSEdgeLoaderREST("test_iterate_attr")) - suite.addTest(TestGDSEdgeLoaderREST("test_iterate_attr_multichar_delimiter")) + suite.addTest(TestGDSEdgeLoaderREST( + "test_iterate_attr_multichar_delimiter")) suite.addTest(TestGDSHeteroEdgeLoaderREST("test_init")) suite.addTest(TestGDSHeteroEdgeLoaderREST("test_iterate_as_homo")) suite.addTest(TestGDSHeteroEdgeLoaderREST("test_iterate_hetero")) - suite.addTest(TestGDSHeteroEdgeLoaderREST("test_iterate_hetero_multichar_delimiter")) + suite.addTest(TestGDSHeteroEdgeLoaderREST( + "test_iterate_hetero_multichar_delimiter")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_GDS.py b/tests/test_gds_GDS.py index b247ac05..38b6a9d9 100644 --- a/tests/test_gds_GDS.py +++ b/tests/test_gds_GDS.py @@ -144,7 +144,8 @@ def test_configureKafka_sasl_plaintext(self): buffer_size=4, ) self.assertEqual(loader.kafka_address_consumer, "34.127.11.236:9092") - self.assertEqual(loader._payload["security_protocol"], "SASL_PLAINTEXT") + self.assertEqual( + loader._payload["security_protocol"], "SASL_PLAINTEXT") self.assertEqual(loader._payload["sasl_mechanism"], "PLAIN") self.assertEqual(loader._payload["sasl_username"], "bill") self.assertEqual(loader._payload["sasl_password"], "bill") diff --git a/tests/test_gds_GraphSAGE.py b/tests/test_gds_GraphSAGE.py index a61f6cb9..599909b3 100644 --- a/tests/test_gds_GraphSAGE.py +++ b/tests/test_gds_GraphSAGE.py @@ -1,12 +1,12 @@ import unittest from pyTigerGraphUnitTest import make_connection -import torch import logging import os from pyTigerGraph.gds.trainer import BaseCallback from pyTigerGraph.gds.models.GraphSAGE import GraphSAGEForLinkPrediction, GraphSAGEForVertexClassification, GraphSAGEForVertexRegression + class TestingCallback(BaseCallback): def __init__(self, test_name, output_dir="./logs"): self.output_dir = output_dir @@ -29,6 +29,7 @@ def on_eval_end(self, trainer): def on_epoch_end(self, trainer): trainer.eval() + class TestHomogeneousVertexClassificationGraphSAGE(unittest.TestCase): @classmethod def setUpClass(cls): @@ -43,7 +44,7 @@ def test_fit(self): v_in_feats=["x"], v_out_labels=["y"], num_batches=5, - e_extra_feats=["is_train","is_val"], + e_extra_feats=["is_train", "is_val"], output_format="PyG", num_neighbors=10, num_hops=2, @@ -55,7 +56,7 @@ def test_fit(self): v_in_feats=["x"], v_out_labels=["y"], num_batches=5, - e_extra_feats=["is_train","is_val"], + e_extra_feats=["is_train", "is_val"], output_format="PyG", num_neighbors=10, num_hops=2, @@ -63,59 +64,63 @@ def test_fit(self): shuffle=False, ) - gs = GraphSAGEForVertexClassification(num_layers=2, - out_dim=7, + gs = GraphSAGEForVertexClassification(num_layers=2, + out_dim=7, dropout=.2, hidden_dim=128) - trainer_args = {"callbacks":[TestingCallback("cora_fit")]} + trainer_args = {"callbacks": [TestingCallback("cora_fit")]} gs.fit(train_loader, valid_loader, 2, trainer_kwargs=trainer_args) ifLogged = os.path.isfile("./logs/train_results_cora_fit.log") self.assertEqual(ifLogged, True) + class TestHeterogeneousVertexClassificationGraphSAGE(unittest.TestCase): def test_init(self): - metadata = (['Actor', 'Movie', 'Director'], - [('Actor', 'actor_movie', 'Movie'), - ('Movie', 'movie_actor', 'Actor'), - ('Movie', 'movie_director', 'Director'), - ('Director', 'director_movie', 'Movie')]) + metadata = (['Actor', 'Movie', 'Director'], + [('Actor', 'actor_movie', 'Movie'), + ('Movie', 'movie_actor', 'Actor'), + ('Movie', 'movie_director', 'Director'), + ('Director', 'director_movie', 'Movie')]) model = GraphSAGEForVertexClassification(2, 3, 256, .2, metadata) self.assertEqual(len(list(model.parameters())), 24) + class TestHomogeneousVertexRegression(unittest.TestCase): def test_init(self): model = GraphSAGEForVertexRegression(2, 1, 128, 0.5) self.assertEqual(len(list(model.parameters())), 6) + class TestHeterogeneousVertexRegression(unittest.TestCase): def test_init(self): - metadata = (['Actor', 'Movie', 'Director'], - [('Actor', 'actor_movie', 'Movie'), - ('Movie', 'movie_actor', 'Actor'), - ('Movie', 'movie_director', 'Director'), - ('Director', 'director_movie', 'Movie')]) + metadata = (['Actor', 'Movie', 'Director'], + [('Actor', 'actor_movie', 'Movie'), + ('Movie', 'movie_actor', 'Actor'), + ('Movie', 'movie_director', 'Director'), + ('Director', 'director_movie', 'Movie')]) model = GraphSAGEForVertexRegression(2, 1, 128, 0.5, metadata) self.assertEqual(len(list(model.parameters())), 24) + class TestHomogeneousLinkPrediction(unittest.TestCase): def test_init(self): model = GraphSAGEForLinkPrediction(2, 128, 128, 0.5) self.assertEqual(len(list(model.parameters())), 6) + class TestHeterogeneousLinkPrediction(unittest.TestCase): def test_init(self): - metadata = (['Actor', 'Movie', 'Director'], - [('Actor', 'actor_movie', 'Movie'), - ('Movie', 'movie_actor', 'Actor'), - ('Movie', 'movie_director', 'Director'), - ('Director', 'director_movie', 'Movie')]) + metadata = (['Actor', 'Movie', 'Director'], + [('Actor', 'actor_movie', 'Movie'), + ('Movie', 'movie_actor', 'Actor'), + ('Movie', 'movie_director', 'Director'), + ('Director', 'director_movie', 'Movie')]) model = GraphSAGEForLinkPrediction(2, 128, 128, 0.5, metadata) self.assertEqual(len(list(model.parameters())), 24) - if __name__ == "__main__": unittest.main(verbosity=2, failfast=True) diff --git a/tests/test_gds_HGTLoader.py b/tests/test_gds_HGTLoader.py index 15c51820..2ca6a287 100644 --- a/tests/test_gds_HGTLoader.py +++ b/tests/test_gds_HGTLoader.py @@ -136,14 +136,15 @@ def test_iterate_pyg(self): add_self_loop=False, loader_id=None, buffer_size=4, - filter_by= {"v2": "train_mask"} + filter_by={"v2": "train_mask"} ) num_batches = 0 for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) self.assertGreater(data["v2"]["x"].shape[0], 0) - self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + self.assertEqual(data["v2"]["x"].shape[0], + data["v2"]["is_seed"].shape[0]) num_batches += 1 self.assertEqual(num_batches, 6) @@ -163,7 +164,8 @@ def test_fetch(self): buffer_size=4, ) data = loader.fetch( - [{"primary_id": "13", "type": "v2"}, {"primary_id": "28", "type": "v2"}] + [{"primary_id": "13", "type": "v2"}, { + "primary_id": "28", "type": "v2"}] ) self.assertIn("13", data["v2"]["primary_id"]) self.assertIn("28", data["v2"]["primary_id"]) diff --git a/tests/test_gds_NeighborLoader.py b/tests/test_gds_NeighborLoader.py index 2bb9a7c0..911a185d 100644 --- a/tests/test_gds_NeighborLoader.py +++ b/tests/test_gds_NeighborLoader.py @@ -108,7 +108,7 @@ def test_iterate_stop_pyg(self): break rq_id = self.conn.getRunningQueries()["results"] self.assertEqual(len(rq_id), 0) - + def test_whole_graph_pyg(self): loader = NeighborLoader( graph=self.conn, @@ -256,21 +256,21 @@ def run_loader(params: dict) -> int: neighbor_loader = NeighborLoader( graph=params["conn"], batch_size=8, - num_neighbors = 10, - num_hops =2, - v_in_feats = ["x"], - v_out_labels = ["y"], - v_extra_feats = ["train_mask", "val_mask", "test_mask"], - output_format = "PyG", + num_neighbors=10, + num_hops=2, + v_in_feats=["x"], + v_out_labels=["y"], + v_extra_feats=["train_mask", "val_mask", "test_mask"], + output_format="PyG", shuffle=False, - filter_by = "train_mask", + filter_by="train_mask", timeout=300000, kafka_address="kafka:9092", kafka_num_partitions=2, kafka_auto_offset_reset="earliest", kafka_group_id="test_group", - kafka_auto_del_topic = False, - kafka_skip_produce = params["kafka_skip_produce"] + kafka_auto_del_topic=False, + kafka_skip_produce=params["kafka_skip_produce"] ) num_batches = 0 num_seeds = 0 @@ -279,7 +279,7 @@ def run_loader(params: dict) -> int: num_batches += 1 num_seeds += data["is_seed"].sum().item() return num_batches, num_seeds - + def test_distributed_loaders(self): params = [ {"conn": self.conn, "kafka_skip_produce": False}, @@ -290,6 +290,7 @@ def test_distributed_loaders(self): self.assertEqual(res[0][0]+res[1][0], 18) self.assertEqual(res[0][1]+res[1][1], 140) + class TestGDSNeighborLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): @@ -682,17 +683,23 @@ def test_iterate_pyg(self): # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) self.assertGreater(data["v0"]["x"].shape[0], 0) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], + data["v0"]["y"].shape[0]) self.assertEqual( data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] ) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], + data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], + data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], + data["v0"]["val_mask"].shape[0]) self.assertGreater(data["v1"]["x"].shape[0], 0) - self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + self.assertEqual(data["v1"]["x"].shape[0], + data["v1"]["is_seed"].shape[0]) self.assertGreater(data["v2"]["x"].shape[0], 0) - self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + self.assertEqual(data["v2"]["x"].shape[0], + data["v2"]["is_seed"].shape[0]) self.assertTrue( data["v0v0"]["edge_index"].shape[1] > 0 and data["v0v0"]["edge_index"].shape[1] <= 710 @@ -741,17 +748,23 @@ def test_iterate_pyg_multichar_delimiter(self): # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) self.assertGreater(data["v0"]["x"].shape[0], 0) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], + data["v0"]["y"].shape[0]) self.assertEqual( data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] ) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], + data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], + data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], + data["v0"]["val_mask"].shape[0]) self.assertGreater(data["v1"]["x"].shape[0], 0) - self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + self.assertEqual(data["v1"]["x"].shape[0], + data["v1"]["is_seed"].shape[0]) self.assertGreater(data["v2"]["x"].shape[0], 0) - self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + self.assertEqual(data["v2"]["x"].shape[0], + data["v2"]["is_seed"].shape[0]) self.assertTrue( data["v0v0"]["edge_index"].shape[1] > 0 and data["v0v0"]["edge_index"].shape[1] <= 710 @@ -795,7 +808,8 @@ def test_fetch(self): buffer_size=4, ) data = loader.fetch( - [{"primary_id": "10", "type": "v0"}, {"primary_id": "55", "type": "v0"}] + [{"primary_id": "10", "type": "v0"}, { + "primary_id": "55", "type": "v0"}] ) self.assertIn("primary_id", data["v0"]) self.assertGreater(data["v0"]["x"].shape[0], 2) @@ -825,7 +839,8 @@ def test_fetch_delimiter(self): buffer_size=4, ) data = loader.fetch( - [{"primary_id": "10", "type": "v0"}, {"primary_id": "55", "type": "v0"}] + [{"primary_id": "10", "type": "v0"}, { + "primary_id": "55", "type": "v0"}] ) self.assertIn("primary_id", data["v0"]) self.assertGreater(data["v0"]["x"].shape[0], 2) @@ -837,7 +852,7 @@ def test_fetch_delimiter(self): self.assertTrue(data["v0"]["is_seed"][i].item()) else: self.assertFalse(data["v0"]["is_seed"][i].item()) - + def test_metadata(self): loader = NeighborLoader( graph=self.conn, @@ -861,10 +876,11 @@ def test_metadata(self): ("v2", "v2v0", "v0"), ("v2", "v2v1", "v1"), ("v2", "v2v2", "v2")]) - + metadata = loader.metadata() self.assertEqual(test, metadata) + if __name__ == "__main__": suite = unittest.TestSuite() suite.addTest(TestGDSNeighborLoaderKafka("test_init")) @@ -886,7 +902,8 @@ def test_metadata(self): suite.addTest(TestGDSHeteroNeighborLoaderREST("test_whole_graph_df")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_whole_graph_pyg")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_iterate_pyg")) - suite.addTest(TestGDSHeteroNeighborLoaderREST("test_iterate_pyg_multichar_delimiter")) + suite.addTest(TestGDSHeteroNeighborLoaderREST( + "test_iterate_pyg_multichar_delimiter")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_fetch")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_fetch_delimiter")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_metadata")) diff --git a/tests/test_gds_NodePiece.py b/tests/test_gds_NodePiece.py index f4ecd0f6..e19f0528 100644 --- a/tests/test_gds_NodePiece.py +++ b/tests/test_gds_NodePiece.py @@ -1,13 +1,13 @@ import unittest from pyTigerGraphUnitTest import make_connection -import torch import logging import os from pyTigerGraph.gds.models.NodePieceMLP import NodePieceMLPForVertexClassification from pyTigerGraph.gds.trainer import BaseCallback from pyTigerGraph.gds.transforms.nodepiece_transforms import NodePieceMLPTransform + class TestingCallback(BaseCallback): def __init__(self, test_name, output_dir="./logs"): self.output_dir = output_dir @@ -38,42 +38,43 @@ def setUpClass(cls): def test_init(self): model = NodePieceMLPForVertexClassification(num_layers=4, - hidden_dim=128, - out_dim=7, - dropout=0.5, - vocab_size=10, - sequence_length=20) + hidden_dim=128, + out_dim=7, + dropout=0.5, + vocab_size=10, + sequence_length=20) self.assertEqual(len(list(model.parameters())), 7) - self.assertEqual(model.model.base_embedding.embedding.weight.shape[0], 10) - self.assertEqual(model.model.base_embedding.embedding.weight.shape[1], 768) - + self.assertEqual( + model.model.base_embedding.embedding.weight.shape[0], 10) + self.assertEqual( + model.model.base_embedding.embedding.weight.shape[1], 768) def test_fit(self): - t = NodePieceMLPTransform(label = "y") + t = NodePieceMLPTransform(label="y") train_loader, valid_loader = self.conn.gds.nodepieceLoader( - v_feats=["y"], - target_vertex_types="Paper", - clear_cache=True, - compute_anchors=True, - filter_by=["train_mask", "val_mask"], - anchor_percentage=0.1, - max_anchors=10, - max_distance=10, - num_batches=5, - use_cache=False, - shuffle=False, - reverse_edge=True, - callback_fn = lambda x: t(x), - timeout=600_000) + v_feats=["y"], + target_vertex_types="Paper", + clear_cache=True, + compute_anchors=True, + filter_by=["train_mask", "val_mask"], + anchor_percentage=0.1, + max_anchors=10, + max_distance=10, + num_batches=5, + use_cache=False, + shuffle=False, + reverse_edge=True, + callback_fn=lambda x: t(x), + timeout=600_000) model = NodePieceMLPForVertexClassification(num_layers=4, - hidden_dim=128, - out_dim=7, - dropout=0.5, - vocab_size=train_loader.num_tokens, - sequence_length=20) + hidden_dim=128, + out_dim=7, + dropout=0.5, + vocab_size=train_loader.num_tokens, + sequence_length=20) - trainer_args = {"callbacks":[TestingCallback("cora_fit_np")]} + trainer_args = {"callbacks": [TestingCallback("cora_fit_np")]} model.fit(train_loader, valid_loader, 2, trainer_kwargs=trainer_args) ifLogged = os.path.isfile("./logs/train_results_cora_fit_np.log") diff --git a/tests/test_gds_NodePieceLoader.py b/tests/test_gds_NodePieceLoader.py index e34e7a84..95133b8a 100644 --- a/tests/test_gds_NodePieceLoader.py +++ b/tests/test_gds_NodePieceLoader.py @@ -3,7 +3,6 @@ from pyTigerGraphUnitTest import make_connection from pandas import DataFrame -from pyTigerGraph import TigerGraphConnection from pyTigerGraph.gds.dataloaders import NodePieceLoader from pyTigerGraph.gds.utilities import is_query_installed @@ -144,6 +143,7 @@ def test_sasl_ssl(self): num_batches += 1 self.assertEqual(num_batches, 9) + class TestGDSNodePieceLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): @@ -224,7 +224,7 @@ def test_init(self): compute_anchors=True, anchor_percentage=0.5, v_feats={"v0": ["x", "y"], - "v1": ["x"]}, + "v1": ["x"]}, batch_size=20, shuffle=True, filter_by=None, @@ -240,7 +240,7 @@ def test_iterate(self): anchor_percentage=0.5, graph=self.conn, v_feats={"v0": ["x", "y"], - "v1": ["x"]}, + "v1": ["x"]}, batch_size=20, shuffle=True, filter_by=None, @@ -268,7 +268,7 @@ def test_all_vertices(self): compute_anchors=True, anchor_percentage=0.5, v_feats={"v0": ["x", "y"], - "v1": ["x"]}, + "v1": ["x"]}, num_batches=1, shuffle=False, filter_by=None, @@ -293,7 +293,7 @@ def test_all_vertices(self): suite.addTest(TestGDSNodePieceLoader("test_init")) suite.addTest(TestGDSNodePieceLoader("test_iterate")) suite.addTest(TestGDSNodePieceLoader("test_all_vertices")) - #suite.addTest(TestGDSNodePieceLoader("test_sasl_plaintext")) + # suite.addTest(TestGDSNodePieceLoader("test_sasl_plaintext")) # suite.addTest(TestGDSNodePieceLoader("test_sasl_ssl")) suite.addTest(TestGDSNodePieceLoaderREST("test_init")) suite.addTest(TestGDSNodePieceLoaderREST("test_iterate")) diff --git a/tests/test_gds_Trainer.py b/tests/test_gds_Trainer.py index 7c8bd10d..343dada9 100644 --- a/tests/test_gds_Trainer.py +++ b/tests/test_gds_Trainer.py @@ -17,6 +17,7 @@ from torch_geometric.data import HeteroData as pygHeteroData ''' + class TestingCallback(BaseCallback): def __init__(self, test_name, output_dir="./logs"): self.output_dir = output_dir @@ -39,6 +40,7 @@ def on_eval_end(self, trainer): def on_epoch_end(self, trainer): trainer.eval() + class TestGDSTrainer(unittest.TestCase): @classmethod def setUpClass(cls): @@ -77,12 +79,13 @@ def testHomogeneousVertexClassTraining(self): buffer_size=4, ) - gs = GraphSAGEForVertexClassification(num_layers=2, - out_dim=7, + gs = GraphSAGEForVertexClassification(num_layers=2, + out_dim=7, dropout=.2, hidden_dim=128) - trainer = Trainer(gs, train, valid, callbacks=[TestingCallback("cora_class")]) + trainer = Trainer(gs, train, valid, callbacks=[ + TestingCallback("cora_class")]) trainer.train(num_epochs=1) ifLogged = os.path.isfile("./logs/train_results_cora_class.log") @@ -104,17 +107,19 @@ def testHomogeneousVertexClassPredict(self): buffer_size=4, ) - gs = GraphSAGEForVertexClassification(num_layers=2, - out_dim=7, + gs = GraphSAGEForVertexClassification(num_layers=2, + out_dim=7, dropout=.2, hidden_dim=128) - trainer = Trainer(gs, train, valid, callbacks=[TestingCallback("cora_class")]) + trainer = Trainer(gs, train, valid, callbacks=[ + TestingCallback("cora_class")]) trainer.train(num_epochs=1) - out, _ = trainer.predict(infer.fetch([{"primary_id": 1, "type": "Paper"}])) + out, _ = trainer.predict(infer.fetch( + [{"primary_id": 1, "type": "Paper"}])) self.assertEqual(out.shape[1], 7) if __name__ == "__main__": - unittest.main(verbosity=2, failfast=True) \ No newline at end of file + unittest.main(verbosity=2, failfast=True) diff --git a/tests/test_gds_VertexLoader.py b/tests/test_gds_VertexLoader.py index 3b76890d..b672c845 100644 --- a/tests/test_gds_VertexLoader.py +++ b/tests/test_gds_VertexLoader.py @@ -126,6 +126,7 @@ def test_sasl_ssl(self): num_batches += 1 self.assertEqual(num_batches, 9) + class TestGDSVertexLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): @@ -319,12 +320,14 @@ def test_all_vertices_multichar_delimiter(self): suite.addTest(TestGDSVertexLoaderREST("test_init")) suite.addTest(TestGDSVertexLoaderREST("test_iterate")) suite.addTest(TestGDSVertexLoaderREST("test_all_vertices")) - suite.addTest(TestGDSVertexLoaderREST("test_all_vertices_multichar_delimiter")) + suite.addTest(TestGDSVertexLoaderREST( + "test_all_vertices_multichar_delimiter")) suite.addTest(TestGDSVertexLoaderREST("test_string_attr")) suite.addTest(TestGDSHeteroVertexLoaderREST("test_init")) suite.addTest(TestGDSHeteroVertexLoaderREST("test_iterate")) suite.addTest(TestGDSHeteroVertexLoaderREST("test_all_vertices")) - suite.addTest(TestGDSHeteroVertexLoaderREST("test_all_vertices_multichar_delimiter")) + suite.addTest(TestGDSHeteroVertexLoaderREST( + "test_all_vertices_multichar_delimiter")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_featurizer.py b/tests/test_gds_featurizer.py index b873d404..8c265f87 100644 --- a/tests/test_gds_featurizer.py +++ b/tests/test_gds_featurizer.py @@ -79,14 +79,15 @@ def test_listAlgorithms_category(self, mock_stdout): 10. name: tg_harmonic_cent Call runAlgorithm() with the algorithm name to execute it """ - self.maxDiff=None + self.maxDiff = None self.assertEqual(mock_stdout.getvalue(), dedent(truth)) def test_install_query_file(self): query_path = "https://raw.githubusercontent.com/tigergraph/gsql-graph-algorithms/3.7/algorithms/Centrality/pagerank/global/unweighted/tg_pagerank.gsql" - resp = self.featurizer._install_query_file(query_path) + resp = self.featurizer._install_query_file(query_path) self.assertEqual(resp, "tg_pagerank") - self.assertTrue(is_query_installed(self.featurizer.conn, "tg_pagerank")) + self.assertTrue(is_query_installed( + self.featurizer.conn, "tg_pagerank")) def test_get_algo_details(self): path = os.path.dirname(os.path.realpath(__file__)) @@ -95,12 +96,12 @@ def test_get_algo_details(self): res = self.featurizer._get_algo_details(algo_dict["Path"]) self.assertDictEqual( res[0], - {'tg_bfs': ['https://raw.githubusercontent.com/tigergraph/gsql-graph-algorithms/3.8/algorithms/Path/bfs/tg_bfs.gsql'], - 'tg_cycle_detection_count': ['https://raw.githubusercontent.com/tigergraph/gsql-graph-algorithms/3.8/algorithms/Path/cycle_detection/count/tg_cycle_detection_count.gsql'], + {'tg_bfs': ['https://raw.githubusercontent.com/tigergraph/gsql-graph-algorithms/3.8/algorithms/Path/bfs/tg_bfs.gsql'], + 'tg_cycle_detection_count': ['https://raw.githubusercontent.com/tigergraph/gsql-graph-algorithms/3.8/algorithms/Path/cycle_detection/count/tg_cycle_detection_count.gsql'], 'tg_shortest_ss_no_wt': ['https://raw.githubusercontent.com/tigergraph/gsql-graph-algorithms/3.8/algorithms/Path/shortest_path/unweighted/tg_shortest_ss_no_wt.gsql']}) self.assertDictEqual( res[1], - {'tg_bfs': "INT", + {'tg_bfs': "INT", 'tg_shortest_ss_no_wt': "INT"}) def test_get_params(self): @@ -152,14 +153,14 @@ def test_get_params_emtpy(self): def test_getParams(self): params = self.featurizer.getParams("tg_pagerank", printout=False) truth = { - "v_type": None, + "v_type": None, "e_type": None, - "max_change": 0.001, - "maximum_iteration": 25, - "damping": 0.85, + "max_change": 0.001, + "maximum_iteration": 25, + "damping": 0.85, "top_k": 100, - "print_results": True, - "result_attribute": "", + "print_results": True, + "result_attribute": "", "file_path": "", "display_edges": False } @@ -184,41 +185,48 @@ def test_getParams_print(self, mock_stdout): self.assertEqual(mock_stdout.getvalue(), dedent(truth)) def test01_add_attribute(self): - self.assertEqual(add_attribute(self.conn, "VERTEX", "FLOAT", "attr1", global_change=False), 'Schema change succeeded.') + self.assertEqual(add_attribute(self.conn, "VERTEX", "FLOAT", + "attr1", global_change=False), 'Schema change succeeded.') def test02_add_attribute(self): - self.assertEqual(add_attribute(self.conn, "Edge", "BOOL", "attr2", global_change=False), 'Schema change succeeded.') - + self.assertEqual(add_attribute(self.conn, "Edge", "BOOL", + "attr2", global_change=False), 'Schema change succeeded.') + def test03_add_attribute(self): - self.assertEqual(add_attribute(self.conn, "Vertex", "BOOL", "attr1", global_change=False), 'Attribute already exists') + self.assertEqual(add_attribute(self.conn, "Vertex", "BOOL", + "attr1", global_change=False), 'Attribute already exists') def test04_add_attribute(self): with self.assertRaises(Exception) as context: - add_attribute(self.conn, "Something","BOOL","attr3") - self.assertTrue('schema_type has to be VERTEX or EDGE' in str(context.exception)) - + add_attribute(self.conn, "Something", "BOOL", "attr3") + self.assertTrue( + 'schema_type has to be VERTEX or EDGE' in str(context.exception)) + def test05_add_attribute(self): - self.assertEqual(add_attribute(self.conn, "VERTEX", "BOOL", "attr4", ['Paper'], global_change=False), 'Schema change succeeded.') + self.assertEqual(add_attribute(self.conn, "VERTEX", "BOOL", "attr4", [ + 'Paper'], global_change=False), 'Schema change succeeded.') def test01_installAlgorithm(self): - self.assertEqual(self.featurizer.installAlgorithm("tg_pagerank"), "tg_pagerank") + self.assertEqual(self.featurizer.installAlgorithm( + "tg_pagerank"), "tg_pagerank") def test02_installAlgorithm(self): with self.assertRaises(Exception): self.featurizer.installAlgorithm("someQuery") - + def test01_runAlgorithm(self): params = {'v_type': 'Paper', - 'e_type': 'Cite', - 'max_change': 0.001, - 'maximum_iteration': 25, - 'damping': 0.85, - 'top_k': 100, - 'print_results': True, - 'result_attribute': 'pagerank', - 'file_path': '', - 'display_edges': True} - self.assertIsNotNone(self.featurizer.runAlgorithm("tg_pagerank", params=params)) + 'e_type': 'Cite', + 'max_change': 0.001, + 'maximum_iteration': 25, + 'damping': 0.85, + 'top_k': 100, + 'print_results': True, + 'result_attribute': 'pagerank', + 'file_path': '', + 'display_edges': True} + self.assertIsNotNone(self.featurizer.runAlgorithm( + "tg_pagerank", params=params)) def test02_runAlgorithm(self): with self.assertRaises(ValueError) as error: @@ -226,15 +234,16 @@ def test02_runAlgorithm(self): self.assertIn('Missing mandatory parameters:', str(error.exception)) with self.assertRaises(ValueError) as error: - self.featurizer.runAlgorithm("tg_pagerank", params= {'v_type': 'Paper'}) + self.featurizer.runAlgorithm( + "tg_pagerank", params={'v_type': 'Paper'}) self.assertIn('Missing mandatory parameters:', str(error.exception)) with self.assertRaises(ValueError) as error: - self.featurizer.runAlgorithm("tg_pagerank", params= {'foo': 'bar'}) + self.featurizer.runAlgorithm("tg_pagerank", params={'foo': 'bar'}) self.assertIn("Unknown parameters: ['foo']", str(error.exception)) def test03_runAlgorithm(self): - params={ + params = { "v_type": ["Paper"], "e_type": ["Cite"], "output_v_type": ["Paper"], @@ -253,77 +262,84 @@ def test03_runAlgorithm(self): def test06_installCustomAlgorithm(self): path = os.path.dirname(os.path.realpath(__file__)) fname = os.path.join(path, "fixtures/create_query_simple.gsql") - out = self.featurizer.installAlgorithm("simple_query", query_path=fname) + out = self.featurizer.installAlgorithm( + "simple_query", query_path=fname) self.assertEqual(out, "simple_query") - + def test07_runCustomAlgorithm(self): - out = self.featurizer.runAlgorithm("simple_query", params={}, custom_query=True) + out = self.featurizer.runAlgorithm( + "simple_query", params={}, custom_query=True) self.assertEqual(out[0]['"Hello World!"'], "Hello World!") def test08_runAlgorithm_async_qid(self): params = {'v_type': 'Paper', - 'e_type': 'Cite', - 'max_change': 0.001, - 'maximum_iteration': 25, - 'damping': 0.85, - 'top_k': 100, - 'print_results': True, - 'result_attribute': 'pagerank', - 'file_path': '', - 'display_edges': True} - ret = self.featurizer.runAlgorithm("tg_pagerank", params=params, runAsync=True) + 'e_type': 'Cite', + 'max_change': 0.001, + 'maximum_iteration': 25, + 'damping': 0.85, + 'top_k': 100, + 'print_results': True, + 'result_attribute': 'pagerank', + 'file_path': '', + 'display_edges': True} + ret = self.featurizer.runAlgorithm( + "tg_pagerank", params=params, runAsync=True) self.assertIsNotNone(ret.query_id) def test09_runAlgorithm_async_wait(self): params = {'v_type': 'Paper', - 'e_type': 'Cite', - 'max_change': 0.001, - 'maximum_iteration': 25, - 'damping': 0.85, - 'top_k': 100, - 'print_results': True, - 'result_attribute': 'pagerank', - 'file_path': '', - 'display_edges': True} - ret = self.featurizer.runAlgorithm("tg_pagerank", params=params, runAsync=True) + 'e_type': 'Cite', + 'max_change': 0.001, + 'maximum_iteration': 25, + 'damping': 0.85, + 'top_k': 100, + 'print_results': True, + 'result_attribute': 'pagerank', + 'file_path': '', + 'display_edges': True} + ret = self.featurizer.runAlgorithm( + "tg_pagerank", params=params, runAsync=True) self.assertIsNotNone(ret.wait()) def test_get_template_queries(self): if (self.featurizer.major_ver != "master" and ( - int(self.featurizer.major_ver) < 3 or ( + int(self.featurizer.major_ver) < 3 or ( int(self.featurizer.major_ver) == 3 and int(self.featurizer.minor_ver) < 8) - ) + ) ): print("Skip test_get_template_queries as the DB version is not supported.") return self.conn.gsql("IMPORT PACKAGE GDBMS_ALGO") self.featurizer._get_template_queries() self.assertIn("centrality", self.featurizer.template_queries) - self.assertIn("article_rank(string v_type, string e_type, float max_change, int maximum_iteration, float damping, int top_k, bool print_results, string result_attribute, string file_path)", self.featurizer.template_queries["centrality"]) + self.assertIn("article_rank(string v_type, string e_type, float max_change, int maximum_iteration, float damping, int top_k, bool print_results, string result_attribute, string file_path)", + self.featurizer.template_queries["centrality"]) def test_template_query(self): if (self.featurizer.major_ver != "master" and ( - int(self.featurizer.major_ver) < 3 or ( + int(self.featurizer.major_ver) < 3 or ( int(self.featurizer.major_ver) == 3 and int(self.featurizer.minor_ver) < 9) - ) + ) ): print("Skip test_template_query as the DB version is not supported.") return params = {'v_type': 'Paper', - 'e_type': 'Cite', - 'max_change': 0.001, - 'maximum_iteration': 25, - 'damping': 0.85, - 'top_k': 100, - 'print_results': True, - 'result_attribute': 'pagerank', - 'file_path': '', - 'display_edges': False} - - resp = self.featurizer.runAlgorithm("tg_pagerank", params, templateQuery=True) + 'e_type': 'Cite', + 'max_change': 0.001, + 'maximum_iteration': 25, + 'damping': 0.85, + 'top_k': 100, + 'print_results': True, + 'result_attribute': 'pagerank', + 'file_path': '', + 'display_edges': False} + + resp = self.featurizer.runAlgorithm( + "tg_pagerank", params, templateQuery=True) self.assertIn("@@top_scores_heap", resp[0]) self.assertEqual(len(resp[0]["@@top_scores_heap"]), 100) + if __name__ == '__main__': suite = unittest.TestSuite() suite.addTest(test_Featurizer("test_get_db_version")) @@ -345,16 +361,13 @@ def test_template_query(self): suite.addTest(test_Featurizer("test02_installAlgorithm")) suite.addTest(test_Featurizer("test01_runAlgorithm")) suite.addTest(test_Featurizer("test02_runAlgorithm")) - suite.addTest(test_Featurizer("test03_runAlgorithm")) + suite.addTest(test_Featurizer("test03_runAlgorithm")) suite.addTest(test_Featurizer("test06_installCustomAlgorithm")) suite.addTest(test_Featurizer("test07_runCustomAlgorithm")) suite.addTest(test_Featurizer("test08_runAlgorithm_async_qid")) suite.addTest(test_Featurizer("test09_runAlgorithm_async_wait")) suite.addTest(test_Featurizer("test_get_template_queries")) suite.addTest(test_Featurizer("test_template_query")) - + runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) - - - \ No newline at end of file diff --git a/tests/test_gds_metrics.py b/tests/test_gds_metrics.py index fcb1bce0..931eb0c7 100644 --- a/tests/test_gds_metrics.py +++ b/tests/test_gds_metrics.py @@ -2,17 +2,17 @@ import numpy as np from pyTigerGraph.gds.metrics import (Accumulator, - Accuracy, - BinaryPrecision, - BinaryRecall, - Recall, - Precision, - MSE, - RMSE, - MAE, - HitsAtK, - RecallAtK, - ConfusionMatrix) + Accuracy, + BinaryPrecision, + BinaryRecall, + Recall, + Precision, + MSE, + RMSE, + MAE, + HitsAtK, + RecallAtK, + ConfusionMatrix) class TestGDSAccumulator(unittest.TestCase): @@ -86,6 +86,7 @@ def test_update(self): measure.update(preds, truth) self.assertEqual(measure.value, 0.5) + class TestGDSRecall(unittest.TestCase): def test_init(self): measure = Recall(num_classes=2) @@ -119,6 +120,7 @@ def test_update(self): measure.update(preds, truth) self.assertEqual(measure.value, 0.5) + class TestGDSMSE(unittest.TestCase): def test_init(self): measure = MSE() @@ -135,6 +137,7 @@ def test_update(self): measure.update(preds, truth) self.assertEqual(measure.value, 0.15) + class TestGDSRMSE(unittest.TestCase): def test_init(self): measure = RMSE() @@ -168,6 +171,7 @@ def test_update(self): measure.update(preds, truth) self.assertEqual(measure.value, 0.3) + class TestGDSHitsAtK(unittest.TestCase): def test_init(self): measure = HitsAtK(k=1) @@ -184,6 +188,7 @@ def test_update(self): measure.update(preds, truth) self.assertEqual(measure.value, 0.5) + class TestGDSRecallAtK(unittest.TestCase): def test_init(self): measure = RecallAtK(k=1) @@ -200,6 +205,7 @@ def test_update(self): measure.update(preds, truth) self.assertEqual(measure.value, 1/3) + class TestGDSConfusionMatrix(unittest.TestCase): def test_init(self): measure = ConfusionMatrix(num_classes=2) @@ -210,22 +216,23 @@ def test_update(self): preds = np.array([1, 1]) truth = np.array([1, 0]) measure.update(preds, truth) - self.assertEqual(measure.value.values[1,1], 1) + self.assertEqual(measure.value.values[1, 1], 1) preds = np.array([1, 1]) truth = np.array([1, 0]) measure.update(preds, truth) - self.assertEqual(measure.value.values[1,1], 2) + self.assertEqual(measure.value.values[1, 1], 2) def test_update_multiclass(self): measure = ConfusionMatrix(num_classes=4) preds = np.array([1, 1, 3, 2]) truth = np.array([1, 0, 3, 0]) measure.update(preds, truth) - self.assertEqual(measure.value.values[3,3], 1) + self.assertEqual(measure.value.values[3, 3], 1) preds = np.array([1, 1]) truth = np.array([1, 0]) measure.update(preds, truth) - self.assertEqual(measure.value.values[1,1], 2) + self.assertEqual(measure.value.values[1, 1], 2) + if __name__ == "__main__": unittest.main(verbosity=2, failfast=True) diff --git a/tests/test_gds_splitters.py b/tests/test_gds_splitters.py index 45fd900c..f4d6e6b5 100644 --- a/tests/test_gds_splitters.py +++ b/tests/test_gds_splitters.py @@ -21,7 +21,8 @@ def test_bad_attr(self): with self.assertRaises(ValueError): splitter = self.conn.gds.vertexSplitter(train_mask=1.1) with self.assertRaises(ValueError): - splitter = self.conn.gds.vertexSplitter(train_mask=0.6, val_mask=0.7) + splitter = self.conn.gds.vertexSplitter( + train_mask=0.6, val_mask=0.7) def test_one_attr(self): splitter = self.conn.gds.vertexSplitter(train_mask=0.6) @@ -85,12 +86,14 @@ def test_override_attr(self): self.assertAlmostEqual(p3_count / num_vertices, 0.7, delta=0.05) def test_v_types(self): - splitter = self.conn.gds.vertexSplitter(v_types=["Paper"],train_mask=0.3) + splitter = self.conn.gds.vertexSplitter( + v_types=["Paper"], train_mask=0.3) splitter.run() num_vertices = self.conn.getVertexCount("Paper") p1_count = self.conn.getVertexCount("Paper", where="train_mask!=0") self.assertAlmostEqual(p1_count / num_vertices, 0.3, delta=0.05) + def get_edge_count(conn: TigerGraphConnection, attribute: str): gsql = """ USE GRAPH Cora\n diff --git a/tests/test_gds_transforms.py b/tests/test_gds_transforms.py index df3858d2..e641d4e1 100644 --- a/tests/test_gds_transforms.py +++ b/tests/test_gds_transforms.py @@ -3,32 +3,34 @@ import torch_geometric as pyg import torch + class TestPyGTemporalTransform(unittest.TestCase): def test_init(self): - vertex_start_attrs = {"Customer": "start_dt", "Item": "start_ts", "ItemInstance": "start_dt"} + vertex_start_attrs = {"Customer": "start_dt", + "Item": "start_ts", "ItemInstance": "start_dt"} vertex_end_attrs = {"Item": "end_ts", "ItemInstance": "end_dt"} edge_start_attrs = {("Customer", "PURCHASED", "Item"): "purchase_time"} edge_end_attrs = {("Customer", "PURCHASED", "Item"): "purchase_time"} - feature_transforms = {("ItemInstance", "reverse_DESCRIBED_BY", "Item"): ["x"]} + feature_transforms = { + ("ItemInstance", "reverse_DESCRIBED_BY", "Item"): ["x"]} transform = TemporalPyGTransform(vertex_start_attrs=vertex_start_attrs, - vertex_end_attrs=vertex_end_attrs, - edge_start_attrs=edge_start_attrs, - edge_end_attrs=edge_end_attrs, - start_dt=0, - end_dt=6, - feature_transforms=feature_transforms) + vertex_end_attrs=vertex_end_attrs, + edge_start_attrs=edge_start_attrs, + edge_end_attrs=edge_end_attrs, + start_dt=0, + end_dt=6, + feature_transforms=feature_transforms) self.assertEqual(transform.timestep, 86400) - - + def test_homogeneous_transform(self): data = pyg.data.Data() data.x = torch.randn(4, 3) data.edge_index = torch.tensor([[0, 3, 2], [1, 2, 0]]) - data.vertex_start = torch.tensor([0,0,2,1]) - data.vertex_end = torch.tensor([5,5,5,5]) - data.edge_start = torch.tensor([0,2,3]) - data.edge_end = torch.tensor([5,5,5]) + data.vertex_start = torch.tensor([0, 0, 2, 1]) + data.vertex_end = torch.tensor([5, 5, 5, 5]) + data.edge_start = torch.tensor([0, 2, 3]) + data.edge_end = torch.tensor([5, 5, 5]) vertex_start_attrs = "vertex_start" vertex_end_attrs = "vertex_end" edge_start_attrs = "edge_start" @@ -44,50 +46,56 @@ def test_homogeneous_transform(self): self.assertEqual(seq[0].edge_index.shape[1], 1) self.assertEqual(seq[-1].edge_index.shape[1], 3) self.assertEqual(len(seq), 5) - + def test_heterogeneous_transform(self): data = pyg.data.HeteroData() data["ItemInstance"].x = torch.tensor([5, 6, 3, 2.5]) - data["ItemInstance"].start_dt = torch.tensor([0,4,3,1]) + data["ItemInstance"].start_dt = torch.tensor([0, 4, 3, 1]) data["ItemInstance"].end_dt = torch.tensor([6, 6, 4, 3]) data["ItemInstance"].id = ["2_1", "1_3", "1_2", "1_1"] data["ItemInstance"].is_seed = torch.tensor([True, True, True, True]) - data["Item"].start_ts = torch.tensor([1,0]) + data["Item"].start_ts = torch.tensor([1, 0]) data["Item"].end_ts = torch.tensor([-1, -1]) data["Item"].is_seed = torch.tensor([True, True]) - data["Customer"].start_dt = torch.tensor([0,1]) + data["Customer"].start_dt = torch.tensor([0, 1]) data["Customer"].is_seed = torch.tensor([True, True]) data["ZipCode"].id = torch.tensor([55369]) data["ZipCode"].is_seed = torch.tensor([True]) - data[("ItemInstance", "reverse_DESCRIBED_BY", "Item")].edge_index = torch.tensor([[3,2,1,0], - [0,0,0,1]]) - data[("Item", "DESCRIBED_BY", "ItemInstance")].edge_index = torch.tensor([[0,0,0,1], - [2,1,3,0]]) - data[("Customer", "LIVES_IN", "ZipCode")].edge_index = torch.tensor([[0,1], - [0,0]]) - data[("Customer", "PURCHASED", "Item")].edge_index = torch.tensor([[1,1,1], - [1,0,0]]) - data[("Customer", "PURCHASED", "Item")].purchase_time = torch.tensor([3, 1, 4]) + data[("ItemInstance", "reverse_DESCRIBED_BY", "Item")].edge_index = torch.tensor([[3, 2, 1, 0], + [0, 0, 0, 1]]) + data[("Item", "DESCRIBED_BY", "ItemInstance")].edge_index = torch.tensor([[0, 0, 0, 1], + [2, 1, 3, 0]]) + data[("Customer", "LIVES_IN", "ZipCode")].edge_index = torch.tensor([[0, 1], + [0, 0]]) + data[("Customer", "PURCHASED", "Item")].edge_index = torch.tensor([[1, 1, 1], + [1, 0, 0]]) + data[("Customer", "PURCHASED", "Item") + ].purchase_time = torch.tensor([3, 1, 4]) - vertex_start_attrs = {"Customer": "start_dt", "Item": "start_ts", "ItemInstance": "start_dt"} + vertex_start_attrs = {"Customer": "start_dt", + "Item": "start_ts", "ItemInstance": "start_dt"} vertex_end_attrs = {"Item": "end_ts", "ItemInstance": "end_dt"} edge_start_attrs = {("Customer", "PURCHASED", "Item"): "purchase_time"} edge_end_attrs = {("Customer", "PURCHASED", "Item"): "purchase_time"} - feature_transforms = {("ItemInstance", "reverse_DESCRIBED_BY", "Item"): ["x"]} + feature_transforms = { + ("ItemInstance", "reverse_DESCRIBED_BY", "Item"): ["x"]} transform = TemporalPyGTransform(vertex_start_attrs=vertex_start_attrs, - vertex_end_attrs=vertex_end_attrs, - edge_start_attrs=edge_start_attrs, - edge_end_attrs=edge_end_attrs, - start_dt=0, - end_dt=6, - feature_transforms=feature_transforms, - timestep=1) + vertex_end_attrs=vertex_end_attrs, + edge_start_attrs=edge_start_attrs, + edge_end_attrs=edge_end_attrs, + start_dt=0, + end_dt=6, + feature_transforms=feature_transforms, + timestep=1) transformed_data = transform(data) - self.assertEqual(transformed_data[0]["Item"]["ItemInstance_x"][1], 5) # price of second item should be 5 + # price of second item should be 5 + self.assertEqual(transformed_data[0]["Item"]["ItemInstance_x"][1], 5) self.assertEqual(transformed_data[4]["Item"]["ItemInstance_x"][0], 6) - self.assertEqual(transformed_data[5][("Customer", "PURCHASED", "ITEM")], {}) + self.assertEqual(transformed_data[5][( + "Customer", "PURCHASED", "ITEM")], {}) self.assertEqual(len(transformed_data), 6) + if __name__ == "__main__": unittest.main(verbosity=2, failfast=True) diff --git a/tests/test_gds_utilities.py b/tests/test_gds_utilities.py index 5b01b4c5..42362165 100644 --- a/tests/test_gds_utilities.py +++ b/tests/test_gds_utilities.py @@ -19,7 +19,8 @@ def test_is_query_installed(self): def test_install_query_file(self): resp = utils.install_query_file( self.conn, - os.path.join(os.path.dirname(__file__), "fixtures/create_query_simple.gsql") + os.path.join(os.path.dirname(__file__), + "fixtures/create_query_simple.gsql") ) self.assertEqual(resp, "simple_query") self.assertTrue(utils.is_query_installed(self.conn, "simple_query")) @@ -27,14 +28,16 @@ def test_install_query_file(self): def test_install_exist_query(self): resp = utils.install_query_file( self.conn, - os.path.join(os.path.dirname(__file__), "fixtures/create_query_simple.gsql") + os.path.join(os.path.dirname(__file__), + "fixtures/create_query_simple.gsql") ) self.assertEqual(resp, "simple_query") def test_install_query_by_force(self): resp = utils.install_query_file( self.conn, - os.path.join(os.path.dirname(__file__), "fixtures/create_query_simple2.gsql"), + os.path.join(os.path.dirname(__file__), + "fixtures/create_query_simple2.gsql"), force=True ) self.assertEqual(resp, "simple_query") @@ -48,12 +51,14 @@ def test_install_query_template(self): } resp = utils.install_query_file( self.conn, - os.path.join(os.path.dirname(__file__), "fixtures/create_query_template.gsql"), + os.path.join(os.path.dirname(__file__), + "fixtures/create_query_template.gsql"), replace ) self.assertEqual(resp, "simple_query_something_special") self.assertTrue( - utils.is_query_installed(self.conn, "simple_query_something_special") + utils.is_query_installed( + self.conn, "simple_query_something_special") ) diff --git a/tests/test_jwtAuth.py b/tests/test_jwtAuth.py index d37d952a..7ba0b58e 100644 --- a/tests/test_jwtAuth.py +++ b/tests/test_jwtAuth.py @@ -4,14 +4,13 @@ from pyTigerGraphUnitTest import make_connection from pyTigerGraph import TigerGraphConnection -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException class TestJWTTokenAuth(unittest.TestCase): @classmethod def setUpClass(cls): - cls.conn = make_connection(graphname="Cora") - + cls.conn = make_connection(graphname="tests") def test_jwtauth(self): dbversion = self.conn.getVer() @@ -23,11 +22,10 @@ def test_jwtauth(self): self._test_jwtauth_4_1_fail() else: pass - def _requestJWTToken(self): # in >=4.1 API all tokens are JWT tokens - if self.conn._versionGreaterThan4_0(): + if self.conn._version_greater_than_4_0(): return self.conn.getToken(self.conn.createSecret())[0] # Define the URL @@ -39,9 +37,9 @@ def _requestJWTToken(self): 'Content-Type': 'application/json' } # Make the POST request with basic authentication - response = requests.post(url, data=payload, headers=headers, auth=(self.conn.username, self.conn.password)) + response = requests.post(url, data=payload, headers=headers, auth=( + self.conn.username, self.conn.password)) return response.json()['token'] - def _test_jwtauth_3_9(self): with self.assertRaises(TigerGraphException) as context: @@ -51,8 +49,8 @@ def _test_jwtauth_3_9(self): ) # Verify the exception message - self.assertIn("switch to API token or username/password.", str(context.exception)) - + self.assertIn("switch to API token or username/password.", + str(context.exception)) def _test_jwtauth_4_1_success(self): jwt_token = self._requestJWTToken() @@ -63,22 +61,23 @@ def _test_jwtauth_4_1_success(self): ) authheader = newconn.authHeader - print (f"authheader from new conn: {authheader}") + print(f"authheader from new conn: {authheader}") # restpp on port 9000 dbversion = newconn.getVer() - print (f"dbversion from new conn: {dbversion}") + print(f"dbversion from new conn: {dbversion}") self.assertIn("4.1", str(dbversion)) # gsql on port 14240 - if self.conn._versionGreaterThan4_0(): - res = newconn._get(f"{newconn.gsUrl}/gsql/v1/auth/simple", authMode="token", resKey=None) + if self.conn._version_greater_than_4_0(): + res = newconn._get( + f"{newconn.gsUrl}/gsql/v1/auth/simple", authMode="token", resKey=None) res = res['results'] else: - res = newconn._get(f"{self.conn.host}:{self.conn.gsPort}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None) + res = newconn._get( + f"{self.conn.host}:{self.conn.gsPort}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None) self.assertIn("privileges", res) - def _test_jwtauth_4_1_fail(self): with self.assertRaises(TigerGraphException) as context: TigerGraphConnection( @@ -95,5 +94,5 @@ def _test_jwtauth_4_1_fail(self): suite = unittest.TestSuite() suite.addTest(TestJWTTokenAuth("test_jwtauth")) - runner = unittest.TextTestRunner(verbosity=2, failfast=True) - runner.run(suite) \ No newline at end of file + runner = unittest.TextTestRunner(verbosity=2, failfast=True) + runner.run(suite) diff --git a/tests/test_pyTigerGraphAuth.py b/tests/test_pyTigerGraphAuth.py index db69fb5e..fa12ca36 100644 --- a/tests/test_pyTigerGraphAuth.py +++ b/tests/test_pyTigerGraphAuth.py @@ -2,7 +2,7 @@ from pyTigerGraphUnitTest import make_connection -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException class test_pyTigerGraphAuth(unittest.TestCase): @@ -40,7 +40,8 @@ def test_03_createSecret(self): with self.assertRaises(TigerGraphException) as tge: self.conn.createSecret("secret1") - self.assertEqual("The secret with alias secret1 already exists.", tge.exception.message) + self.assertEqual( + "The secret with alias secret1 already exists.", tge.exception.message) def test_04_dropSecret(self): res = self.conn.showSecrets() @@ -70,13 +71,13 @@ def test_05_getToken(self): ''' def test_06_refreshToken(self): # TG 4.x does not allow refreshing tokens - if self.conn._versionGreaterThan4_0(): + self.conn.getToken(self.conn.createSecret()) + if self.conn._version_greater_than_4_0(): with self.assertRaises(TigerGraphException) as tge: self.conn.refreshToken("secret1") - self.assertEqual("Refreshing tokens is only supported on versions of TigerGraph <= 4.0.0.", tge.exception.message) + self.assertEqual( + "Refreshing tokens is only supported on versions of TigerGraph <= 4.0.0.", tge.exception.message) else: - res = self.conn.createSecret("secret6", True) - token = self.conn.getToken(res["secret6"]) if isinstance(token, str): # handle plaintext tokens from TG 3.x refreshed = self.conn.refreshToken(res["secret6"], token) self.assertIsInstance(refreshed, str) @@ -86,6 +87,7 @@ def test_06_refreshToken(self): self.conn.dropSecret("secret6") ''' def test_07_deleteToken(self): + self.conn.dropSecret("secret7", ignoreErrors=True) res = self.conn.createSecret("secret7", True) token = self.conn.getToken(res["secret7"]) if isinstance(token, str): # handle plaintext tokens from TG 3.x @@ -94,5 +96,6 @@ def test_07_deleteToken(self): self.assertTrue(self.conn.deleteToken(res["secret7"], token[0])) self.conn.dropSecret("secret7") + if __name__ == '__main__': unittest.main() diff --git a/tests/test_pyTigerGraphAuthAsync.py b/tests/test_pyTigerGraphAuthAsync.py new file mode 100644 index 00000000..e9538ca3 --- /dev/null +++ b/tests/test_pyTigerGraphAuthAsync.py @@ -0,0 +1,93 @@ +import unittest + +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphAuthAsync(unittest.IsolatedAsyncioTestCase): + @classmethod + async def asyncSetUp(self): + self.conn = await make_connection() + + async def test_01_getSecrets(self): + res = await self.conn.showSecrets() + self.assertIsInstance(res, dict) + # self.assertEqual(3, len(res)) # Just in case more secrets than expected + self.assertIn("secret1", res) + self.assertIn("secret2", res) + self.assertIn("secret2", res) + + async def test_02_getSecret(self): + pass + # TODO Implement + + async def test_03_createSecret(self): + res = await self.conn.createSecret("secret4") + self.assertIsInstance(res, str) + + res = await self.conn.createSecret("secret5", True) + self.assertIsInstance(res, dict) + self.assertEqual(1, len(res)) + alias = list(res.keys())[0] + self.assertEqual("secret5", alias) + + res = await self.conn.createSecret(withAlias=True) + self.assertIsInstance(res, dict) + self.assertEqual(1, len(res)) + alias = list(res.keys())[0] + self.assertTrue(alias.startswith("AUTO_GENERATED_ALIAS_")) + + with self.assertRaises(TigerGraphException) as tge: + await self.conn.createSecret("secret1") + self.assertEqual( + "The secret with alias secret1 already exists.", tge.exception.message) + + async def test_04_dropSecret(self): + res = await self.conn.showSecrets() + for a in list(res.keys()): + if a.startswith("AUTO_GENERATED_ALIAS"): + res = await self.conn.dropSecret(a) + self.assertTrue("Successfully dropped secrets" in res) + + res = await self.conn.dropSecret(["secret4", "secret5"]) + self.assertTrue("Failed to drop secrets" not in res) + + res = await self.conn.dropSecret("non_existent_secret") + self.assertTrue("Failed to drop secrets" in res) + + with self.assertRaises(TigerGraphException) as tge: + res = await self.conn.dropSecret("non_existent_secret", False) + + async def test_05_getToken(self): + res = await self.conn.createSecret("secret5", True) + token = await self.conn.getToken(res["secret5"]) + self.assertIsInstance(token, tuple) + await self.conn.dropSecret("secret5") + ''' + async def test_06_refreshToken(self): + # TG 4.x does not allow refreshing tokens + await self.conn.getToken(await self.conn.createSecret()) + if await self.conn._version_greater_than_4_0(): + with self.assertRaises(TigerGraphException) as tge: + await self.conn.refreshToken("secret1") + self.assertEqual( + "Refreshing tokens is only supported on versions of TigerGraph <= 4.0.0.", tge.exception.message) + else: + await self.conn.dropSecret("secret6", ignoreErrors=True) + res = await self.conn.createSecret("secret6", True) + token = await self.conn.getToken(res["secret6"]) + refreshed = await self.conn.refreshToken(res["secret6"], token[0]) + self.assertIsInstance(refreshed, tuple) + await self.conn.dropSecret("secret6") + ''' + async def test_07_deleteToken(self): + await self.conn.dropSecret("secret7", ignoreErrors=True) + res = await self.conn.createSecret("secret7", True) + token = await self.conn.getToken(res["secret7"]) + self.assertTrue(await self.conn.deleteToken(res["secret7"], token[0])) + await self.conn.dropSecret("secret7") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pyTigerGraphBase.py b/tests/test_pyTigerGraphBase.py index e6a0da3f..2fd54ddc 100644 --- a/tests/test_pyTigerGraphBase.py +++ b/tests/test_pyTigerGraphBase.py @@ -3,7 +3,7 @@ from pyTigerGraphUnitTest import make_connection -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException class test_pyTigerGraphBase(unittest.TestCase): @@ -56,17 +56,17 @@ def test_00_errorCheck(self): "results": {} } - self.conn._errorCheck(json_ok1) - self.conn._errorCheck(json_ok2) - self.conn._errorCheck(json_ok3) - self.conn._errorCheck(json_ok4) + self.conn._error_check(json_ok1) + self.conn._error_check(json_ok2) + self.conn._error_check(json_ok3) + self.conn._error_check(json_ok4) with self.assertRaises(TigerGraphException) as tge: - res = self.conn._errorCheck(json_not_ok1) + res = self.conn._error_check(json_not_ok1) self.assertEqual("error message", tge.exception.message) with self.assertRaises(TigerGraphException) as tge: - res = self.conn._errorCheck(json_not_ok2) + res = self.conn._error_check(json_not_ok2) self.assertEqual("JB-007", tge.exception.code) def test_01_req(self): @@ -83,14 +83,16 @@ def test_03_post(self): self.assertEqual(exp, res) data = json.dumps({"function": "stat_vertex_attr", "type": "vertex4"}) - exp = [{'attributes': {'a01': {'AVG': 3, 'MAX': 5, 'MIN': 1}}, 'v_type': 'vertex4'}] - res = self.conn._post(self.conn.restppUrl + "/builtins/" + self.conn.graphname, data=data) + exp = [ + {'attributes': {'a01': {'AVG': 3, 'MAX': 5, 'MIN': 1}}, 'v_type': 'vertex4'}] + res = self.conn._post(self.conn.restppUrl + + "/builtins/" + self.conn.graphname, data=data) self.assertEqual(exp, res) def test_04_delete(self): with self.assertRaises(TigerGraphException) as tge: res = self.conn._delete(self.conn.restppUrl + "/graph/" + self.conn.graphname + - "/vertices/non_existent_vertex_type/1") + "/vertices/non_existent_vertex_type/1") self.assertEqual("REST-30000", tge.exception.code) diff --git a/tests/test_pyTigerGraphBaseAsync.py b/tests/test_pyTigerGraphBaseAsync.py new file mode 100644 index 00000000..d0059c6c --- /dev/null +++ b/tests/test_pyTigerGraphBaseAsync.py @@ -0,0 +1,98 @@ +import json +import unittest + +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphBaseAsync(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.conn = await make_connection() + + def test_00_errorCheck(self): + json_ok1 = { + "error": False, + "message": "", + "results": { + "GraphName": "tests" + } + } + + json_ok2 = { + "error": "false", + "message": "", + "results": { + "GraphName": "tests" + } + } + + json_ok3 = { + "error": "", + "message": "", + "results": { + "GraphName": "tests" + } + } + + json_ok4 = { + "message": "", + "results": { + "GraphName": "tests" + } + } + + json_not_ok1 = { + "error": True, + "message": "error message", + "results": {} + } + + json_not_ok2 = { + "error": "true", + "message": "error message", + "code": "JB-007", + "results": {} + } + + self.conn._error_check(json_ok1) + self.conn._error_check(json_ok2) + self.conn._error_check(json_ok3) + self.conn._error_check(json_ok4) + + with self.assertRaises(TigerGraphException) as tge: + res = self.conn._error_check(json_not_ok1) + self.assertEqual("error message", tge.exception.message) + + with self.assertRaises(TigerGraphException) as tge: + res = self.conn._error_check(json_not_ok2) + self.assertEqual("JB-007", tge.exception.code) + + def test_01_req(self): + pass + + async def test_02_get(self): + exp = {'error': False, 'message': 'Hello GSQL'} + res = await self.conn._get(self.conn.restppUrl + "/echo/", resKey=None) + self.assertEqual(exp, res) + + async def test_03_post(self): + exp = {'error': False, 'message': 'Hello GSQL'} + res = await self.conn._post(self.conn.restppUrl + "/echo/", resKey=None) + self.assertEqual(exp, res) + + data = json.dumps({"function": "stat_vertex_attr", "type": "vertex4"}) + exp = [ + {'attributes': {'a01': {'AVG': 3, 'MAX': 5, 'MIN': 1}}, 'v_type': 'vertex4'}] + res = await self.conn._post(self.conn.restppUrl + "/builtins/" + self.conn.graphname, data=data) + self.assertEqual(exp, res) + + async def test_04_delete(self): + with self.assertRaises(TigerGraphException) as tge: + res = await self.conn._delete(self.conn.restppUrl + "/graph/" + self.conn.graphname + + "/vertices/non_existent_vertex_type/1") + self.assertEqual("REST-30000", tge.exception.code) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pyTigerGraphEdge.py b/tests/test_pyTigerGraphEdge.py index a52ed06c..a46fd433 100644 --- a/tests/test_pyTigerGraphEdge.py +++ b/tests/test_pyTigerGraphEdge.py @@ -13,14 +13,9 @@ def setUpClass(cls): def test_01_getEdgeTypes(self): res = sorted(self.conn.getEdgeTypes()) self.assertEqual(6, len(res)) - exp = [ - "edge1_undirected", - "edge2_directed", - "edge3_directed_with_reverse", - "edge4_many_to_many", - "edge5_all_to_all", - "edge6_loop", - ] + exp = ["edge1_undirected", "edge2_directed", "edge3_directed_with_reverse", + "edge4_many_to_many", "edge5_all_to_all", "edge6_loop"] + self.assertEqual(exp, res) def test_02_getEdgeType(self): @@ -50,9 +45,8 @@ def test_02_getEdgeType(self): self.assertTrue(res["IsDirected"]) self.assertIn("Config", res) self.assertIn("REVERSE_EDGE", res["Config"]) - self.assertEqual( - "edge3_directed_with_reverse_reverse_edge", res["Config"]["REVERSE_EDGE"] - ) + self.assertEqual("edge3_directed_with_reverse_reverse_edge", + res["Config"]["REVERSE_EDGE"]) res = self.conn.getEdgeType("edge4_many_to_many") self.assertIsNotNone(res) @@ -101,10 +95,12 @@ def test_05_isDirected(self): def test_06_getReverseEdge(self): res = self.conn.getReverseEdge("edge1_undirected") self.assertIsInstance(res, str) - self.assertEqual("", res) # TODO Change this to None or something in getReverseEdge()? + # TODO Change this to None or something in getReverseEdge()? + self.assertEqual("", res) res = self.conn.getReverseEdge("edge2_directed") self.assertIsInstance(res, str) - self.assertEqual("", res) # TODO Change this to None or something in getReverseEdge()? + # TODO Change this to None or something in getReverseEdge()? + self.assertEqual("", res) res = self.conn.getReverseEdge("edge3_directed_with_reverse") self.assertIsInstance(res, str) self.assertEqual("edge3_directed_with_reverse_reverse_edge", res) @@ -121,15 +117,14 @@ def test_07_getEdgeCountFrom(self): self.assertIsInstance(res, int) self.assertEqual(8, res) - res = self.conn.getEdgeCountFrom( - sourceVertexType="vertex4", - edgeType="edge4_many_to_many", - targetVertexType="vertex5", - ) + res = self.conn.getEdgeCountFrom(sourceVertexType="vertex4", edgeType="edge4_many_to_many", + targetVertexType="vertex5") + self.assertIsInstance(res, int) self.assertEqual(3, res) - res = self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1) + res = self.conn.getEdgeCountFrom( + sourceVertexType="vertex4", sourceVertexId=1) self.assertIsInstance(res, dict) self.assertIn("edge1_undirected", res) self.assertEqual(3, res["edge1_undirected"]) @@ -138,37 +133,24 @@ def test_07_getEdgeCountFrom(self): self.assertIn("edge4_many_to_many", res) self.assertEqual(3, res["edge4_many_to_many"]) - res = self.conn.getEdgeCountFrom( - sourceVertexType="vertex4", sourceVertexId=1, edgeType="edge1_undirected" - ) + res = self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1, + edgeType="edge1_undirected") self.assertIsInstance(res, int) self.assertEqual(3, res) - res = self.conn.getEdgeCountFrom( - sourceVertexType="vertex4", - sourceVertexId=1, - edgeType="edge1_undirected", - where="a01=2", - ) + res = self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1, + edgeType="edge1_undirected", where="a01=2") self.assertIsInstance(res, int) self.assertEqual(2, res) - res = self.conn.getEdgeCountFrom( - sourceVertexType="vertex4", - sourceVertexId=1, - edgeType="edge1_undirected", - targetVertexType="vertex5", - ) + res = self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1, + edgeType="edge1_undirected", targetVertexType="vertex5") self.assertIsInstance(res, int) self.assertEqual(3, res) - res = self.conn.getEdgeCountFrom( - sourceVertexType="vertex4", - sourceVertexId=1, - edgeType="edge1_undirected", - targetVertexType="vertex5", - targetVertexId=3, - ) + res = self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1, + edgeType="edge1_undirected", targetVertexType="vertex5", targetVertexId=3) + self.assertIsInstance(res, int) self.assertEqual(1, res) @@ -188,7 +170,8 @@ def test_08_getEdgeCount(self): self.assertIsInstance(res, int) self.assertEqual(8, res) - res = self.conn.getEdgeCount("edge4_many_to_many", "vertex4", "vertex5") + res = self.conn.getEdgeCount( + "edge4_many_to_many", "vertex4", "vertex5") self.assertIsInstance(res, int) self.assertEqual(3, res) @@ -212,14 +195,16 @@ def test_08_getEdgeCount(self): And similarly, should the deletion test have a setup stage, when vertices to be deleted are inserted? • Or should these two actions tested together? But that would defeat the idea of unittests. - """ + def test_09_upsertEdge(self): - res = self.conn.upsertEdge("vertex6", 1, "edge4_many_to_many", "vertex7", 1) + res = self.conn.upsertEdge( + "vertex6", 1, "edge4_many_to_many", "vertex7", 1) self.assertIsInstance(res, int) self.assertEqual(1, res) - res = self.conn.upsertEdge("vertex6", 6, "edge4_many_to_many", "vertex7", 6) + res = self.conn.upsertEdge( + "vertex6", 6, "edge4_many_to_many", "vertex7", 6) self.assertIsInstance(res, int) self.assertEqual(1, res) @@ -255,8 +240,15 @@ def test_09_upsertEdge_mustExist(self): # TODO Add MultiEdge edge to schema and add test cases def test_10_upsertEdges(self): - es = [(2, 1), (2, 2), (2, 3), (2, 4)] - res = self.conn.upsertEdges("vertex6", "edge4_many_to_many", "vertex7", es) + es = [ + (2, 1), + (2, 2), + (2, 3), + (2, 4) + ] + res = self.conn.upsertEdges( + "vertex6", "edge4_many_to_many", "vertex7", es) + self.assertIsInstance(res, int) self.assertEqual(4, res) @@ -355,34 +347,40 @@ def test_12_getEdges(self): self.assertIsInstance(res, list) self.assertEqual(5, len(res)) - res = self.conn.getEdges("vertex4", 1, "edge1_undirected", "vertex5", 2) + res = self.conn.getEdges( + "vertex4", 1, "edge1_undirected", "vertex5", 2) self.assertIsInstance(res, list) self.assertEqual(1, len(res)) res = self.conn.getEdges( "vertex4", 1, "edge1_undirected", select="a01", where="a01>1" ) + self.assertIsInstance(res, list) self.assertEqual(2, len(res)) - res = self.conn.getEdges("vertex4", 1, "edge1_undirected", sort="-a01", limit=2) + res = self.conn.getEdges( + "vertex4", 1, "edge1_undirected", sort="-a01", limit=2) self.assertIsInstance(res, list) self.assertEqual(2, len(res)) res = self.conn.getEdges( "vertex4", 1, "edge1_undirected", "vertex5", fmt="json" ) + self.assertIsInstance(res, str) res = json.loads(res) self.assertIsInstance(res, list) self.assertEqual(5, len(res)) - res = self.conn.getEdges("vertex4", 1, "edge1_undirected", "vertex5", fmt="df") + res = self.conn.getEdges( + "vertex4", 1, "edge1_undirected", "vertex5", fmt="df") self.assertIsInstance(res, pd.DataFrame) self.assertEqual(5, len(res.index)) def test_13_getEdgesDataFrame(self): - res = self.conn.getEdgesDataFrame("vertex4", 1, "edge1_undirected", "vertex5") + res = self.conn.getEdgesDataFrame( + "vertex4", 1, "edge1_undirected", "vertex5") self.assertIsInstance(res, pd.DataFrame) self.assertEqual(5, len(res.index)) @@ -403,8 +401,8 @@ def test_16_getEdgeStats(self): self.assertEqual(-18.5, res["edge1_undirected"]["a01"]["AVG"]) res = self.conn.getEdgeStats( - ["edge1_undirected", "edge2_directed", "edge6_loop"] - ) + ["edge1_undirected", "edge2_directed", "edge6_loop"]) + self.assertIsInstance(res, dict) self.assertEqual(3, len(res)) self.assertIn("edge1_undirected", res) @@ -417,6 +415,7 @@ def test_16_getEdgeStats(self): res = self.conn.getEdgeStats( ["edge1_undirected", "edge2_directed", "edge6_loop"], skipNA=True ) + self.assertIsInstance(res, dict) self.assertEqual(2, len(res)) self.assertIn("edge1_undirected", res) @@ -448,7 +447,8 @@ def test_17_delEdges(self): self.assertIn("edge4_many_to_many", res) self.assertEqual(0, res["edge4_many_to_many"]) - res = self.conn.delEdges("vertex6", 2, "edge4_many_to_many", "vertex7", 1) + res = self.conn.delEdges( + "vertex6", 2, "edge4_many_to_many", "vertex7", 1) self.assertIsInstance(res, dict) self.assertEqual(1, len(res)) self.assertIn("edge4_many_to_many", res) @@ -462,7 +462,7 @@ def test_17_delEdges(self): def test_18_edgeSetToDataFrame(self): pass - + """ if __name__ == "__main__": unittest.main() diff --git a/tests/test_pyTigerGraphEdgeAsync.py b/tests/test_pyTigerGraphEdgeAsync.py new file mode 100644 index 00000000..ac3b5d44 --- /dev/null +++ b/tests/test_pyTigerGraphEdgeAsync.py @@ -0,0 +1,343 @@ +import json +import unittest + +import pandas +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + +class test_pyTigerGraphEdgeAsync(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.conn = await make_connection() + + async def test_01_getEdgeTypes(self): + res = sorted(await self.conn.getEdgeTypes()) + self.assertEqual(6, len(res)) + exp = ["edge1_undirected", "edge2_directed", "edge3_directed_with_reverse", + "edge4_many_to_many", "edge5_all_to_all", "edge6_loop"] + self.assertEqual(exp, res) + + async def test_02_getEdgeType(self): + res = await self.conn.getEdgeType("edge1_undirected") + self.assertIsNotNone(res) + self.assertIsInstance(res, dict) + self.assertIn("FromVertexTypeName", res) + self.assertEqual("vertex4", res["FromVertexTypeName"]) + self.assertIn("ToVertexTypeName", res) + self.assertEqual("vertex5", res["ToVertexTypeName"]) + self.assertIn("IsDirected", res) + self.assertFalse(res["IsDirected"]) + self.assertNotIn("EdgePairs", res) + + res = await self.conn.getEdgeType("edge2_directed") + self.assertIsNotNone(res) + self.assertIsInstance(res, dict) + self.assertIn("IsDirected", res) + self.assertTrue(res["IsDirected"]) + self.assertIn("Config", res) + self.assertNotIn("REVERSE_EDGE", res["Config"]) + + res = await self.conn.getEdgeType("edge3_directed_with_reverse") + self.assertIsNotNone(res) + self.assertIsInstance(res, dict) + self.assertIn("IsDirected", res) + self.assertTrue(res["IsDirected"]) + self.assertIn("Config", res) + self.assertIn("REVERSE_EDGE", res["Config"]) + self.assertEqual("edge3_directed_with_reverse_reverse_edge", + res["Config"]["REVERSE_EDGE"]) + + res = await self.conn.getEdgeType("edge4_many_to_many") + self.assertIsNotNone(res) + self.assertIsInstance(res, dict) + self.assertIn("ToVertexTypeName", res) + self.assertEqual("*", res["ToVertexTypeName"]) + self.assertIn("FromVertexTypeName", res) + self.assertEqual("*", res["FromVertexTypeName"]) + self.assertIn("EdgePairs", res) + self.assertEqual(5, len(res["EdgePairs"])) + + res = await self.conn.getEdgeType("edge5_all_to_all") + self.assertIsNotNone(res) + self.assertIsInstance(res, dict) + self.assertIn("ToVertexTypeName", res) + self.assertEqual("*", res["ToVertexTypeName"]) + self.assertIn("FromVertexTypeName", res) + self.assertEqual("*", res["FromVertexTypeName"]) + self.assertIn("EdgePairs", res) + self.assertEqual(49, len(res["EdgePairs"])) + + res = await self.conn.getEdgeType("non_existing_edge_type") + self.assertEqual({}, res) + # TODO This will need to be reviewed if/when getEdgeType() return value changes from {} in + # case of invalid/non-existing edge type name is specified (e.g. an exception will be + # raised instead of returning {} + + async def test_03_getEdgeSourceVertexType(self): + res = await self.conn.getEdgeSourceVertexType("edge1_undirected") + self.assertIsInstance(res, str) + self.assertEqual("vertex4", res) + + async def test_04_getEdgeTargetVertexType(self): + res = await self.conn.getEdgeTargetVertexType("edge2_directed") + self.assertIsInstance(res, str) + self.assertEqual("vertex5", res) + + async def test_05_isDirected(self): + res = await self.conn.isDirected("edge1_undirected") + self.assertIsInstance(res, bool) + self.assertFalse(res) + res = await self.conn.isDirected("edge2_directed") + self.assertIsInstance(res, bool) + self.assertTrue(res) + + async def test_06_getReverseEdge(self): + res = await self.conn.getReverseEdge("edge1_undirected") + self.assertIsInstance(res, str) + # TODO Change this to None or something in getReverseEdge()? + self.assertEqual("", res) + res = await self.conn.getReverseEdge("edge2_directed") + self.assertIsInstance(res, str) + # TODO Change this to None or something in getReverseEdge()? + self.assertEqual("", res) + res = await self.conn.getReverseEdge("edge3_directed_with_reverse") + self.assertIsInstance(res, str) + self.assertEqual("edge3_directed_with_reverse_reverse_edge", res) + + async def test_07_getEdgeCountFrom(self): + res = await self.conn.getEdgeCountFrom(edgeType="*") + self.assertIsInstance(res, dict) + self.assertIn("edge1_undirected", res) + self.assertEqual(8, res["edge1_undirected"]) + self.assertIn("edge6_loop", res) + self.assertEqual(0, res["edge6_loop"]) + + res = await self.conn.getEdgeCountFrom(edgeType="edge4_many_to_many") + self.assertIsInstance(res, int) + self.assertEqual(8, res) + + res = await self.conn.getEdgeCountFrom(sourceVertexType="vertex4", edgeType="edge4_many_to_many", + targetVertexType="vertex5") + self.assertIsInstance(res, int) + self.assertEqual(3, res) + + res = await self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1) + self.assertIsInstance(res, dict) + self.assertIn("edge1_undirected", res) + self.assertEqual(3, res["edge1_undirected"]) + self.assertIn("edge2_directed", res) + self.assertEqual(0, res["edge2_directed"]) + self.assertIn("edge4_many_to_many", res) + self.assertEqual(3, res["edge1_undirected"]) + + res = await self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1, + edgeType="edge1_undirected") + self.assertIsInstance(res, int) + self.assertEqual(3, res) + + res = await self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1, + edgeType="edge1_undirected", where="a01=2") + self.assertIsInstance(res, int) + self.assertEqual(2, res) + + res = await self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1, + edgeType="edge1_undirected", targetVertexType="vertex5") + self.assertIsInstance(res, int) + self.assertEqual(3, res) + + res = await self.conn.getEdgeCountFrom(sourceVertexType="vertex4", sourceVertexId=1, + edgeType="edge1_undirected", targetVertexType="vertex5", targetVertexId=3) + self.assertIsInstance(res, int) + self.assertEqual(1, res) + + async def test_08_getEdgeCount(self): + res = await self.conn.getEdgeCount("*") + self.assertIsInstance(res, dict) + self.assertIn("edge1_undirected", res) + self.assertEqual(8, res["edge1_undirected"]) + self.assertIn("edge6_loop", res) + self.assertEqual(0, res["edge6_loop"]) + + res = await self.conn.getEdgeCount("edge4_many_to_many") + self.assertIsInstance(res, int) + self.assertEqual(8, res) + + res = await self.conn.getEdgeCount("edge4_many_to_many", "vertex4") + self.assertIsInstance(res, int) + self.assertEqual(8, res) + + res = await self.conn.getEdgeCount("edge4_many_to_many", "vertex4", "vertex5") + self.assertIsInstance(res, int) + self.assertEqual(3, res) + + """ Commented out because the order of execution is not guaranteed, so the serialized nature doesn't work in async case. + Apparently, the following tests are not structured properly. + The code below first inserts edges in two steps, then retrieves them, and finally, deletes them. + It seems that the order of execution is not guaranteed, so the serialised nature of steps might + not work in some environments/setups. + Also, unittest runs separate tests with fresh instances of the TestCase, so setUp and tearDown + are executed before/after each tests and – importantly – it is not "possible" to persist + information between test cases (i.e. save a piece of information in e.g. a variable of the class + instance in one test and use it in another test) (it is technically possible, but not + recommended due to the aforementioned reasons). + + Luckily, it seems that tests are executed in alphabetical order, so there is a good chance that + in basic testing setups, they will be executed in the desired order. + + TODO How to structure tests so that every step can be executed independently? + E.g. how to test insertion and deletion of edge? + • Should the insertion test have a clean-up stage deleting the newly inserted vertices? + And similarly, should the deletion test have a setup stage, when vertices to be deleted are + inserted? + • Or should these two actions tested together? But that would defeat the idea of unittests. + + + async def test_09_upsertEdge(self): + res = await self.conn.upsertEdge("vertex6", 1, "edge4_many_to_many", "vertex7", 1) + self.assertIsInstance(res, int) + self.assertEqual(1, res) + + res = await self.conn.upsertEdge("vertex6", 6, "edge4_many_to_many", "vertex7", 6) + self.assertIsInstance(res, int) + self.assertEqual(1, res) + + # TODO Tests with ack, new_vertex_only, vertex_must_exist, update_vertex_only and + # atomic_level parameters; when they will be added to pyTigerGraphEdge.upsertEdge() + # TODO Add MultiEdge edge to schema and add test cases + + async def test_10_upsertEdges(self): + es = [ + (2, 1), + (2, 2), + (2, 3), + (2, 4) + ] + res = await self.conn.upsertEdges("vertex6", "edge4_many_to_many", "vertex7", es) + self.assertIsInstance(res, int) + self.assertEqual(4, res) + + res = await self.conn.getEdgeCount("edge4_many_to_many") + self.assertIsInstance(res, int) + self.assertEqual(14, res) + + async def test_11_upsertEdgeDataFrame(self): + # TODO Implement + pass + + async def test_12_getEdges(self): + res = await self.conn.getEdges("vertex4", 1) + self.assertIsInstance(res, list) + self.assertEqual(6, len(res)) + + res = await self.conn.getEdges("vertex4", 1, "edge1_undirected") + self.assertIsInstance(res, list) + self.assertEqual(3, len(res)) + + res = await self.conn.getEdges("vertex4", 1, "edge1_undirected", "vertex5") + self.assertIsInstance(res, list) + self.assertEqual(3, len(res)) + + res = await self.conn.getEdges("vertex4", 1, "edge1_undirected", "vertex5", 2) + self.assertIsInstance(res, list) + self.assertEqual(1, len(res)) + + res = await self.conn.getEdges("vertex4", 1, "edge1_undirected", select="a01", where="a01>1") + self.assertIsInstance(res, list) + self.assertEqual(2, len(res)) + + res = await self.conn.getEdges("vertex4", 1, "edge1_undirected", sort="-a01", limit=2) + self.assertIsInstance(res, list) + self.assertEqual(2, len(res)) + + res = await self.conn.getEdges("vertex4", 1, "edge1_undirected", "vertex5", fmt="json") + self.assertIsInstance(res, str) + res = json.loads(res) + self.assertIsInstance(res, list) + self.assertEqual(3, len(res)) + + res = await self.conn.getEdges("vertex4", 1, "edge1_undirected", "vertex5", fmt="df") + self.assertIsInstance(res, pandas.DataFrame) + self.assertEqual(3, len(res.index)) + + async def test_13_getEdgesDataFrame(self): + res = await self.conn.getEdgesDataFrame("vertex4", 1, "edge1_undirected", "vertex5") + self.assertIsInstance(res, pandas.DataFrame) + self.assertEqual(3, len(res.index)) + + async def test_14_getEdgesByType(self): + res = await self.conn.getEdgesByType("edge1_undirected") + self.assertIsInstance(res, list) + self.assertEqual(8, len(res)) + + async def test_15_getEdgesDataFrameByType(self): + pass + + async def test_16_getEdgeStats(self): + res = await self.conn.getEdgeStats("edge1_undirected") + self.assertIsInstance(res, dict) + self.assertEqual(1, len(res)) + self.assertIn("edge1_undirected", res) + self.assertEqual(2, res["edge1_undirected"]["a01"]["MAX"]) + self.assertEqual(1.875, res["edge1_undirected"]["a01"]["AVG"]) + + res = await self.conn.getEdgeStats(["edge1_undirected", "edge2_directed", "edge6_loop"]) + self.assertIsInstance(res, dict) + self.assertEqual(3, len(res)) + self.assertIn("edge1_undirected", res) + self.assertEqual(2, res["edge1_undirected"]["a01"]["MAX"]) + self.assertIn("edge2_directed", res) + self.assertEqual(2, res["edge2_directed"]["a01"]["AVG"]) + self.assertIn("edge6_loop", res) + self.assertEqual({}, res["edge6_loop"]) + + res = await self.conn.getEdgeStats(["edge1_undirected", "edge2_directed", "edge6_loop"], + skipNA=True) + self.assertIsInstance(res, dict) + self.assertEqual(2, len(res)) + self.assertIn("edge1_undirected", res) + self.assertEqual(2, res["edge1_undirected"]["a01"]["MAX"]) + self.assertIn("edge2_directed", res) + self.assertNotIn("edge6_loop", res) + + res = await self.conn.getEdgeStats("*", skipNA=True) + self.assertIsInstance(res, dict) + self.assertIn("edge3_directed_with_reverse", res) + self.assertNotIn("edge4_many_to_many", res) + + async def test_17_delEdges(self): + res = await self.conn.delEdges("vertex6", 1) + self.assertIsInstance(res, dict) + self.assertEqual(7, len(res)) + self.assertIn("edge4_many_to_many", res) + self.assertEqual(1, res["edge4_many_to_many"]) + + res = await self.conn.delEdges("vertex6", 6, "edge4_many_to_many") + self.assertIsInstance(res, dict) + self.assertEqual(1, len(res)) + self.assertIn("edge4_many_to_many", res) + self.assertEqual(1, res["edge4_many_to_many"]) + + res = await self.conn.delEdges("vertex6", 6, "edge4_many_to_many") + self.assertIsInstance(res, dict) + self.assertEqual(1, len(res)) + self.assertIn("edge4_many_to_many", res) + self.assertEqual(0, res["edge4_many_to_many"]) + + res = await self.conn.delEdges("vertex6", 2, "edge4_many_to_many", "vertex7", 1) + self.assertIsInstance(res, dict) + self.assertEqual(1, len(res)) + self.assertIn("edge4_many_to_many", res) + self.assertEqual(1, res["edge4_many_to_many"]) + + res = await self.conn.delEdges("vertex6", 2, "edge4_many_to_many", "vertex7") + self.assertIsInstance(res, dict) + self.assertEqual(1, len(res)) + self.assertIn("edge4_many_to_many", res) + self.assertEqual(3, res["edge4_many_to_many"]) + + def test_18_edgeSetToDataFrame(self): + pass + """ + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pyTigerGraphGSQL.py b/tests/test_pyTigerGraphGSQL.py index b81a634a..2ec73f88 100644 --- a/tests/test_pyTigerGraphGSQL.py +++ b/tests/test_pyTigerGraphGSQL.py @@ -13,13 +13,15 @@ def test_01_gsql(self): res = self.conn.gsql("help") self.assertIsInstance(res, str) res = res.split("\n") - self.assertEqual("GSQL Help: Summary of TigerGraph GSQL Shell commands.", res[0]) + self.assertEqual( + "GSQL Help: Summary of TigerGraph GSQL Shell commands.", res[0]) def test_02_gsql(self): res = self.conn.gsql("ls") self.assertIsInstance(res, str) res = res.split("\n")[0] - self.assertIn(res,["---- Global vertices, edges, and all graphs", "---- Graph " + self.conn.graphname]) + self.assertIn(res, ["---- Global vertices, edges, and all graphs", + "---- Graph " + self.conn.graphname]) # def test_03_installUDF(self): # path = os.path.dirname(os.path.realpath(__file__)) @@ -57,4 +59,3 @@ def test_getUDF(self): runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) - diff --git a/tests/test_pyTigerGraphGSQLAsync.py b/tests/test_pyTigerGraphGSQLAsync.py new file mode 100644 index 00000000..e80fac44 --- /dev/null +++ b/tests/test_pyTigerGraphGSQLAsync.py @@ -0,0 +1,64 @@ +import unittest +import os + +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphGSQLAsync(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.conn = await make_connection() + + async def test_01_gsql(self): + res = await self.conn.gsql("help") + self.assertIsInstance(res, str) + res = res.split("\n") + self.assertEqual( + "GSQL Help: Summary of TigerGraph GSQL Shell commands.", res[0]) + + async def test_02_gsql(self): + res = await self.conn.gsql("ls") + self.assertIsInstance(res, str) + res = res.split("\n")[0] + self.assertIn(res, ["---- Global vertices, edges, and all graphs", + "---- Graph " + self.conn.graphname]) + + # def test_03_installUDF(self): + # path = os.path.dirname(os.path.realpath(__file__)) + # ExprFunctions = os.path.join(path, "fixtures", "ExprFunctions.hpp") + # ExprUtil = os.path.join(path, "fixtures", "ExprUtil.hpp") + # self.assertEqual(self.conn.installUDF(ExprFunctions, ExprUtil), 0) + + # def test_04_installUDFRemote(self): + # ExprFunctions = "https://tg-mlworkbench.s3.us-west-1.amazonaws.com/udf/1.0/ExprFunctions.hpp" + # self.assertEqual(self.conn.installUDF(ExprFunctions=ExprFunctions), 0) + + async def test_getUDF(self): + # Don't get anything + res = await self.conn.getUDF(ExprFunctions=False, ExprUtil=False) + self.assertEqual(res, "") + # Get both ExprFunctions and ExprUtil (default) + udf = await self.conn.getUDF() + self.assertIn("init_kafka_producer", udf[0]) + self.assertIn("class KafkaProducer", udf[1]) + # Get ExprFunctions only + udf = await self.conn.getUDF(ExprUtil=False) + self.assertIn("init_kafka_producer", udf) + # Get ExprUtil only + udf = await self.conn.getUDF(ExprFunctions=False) + self.assertIn("class KafkaProducer", udf) + + +if __name__ == "__main__": + # suite = unittest.TestSuite() + # suite.addTest(test_pyTigerGraphGSQL("test_01_gsql")) + # suite.addTest(test_pyTigerGraphGSQL("test_02_gsql")) + # # suite.addTest(test_pyTigerGraphGSQL("test_03_installUDF")) + # # suite.addTest(test_pyTigerGraphGSQL("test_04_installUDFRemote")) + # suite.addTest(test_pyTigerGraphGSQL("test_getUDF")) + + # runner = unittest.TextTestRunner(verbosity=2, failfast=True) + # runner.run(suite) + + unittest.main() diff --git a/tests/test_pyTigerGraphPath.py b/tests/test_pyTigerGraphPath.py index e13b394a..956c947b 100644 --- a/tests/test_pyTigerGraphPath.py +++ b/tests/test_pyTigerGraphPath.py @@ -1,6 +1,8 @@ import json import unittest +from pyTigerGraph.common.path import _prepare_path_params + from pyTigerGraphUnitTest import make_connection @@ -58,9 +60,10 @@ def _check_edges(self, res_es: list, exp_es: list) -> bool: return sorted(es) == sorted(exp_es) def test_01_preparePathParams(self): - res = self.conn._preparePathParams([("srctype1", 1), ("srctype2", 2), ("srctype3", 3)], - [("trgtype1", 1), ("trgtype2", 2), ("trgtype3", 3)], 5, - [("srctype1", "a01>10")], [("trgtype1", "a10<20")], True) + res = _prepare_path_params([("srctype1", 1), ("srctype2", 2), ("srctype3", 3)], + [("trgtype1", 1), ("trgtype2", 2), + ("trgtype3", 3)], 5, + [("srctype1", "a01>10")], [("trgtype1", "a10<20")], True) self.assertIsInstance(res, str) res = json.loads(res) self.assertEqual(6, len(res)) @@ -76,7 +79,7 @@ def test_01_preparePathParams(self): self.assertIn("allShortestPaths", res) self.assertTrue(res["allShortestPaths"]) - res = self.conn._preparePathParams([("srct", 1)], [("trgt", 1)]) + res = _prepare_path_params([("srct", 1)], [("trgt", 1)]) self.assertEqual( '{"sources": [{"type": "srct", "id": 1}], "targets": [{"type": "trgt", "id": 1}]}', res @@ -84,8 +87,10 @@ def test_01_preparePathParams(self): def test_02_shortestPath(self): - self.assertEqual(8, self.conn.getVertexCount("vertex4", where="a01>=900")) - self.assertEqual(11, self.conn.getEdgeCount("edge6_loop", "vertex4", "vertex4")) + self.assertEqual(8, self.conn.getVertexCount( + "vertex4", where="a01>=900")) + self.assertEqual(11, self.conn.getEdgeCount( + "edge6_loop", "vertex4", "vertex4")) res = self.conn.shortestPath(("vertex4", 10), ("vertex4", 50)) vs1 = [10, 20, 30, 40, 50] @@ -99,27 +104,30 @@ def test_02_shortestPath(self): self._check_edges(res[0]["edges"], es2)) ) - res = self.conn.shortestPath(("vertex4", 10), ("vertex4", 50), allShortestPaths=True) + res = self.conn.shortestPath( + ("vertex4", 10), ("vertex4", 50), allShortestPaths=True) vs3 = [10, 20, 30, 40, 50, 60, 70] - es3 = [(10, 20), (20, 30), (30, 40), (40, 50), (10, 60), (60, 70), (70, 40)] + es3 = [(10, 20), (20, 30), (30, 40), (40, 50), + (10, 60), (60, 70), (70, 40)] self.assertTrue( (self._check_vertices(res[0]["vertices"], vs3) and self._check_edges(res[0]["edges"], es3)) ) - res = self.conn.shortestPath(("vertex4", 10), ("vertex4", 50), maxLength=3) + res = self.conn.shortestPath( + ("vertex4", 10), ("vertex4", 50), maxLength=3) self.assertEqual([], res[0]["vertices"]) self.assertEqual([], res[0]["edges"]) res = self.conn.shortestPath(("vertex4", 10), ("vertex4", 50), allShortestPaths=True, - vertexFilters=("vertex4", "a01>950")) + vertexFilters=("vertex4", "a01>950")) self.assertTrue( (self._check_vertices(res[0]["vertices"], vs1) and self._check_edges(res[0]["edges"], es1)) ) res = self.conn.shortestPath(("vertex4", 10), ("vertex4", 50), allShortestPaths=True, - edgeFilters=("edge6_loop", "a01<950")) + edgeFilters=("edge6_loop", "a01<950")) self.assertTrue( (self._check_vertices(res[0]["vertices"], vs2) and self._check_edges(res[0]["edges"], es2)) @@ -128,7 +136,8 @@ def test_02_shortestPath(self): def test_03_allPaths(self): res = self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=4) vs = [10, 20, 30, 40, 50, 60, 70] - es = [(10, 20), (20, 30), (30, 40), (40, 50), (10, 60), (60, 70), (70, 40)] + es = [(10, 20), (20, 30), (30, 40), (40, 50), + (10, 60), (60, 70), (70, 40)] self.assertTrue( (self._check_vertices(res[0]["vertices"], vs) and self._check_edges(res[0]["edges"], es)) @@ -137,7 +146,7 @@ def test_03_allPaths(self): res = self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=5) vs = [10, 20, 30, 40, 50, 60, 70, 80] es = [(10, 20), (20, 30), (30, 40), (40, 50), (10, 60), (60, 70), (70, 40), (70, 80), - (80, 40)] + (80, 40)] self.assertTrue( (self._check_vertices(res[0]["vertices"], vs) and self._check_edges(res[0]["edges"], es)) @@ -146,14 +155,14 @@ def test_03_allPaths(self): res = self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=6) vs = [10, 20, 30, 40, 50, 60, 70, 80] es = [(10, 20), (20, 30), (30, 40), (40, 50), (10, 60), (60, 70), (70, 40), (70, 80), - (80, 40), (30, 60)] + (80, 40), (30, 60)] self.assertTrue( (self._check_vertices(res[0]["vertices"], vs) and self._check_edges(res[0]["edges"], es)) ) res = self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=5, - vertexFilters=("vertex4", "a01>950")) + vertexFilters=("vertex4", "a01>950")) vs = [10, 20, 30, 40, 50] es = [(10, 20), (20, 30), (30, 40), (40, 50)] self.assertTrue( @@ -162,7 +171,7 @@ def test_03_allPaths(self): ) res = self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=5, - edgeFilters=("edge6_loop", "a01<950")) + edgeFilters=("edge6_loop", "a01<950")) vs = [10, 60, 70, 40, 50] es = [(10, 60), (60, 70), (70, 40), (40, 50)] self.assertTrue( diff --git a/tests/test_pyTigerGraphPathAsync.py b/tests/test_pyTigerGraphPathAsync.py new file mode 100644 index 00000000..dcdf7b35 --- /dev/null +++ b/tests/test_pyTigerGraphPathAsync.py @@ -0,0 +1,179 @@ +import json +import unittest + +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.path import _prepare_path_params + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphPathAsync(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.conn = await make_connection() + await self.conn.upsertVertices("vertex4", self.vs) + await self.conn.upsertEdges("vertex4", "edge6_loop", "vertex4", self.es) + + vs = [ + (10, {"a01": 999}), + (20, {"a01": 999}), + (30, {"a01": 999}), + (40, {"a01": 999}), + (50, {"a01": 999}), + (60, {"a01": 900}), + (70, {"a01": 900}), + (80, {"a01": 900}) + ] + es = [ + (10, 20, {"a01": 999}), + (20, 30, {"a01": 999}), + (30, 40, {"a01": 999}), + (40, 50, {"a01": 900}), + (10, 60, {"a01": 900}), + (60, 70, {"a01": 900}), + (70, 80, {"a01": 999}), + (80, 40, {"a01": 999}), + (80, 20, {"a01": 999}), + (70, 40, {"a01": 900}), + (30, 60, {"a01": 999}) + ] + + async def asyncTearDown(self): + for i in self.es: + await self.conn.delEdges("vertex4", i[0], "edge6_loop", "vertex4", i[1]) + for i in self.vs: + await self.conn.delVerticesById("vertex4", i[0]) + + def _check_vertices(self, res_vs: list, exp_vs: list) -> bool: + self.assertEqual(len(exp_vs), len(res_vs)) + vs = [] + for v in res_vs: + vs.append(int(v["v_id"])) + return sorted(vs) == sorted(exp_vs) + + def _check_edges(self, res_es: list, exp_es: list) -> bool: + self.assertEqual(len(exp_es), len(res_es)) + es = [] + for e in res_es: + es.append((int(e["from_id"]), int(e["to_id"]))) + return sorted(es) == sorted(exp_es) + + def test_01_preparePathParams(self): + res = _prepare_path_params([("srctype1", 1), ("srctype2", 2), ("srctype3", 3)], + [("trgtype1", 1), ("trgtype2", 2), + ("trgtype3", 3)], 5, + [("srctype1", "a01>10")], [("trgtype1", "a10<20")], True) + self.assertIsInstance(res, str) + res = json.loads(res) + self.assertEqual(6, len(res)) + self.assertIn("sources", res) + srcs = res["sources"] + self.assertIsInstance(srcs, list) + self.assertEqual(3, len(srcs)) + self.assertEqual('{"type": "srctype1", "id": 1}', json.dumps(srcs[0])) + self.assertIn("targets", res) + self.assertIn("vertexFilters", res) + self.assertIn("edgeFilters", res) + self.assertIn("maxLength", res) + self.assertIn("allShortestPaths", res) + self.assertTrue(res["allShortestPaths"]) + + res = _prepare_path_params([("srct", 1)], [("trgt", 1)]) + self.assertEqual( + '{"sources": [{"type": "srct", "id": 1}], "targets": [{"type": "trgt", "id": 1}]}', + res + ) + + async def test_02_shortestPath(self): + + self.assertEqual(8, await self.conn.getVertexCount("vertex4", where="a01>=900")) + self.assertEqual(11, await self.conn.getEdgeCount("edge6_loop", "vertex4", "vertex4")) + + res = await self.conn.shortestPath(("vertex4", 10), ("vertex4", 50)) + vs1 = [10, 20, 30, 40, 50] + es1 = [(10, 20), (20, 30), (30, 40), (40, 50)] + vs2 = [10, 60, 70, 40, 50] + es2 = [(10, 60), (60, 70), (70, 40), (40, 50)] + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs1) and + self._check_edges(res[0]["edges"], es1)) or + (self._check_vertices(res[0]["vertices"], vs2) and + self._check_edges(res[0]["edges"], es2)) + ) + + res = await self.conn.shortestPath(("vertex4", 10), ("vertex4", 50), allShortestPaths=True) + vs3 = [10, 20, 30, 40, 50, 60, 70] + es3 = [(10, 20), (20, 30), (30, 40), (40, 50), + (10, 60), (60, 70), (70, 40)] + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs3) and + self._check_edges(res[0]["edges"], es3)) + ) + + res = await self.conn.shortestPath(("vertex4", 10), ("vertex4", 50), maxLength=3) + self.assertEqual([], res[0]["vertices"]) + self.assertEqual([], res[0]["edges"]) + + res = await self.conn.shortestPath(("vertex4", 10), ("vertex4", 50), allShortestPaths=True, + vertexFilters=("vertex4", "a01>950")) + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs1) and + self._check_edges(res[0]["edges"], es1)) + ) + + res = await self.conn.shortestPath(("vertex4", 10), ("vertex4", 50), allShortestPaths=True, + edgeFilters=("edge6_loop", "a01<950")) + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs2) and + self._check_edges(res[0]["edges"], es2)) + ) + + async def test_03_allPaths(self): + res = await self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=4) + vs = [10, 20, 30, 40, 50, 60, 70] + es = [(10, 20), (20, 30), (30, 40), (40, 50), + (10, 60), (60, 70), (70, 40)] + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs) and + self._check_edges(res[0]["edges"], es)) + ) + + res = await self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=5) + vs = [10, 20, 30, 40, 50, 60, 70, 80] + es = [(10, 20), (20, 30), (30, 40), (40, 50), (10, 60), (60, 70), (70, 40), (70, 80), + (80, 40)] + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs) and + self._check_edges(res[0]["edges"], es)) + ) + + res = await self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=6) + vs = [10, 20, 30, 40, 50, 60, 70, 80] + es = [(10, 20), (20, 30), (30, 40), (40, 50), (10, 60), (60, 70), (70, 40), (70, 80), + (80, 40), (30, 60)] + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs) and + self._check_edges(res[0]["edges"], es)) + ) + + res = await self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=5, + vertexFilters=("vertex4", "a01>950")) + vs = [10, 20, 30, 40, 50] + es = [(10, 20), (20, 30), (30, 40), (40, 50)] + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs) and + self._check_edges(res[0]["edges"], es)) + ) + + res = await self.conn.allPaths(("vertex4", 10), ("vertex4", 50), maxLength=5, + edgeFilters=("edge6_loop", "a01<950")) + vs = [10, 60, 70, 40, 50] + es = [(10, 60), (60, 70), (70, 40), (40, 50)] + self.assertTrue( + (self._check_vertices(res[0]["vertices"], vs) and + self._check_edges(res[0]["edges"], es)) + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pyTigerGraphQuery.py b/tests/test_pyTigerGraphQuery.py index bd98bd61..9f752227 100644 --- a/tests/test_pyTigerGraphQuery.py +++ b/tests/test_pyTigerGraphQuery.py @@ -4,7 +4,7 @@ from pyTigerGraphUnitTest import make_connection -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException class test_pyTigerGraphQuery(unittest.TestCase): @@ -19,7 +19,7 @@ def test_01_getQueries(self): def test_02_getInstalledQueries(self): res = self.conn.getInstalledQueries() self.assertIn("GET /query/tests/query1", res) - # self.assertNotIn("GET /query/tests/query2_not_installed", res) + #self.assertNotIn("GET /query/tests/query2_not_installed", res) self.assertIn("GET /query/tests/query3_installed", res) def test_03_runInstalledQuery(self): @@ -37,7 +37,8 @@ def test_03_runInstalledQuery(self): "p07_vertex": (1, "vertex4"), "p08_vertex_vertex4": 1, "p09_datetime": datetime.now(), - "p10_set_int": [1, 2, 3, 2, 3, 3], # Intentionally bag-like, to see it behaving as set + # Intentionally bag-like, to see it behaving as set + "p10_set_int": [1, 2, 3, 2, 3, 3], "p11_bag_int": [1, 2, 3, 2, 3, 3], "p13_set_vertex": [(1, "vertex4"), (2, "vertex4"), (3, "vertex4")], "p14_set_vertex_vertex4": [1, 2, 3] @@ -57,7 +58,7 @@ def test_03_runInstalledQuery(self): def test_04_runInterpretedQuery(self): queryText = \ -"""INTERPRET QUERY () FOR GRAPH $graphname { + """INTERPRET QUERY () FOR GRAPH $graphname { SumAccum @@summa; start = {vertex4.*}; res = @@ -71,7 +72,7 @@ def test_04_runInterpretedQuery(self): self.assertEqual(15, res[0]["ret"]) queryText = \ -"""INTERPRET QUERY () FOR GRAPH @graphname@ { + """INTERPRET QUERY () FOR GRAPH @graphname@ { SumAccum @@summa; start = {vertex4.*}; res = @@ -90,7 +91,7 @@ def test_05_runInstalledQueryAsync(self): while trials < 30: job = self.conn.checkQueryStatus(q_id)[0] if job["status"] == "success": - break + break sleep(1) trials += 1 res = self.conn.getQueryResult(q_id) @@ -99,6 +100,7 @@ def test_05_runInstalledQueryAsync(self): def test_06_checkQueryStatus(self): q_id = self.conn.runInstalledQuery("query1", runAsync=True) + print(q_id) res = self.conn.checkQueryStatus(q_id) self.assertIn("requestid", res[0]) self.assertEqual(q_id, res[0]["requestid"]) @@ -107,7 +109,7 @@ def test_07_showQuery(self): query = self.conn.showQuery("query1").split("\n")[1] q1 = """# installed v2""" self.assertEqual(q1, query) - + def test_08_getQueryMetadata(self): query_md = self.conn.getQueryMetadata("query1") self.assertEqual(query_md["output"][0], {"ret": "int"}) @@ -122,7 +124,7 @@ def test_10_abortQuery(self): def test_11_queryDescriptions(self): version = self.conn.getVer().split('.') - if version[0]>="4": # Query descriptions only supported in Tigergraph versions >= 4.x + if version[0] >= "4": # Query descriptions only supported in Tigergraph versions >= 4.x self.conn.dropQueryDescription('query1') desc = self.conn.getQueryDescription('query1') self.assertEqual(desc, [{'queryName': 'query1', 'parameters': []}]) @@ -131,24 +133,31 @@ def test_11_queryDescriptions(self): self.assertEqual(desc[0]['description'], 'This is a description') self.conn.dropQueryDescription('query4_all_param_types') - self.conn.describeQuery('query4_all_param_types', 'this is a query description', - {'p01_int':'this is a parameter description', - 'p02_uint':'this is a second param desc'}) + self.conn.describeQuery('query4_all_param_types', 'this is a query description', + {'p01_int': 'this is a parameter description', + 'p02_uint': 'this is a second param desc'}) desc = self.conn.getQueryDescription('query4_all_param_types') - self.assertEqual(desc[0]['description'], 'this is a query description') - self.assertEqual(desc[0]['parameters'][0]['description'], 'this is a parameter description') - self.assertEqual(desc[0]['parameters'][1]['description'], 'this is a second param desc') + self.assertEqual(desc[0]['description'], + 'this is a query description') + self.assertEqual( + desc[0]['parameters'][0]['description'], 'this is a parameter description') + self.assertEqual(desc[0]['parameters'][1] + ['description'], 'this is a second param desc') else: with self.assertRaises(TigerGraphException) as tge: res = self.conn.dropQueryDescription('query1') - self.assertEqual("This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) + self.assertEqual( + "This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) with self.assertRaises(TigerGraphException) as tge: res = self.conn.describeQuery('query1', 'test') - self.assertEqual("This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) + self.assertEqual( + "This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) with self.assertRaises(TigerGraphException) as tge: res = self.conn.getQueryDescription('query1') - self.assertEqual("This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) - + self.assertEqual( + "This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_pyTigerGraphQueryAsync.py b/tests/test_pyTigerGraphQueryAsync.py new file mode 100644 index 00000000..d983c39c --- /dev/null +++ b/tests/test_pyTigerGraphQueryAsync.py @@ -0,0 +1,166 @@ +import unittest +from datetime import datetime +from time import sleep + +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphQueryAsync(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.conn = await make_connection() + + async def test_01_getQueries(self): + # TODO Once pyTigerGraphQuery.getQueries() is available + pass + + async def test_02_getInstalledQueries(self): + res = await self.conn.getInstalledQueries() + self.assertIn("GET /query/tests/query1", res) + #self.assertNotIn("GET /query/tests/query2_not_installed", res) + self.assertIn("GET /query/tests/query3_installed", res) + + async def test_03_runInstalledQuery(self): + res = await self.conn.runInstalledQuery("query1") + self.assertIn("ret", res[0]) + self.assertEqual(15, res[0]["ret"]) + + params = { + "p01_int": 1, + "p02_uint": 1, + "p03_float": 1.1, + "p04_double": 1.1, + "p05_string": "test <>\"'`\\/{}[]()<>!@£$%^&*-_=+;:|,.§±~` árvíztűrő tükörfúrógép 👍", + "p06_bool": True, + "p07_vertex": (1, "vertex4"), + "p08_vertex_vertex4": 1, + "p09_datetime": datetime.now(), + # Intentionally bag-like, to see it behaving as set + "p10_set_int": [1, 2, 3, 2, 3, 3], + "p11_bag_int": [1, 2, 3, 2, 3, 3], + "p13_set_vertex": [(1, "vertex4"), (2, "vertex4"), (3, "vertex4")], + "p14_set_vertex_vertex4": [1, 2, 3] + } + + res = await self.conn.runInstalledQuery("query4_all_param_types", params) + self.assertIsInstance(res, list) + self.assertIsInstance(res[4], dict) + self.assertIn("p05_string", res[4]) + self.assertEqual(params["p05_string"], res[4]["p05_string"]) + self.assertIsInstance(res[11], dict) + vs = res[11] + self.assertIn("p13_set_vertex", vs) + vs = sorted(vs["p13_set_vertex"]) + self.assertIsInstance(vs, list) + self.assertEqual(["1", "2", "3"], vs) + + async def test_04_runInterpretedQuery(self): + queryText = \ + """INTERPRET QUERY () FOR GRAPH $graphname { + SumAccum @@summa; + start = {vertex4.*}; + res = + SELECT src + FROM start:src + ACCUM @@summa += src.a01; + PRINT @@summa AS ret; +}""" + res = await self.conn.runInterpretedQuery(queryText) + self.assertIn("ret", res[0]) + self.assertEqual(15, res[0]["ret"]) + + queryText = \ + """INTERPRET QUERY () FOR GRAPH @graphname@ { + SumAccum @@summa; + start = {vertex4.*}; + res = + SELECT src + FROM start:src + ACCUM @@summa += src.a01; + PRINT @@summa AS ret; +}""" + res = await self.conn.runInterpretedQuery(queryText) + self.assertIn("ret", res[0]) + self.assertEqual(15, res[0]["ret"]) + + async def test_05_runInstalledQueryAsync(self): + q_id = await self.conn.runInstalledQuery("query1", runAsync=True) + trials = 0 + while trials < 30: + job = await self.conn.checkQueryStatus(q_id) + job = job[0] + if job["status"] == "success": + break + sleep(1) + trials += 1 + res = await self.conn.getQueryResult(q_id) + self.assertIn("ret", res[0]) + self.assertEqual(15, res[0]["ret"]) + + async def test_06_checkQueryStatus(self): + q_id = await self.conn.runInstalledQuery("query1", runAsync=True) + print(q_id) + res = await self.conn.checkQueryStatus(q_id) + self.assertIn("requestid", res[0]) + self.assertEqual(q_id, res[0]["requestid"]) + + async def test_07_showQuery(self): + query = await self.conn.showQuery("query1") + query = query.split("\n")[1] + q1 = """# installed v2""" + self.assertEqual(q1, query) + + async def test_08_getQueryMetadata(self): + query_md = await self.conn.getQueryMetadata("query1") + self.assertEqual(query_md["output"][0], {"ret": "int"}) + + async def test_09_getRunningQueries(self): + rq_id = await self.conn.getRunningQueries() + rq_id = rq_id["results"] + self.assertEqual(len(rq_id), 0) + + async def test_10_abortQuery(self): + abort_ret = await self.conn.abortQuery("all") + self.assertEqual(abort_ret["results"], [{'aborted_queries': []}]) + + async def test_11_queryDescriptions(self): + version = await self.conn.getVer() + version = version.split('.') + if version[0] >= "4": # Query descriptions only supported in Tigergraph versions >= 4.x + await self.conn.dropQueryDescription('query1') + desc = await self.conn.getQueryDescription('query1') + self.assertEqual(desc, [{'queryName': 'query1', 'parameters': []}]) + await self.conn.describeQuery('query1', 'This is a description') + desc = await self.conn.getQueryDescription('query1') + self.assertEqual(desc[0]['description'], 'This is a description') + + await self.conn.dropQueryDescription('query4_all_param_types') + await self.conn.describeQuery('query4_all_param_types', 'this is a query description', + {'p01_int': 'this is a parameter description', + 'p02_uint': 'this is a second param desc'}) + desc = await self.conn.getQueryDescription('query4_all_param_types') + self.assertEqual(desc[0]['description'], + 'this is a query description') + self.assertEqual( + desc[0]['parameters'][0]['description'], 'this is a parameter description') + self.assertEqual(desc[0]['parameters'][1] + ['description'], 'this is a second param desc') + + else: + with self.assertRaises(TigerGraphException) as tge: + res = await self.conn.dropQueryDescription('query1') + self.assertEqual( + "This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) + with self.assertRaises(TigerGraphException) as tge: + res = await self.conn.describeQuery('query1', 'test') + self.assertEqual( + "This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) + with self.assertRaises(TigerGraphException) as tge: + res = await self.conn.getQueryDescription('query1') + self.assertEqual( + "This function is only supported on versions of TigerGraph >= 4.0.0.", tge.exception.message) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pyTigerGraphSchema.py b/tests/test_pyTigerGraphSchema.py index 96b8b322..f969918b 100644 --- a/tests/test_pyTigerGraphSchema.py +++ b/tests/test_pyTigerGraphSchema.py @@ -1,6 +1,8 @@ import json import unittest +from pyTigerGraph.common.schema import _upsert_attrs + from pyTigerGraphUnitTest import make_connection @@ -13,16 +15,18 @@ def test_01_getUDTs(self): res = self.conn._getUDTs() self.assertIsInstance(res, list) self.assertEqual(2, len(res)) - self.assertTrue(res[0]["name"] == "tuple1_all_types" or res[0]["name"] == "tuple2_simple") - tuple2_simple = res[0] if (res[0]["name"] == "tuple2_simple") else res[1] + self.assertTrue( + res[0]["name"] == "tuple1_all_types" or res[0]["name"] == "tuple2_simple") + tuple2_simple = res[0] if ( + res[0]["name"] == "tuple2_simple") else res[1] self.assertIn('fields', tuple2_simple) fields = tuple2_simple['fields'] - self.assertTrue(fields[0]['fieldName']=='field1') - self.assertTrue(fields[0]['fieldType']=='INT') - self.assertTrue(fields[1]['fieldName']=='field2') - self.assertTrue(fields[1]['fieldType']=='STRING') - self.assertTrue(fields[2]['fieldName']=='field3') - self.assertTrue(fields[2]['fieldType']=='DATETIME') + self.assertTrue(fields[0]['fieldName'] == 'field1') + self.assertTrue(fields[0]['fieldType'] == 'INT') + self.assertTrue(fields[1]['fieldName'] == 'field2') + self.assertTrue(fields[1]['fieldType'] == 'STRING') + self.assertTrue(fields[2]['fieldName'] == 'field3') + self.assertTrue(fields[2]['fieldType'] == 'DATETIME') def test_02_upsertAttrs(self): tests = [ @@ -41,7 +45,7 @@ def test_02_upsertAttrs(self): ] for t in tests: - res = self.conn._upsertAttrs(t[0]) + res = _upsert_attrs(t[0]) self.assertEqual(t[1], res) def test_03_getSchema(self): diff --git a/tests/test_pyTigerGraphSchemaAsync.py b/tests/test_pyTigerGraphSchemaAsync.py new file mode 100644 index 00000000..330b7aa2 --- /dev/null +++ b/tests/test_pyTigerGraphSchemaAsync.py @@ -0,0 +1,380 @@ +import json +import unittest + +from pyTigerGraph.common.schema import _upsert_attrs + +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphSchemaAsync(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.conn = await make_connection() + + async def test_01_getUDTs(self): + res = await self.conn._getUDTs() + self.assertIsInstance(res, list) + self.assertEqual(2, len(res)) + self.assertTrue( + res[0]["name"] == "tuple1_all_types" or res[0]["name"] == "tuple2_simple") + tuple2_simple = res[0] if ( + res[0]["name"] == "tuple2_simple") else res[1] + self.assertIn('fields', tuple2_simple) + fields = tuple2_simple['fields'] + self.assertTrue(fields[0]['fieldName'] == 'field1') + self.assertTrue(fields[0]['fieldType'] == 'INT') + self.assertTrue(fields[1]['fieldName'] == 'field2') + self.assertTrue(fields[1]['fieldType'] == 'STRING') + self.assertTrue(fields[2]['fieldName'] == 'field3') + self.assertTrue(fields[2]['fieldType'] == 'DATETIME') + + def test_02_upsertAttrs(self): + tests = [ + ({"attr_name": "attr_value"}, {"attr_name": {"value": "attr_value"}}), + ({"attr_name1": "attr_value1", "attr_name2": "attr_value2"}, + {"attr_name1": {"value": "attr_value1"}, "attr_name2": {"value": "attr_value2"}}), + ({"attr_name": ("attr_value", "operator")}, + {"attr_name": {"value": "attr_value", "op": "operator"}}), + ({"attr_name1": ("attr_value1", "+"), "attr_name2": ("attr_value2", "-")}, + {"attr_name1": {"value": "attr_value1", "op": "+"}, + "attr_name2": {"value": "attr_value2", "op": "-"}}), + ("a string", {}), + ({"attr_name"}, {}), + (1, {}), + ({}, {}) + ] + + for t in tests: + res = _upsert_attrs(t[0]) + self.assertEqual(t[1], res) + + async def test_03_getSchema(self): + res = await self.conn.getSchema() + items = [ + ("GraphName"), + ("VertexTypes", "Name", [ + "vertex1_all_types", + "vertex2_primary_key", + "vertex3_primary_key_composite", + "vertex4", + "vertex5", + "vertex6", + "vertex7" + ]), + ("EdgeTypes", "Name", [ + "edge1_undirected", + "edge2_directed", + "edge3_directed_with_reverse", + "edge4_many_to_many", + "edge5_all_to_all", + "edge6_loop" + ]), + ("UDTs", "name", [ + "tuple1_all_types", + "tuple2_simple" + ]) + ] + self.assertEqual(len(items), len(res)) + for i in items: + if i == "GraphName": + self.assertEqual(self.conn.graphname, res[i]) + else: + self.assertIn(i[0], res) + t = res[i[0]] + self.assertIsInstance(t, list) + self.assertEqual(len(i[2]), len(t)) + for tt in t: + self.assertIn(i[1], tt) + self.assertIn(tt[i[1]], i[2]) + + async def test_04_upsertData(self): + data = { + "vertices": { + "vertex4": { + "4000": { + "a01": { + "value": 4000 + } + }, + "4001": { + "a01": { + "value": 4001 + } + } + }, + "vertex5": { + "5000": {}, + "5001": {} + } + }, + "edges": { + "vertex4": { + "4000": { + "edge2_directed": { + "vertex5": { + "5000": { + "a01": { + "value": 40005000 + } + }, + "5001": { + "a01": { + "value": 40005001 + } + } + } + } + }, + "4001": { + "edge3_directed_with_reverse": { + "vertex5": { + "5000": { + "a01": { + "value": 40005000 + } + }, + } + } + } + } + } + } + res = await self.conn.upsertData(data) + self.assertEqual({"accepted_vertices": 4, "accepted_edges": 3}, res) + + res = await self.conn.delVertices("vertex4", where="a01>1000") + self.assertEqual(2, res) + + res = await self.conn.delVerticesById("vertex5", [5000, 5001]) + self.assertEqual(2, res) + + """ + v4 v5 + 7000 🔵️———🔵️ + ╲ ╱ + ╳ + ╱ ╲ + 7001 🔵️ 🔵️ + """ + data = { + "vertices": { + "vertex4": { + "7000": { + "a01": { + "value": 7000 + } + }, + "7001": { + "a01": { + "value": 7000 + } + } + }, + "vertex5": { + "7000": {}, + "7001": {} + } + }, + "edges": { + "vertex4": { + "7000": { + "edge2_directed": { + "vertex5": { + "7000": { + "a01": { + "value": 7000 + } + }, + "7001": { + "a01": { + "value": 7000 + } + } + } + } + }, + "7001": { + "edge2_directed": { + "vertex5": { + "7000": { + "a01": { + "value": 7000 + } + } + } + } + } + } + } + } + + res = await self.conn.upsertData(data, atomic=True, ackAll=True) + self.assertEqual({"accepted_vertices": 4, "accepted_edges": 3}, res) + + """ + v4 v5 + 7000 🔴️———🔵️ + ╲ ╱ + ╳ + ╱ ╲ + 7001 🔴️ 🔵 + + 7002 🟢 + """ + data = { + "vertices": { + "vertex4": { + "7000": { + "a01": { + "value": 7010 + } + }, + "7001": { + "a01": { + "value": 7010 + } + }, + "7002": { + "a01": { + "value": 7010 + } + } + } + } + } + exp = { + "accepted_vertices": 1, + "skipped_vertices": 2, + "vertices_already_exist": [ + {"v_type": "vertex4", "v_id": "7000"}, + {"v_type": "vertex4", "v_id": "7001"} + ], + "accepted_edges": 0 + } + res = await self.conn.upsertData(data, newVertexOnly=True) + self.assertEqual(exp, res) + + """ + v4 v5 + 7000 🔵️———🔵 + ╲ ╱ + ╳ + ╱ ╲ + 7001 🔵️ 🔵️ + ╱ + ╱ + ╱ + 7002 🔵⋯⋯⋯🔴 + + 7003 🔴⋯⋯⋯🔴 + """ + data = { + "edges": { + "vertex4": { + "7002": { + "edge2_directed": { + "vertex5": { + "7001": { + "a01": { + "value": 7000 + } + }, + "7002": { + "a01": { + "value": 7000 + } + } + } + } + }, + "7003": { + "edge2_directed": { + "vertex5": { + "7003": { + "a01": { + "value": 7000 + } + }, + } + } + } + } + } + } + exp = { + "accepted_vertices": 0, + "accepted_edges": 1, + "skipped_edges": 2, + "edge_vertices_not_exist": [ + {"v_type": "vertex5", "v_id": "7002"}, + {"v_type": "vertex4", "v_id": "7003"}, + {"v_type": "vertex5", "v_id": "7003"} + ] + } + res = await self.conn.upsertData(data, vertexMustExist=True) + self.assertEqual(exp, res) + + """ + v4 v5 + 7000 🟢️———🔵 + ╲ ╱ + ╳ + ╱ ╲ + 7001 🟢️ 🔵 + ╱ + ╱ + ╱ + 7002 🔵 + + 7003 🔴 + """ + data = { + "vertices": { + "vertex4": { + "7000": { + "a01": { + "value": 7020 + } + }, + "7001": { + "a01": { + "value": 7020 + } + }, + "7003": { + "a01": { + "value": 7020 + } + } + } + } + } + exp = { + "accepted_vertices": 2, + "skipped_vertices": 1, + "vertices_not_exist": [ + {"v_type": "vertex4", "v_id": "7003"} + ], + "accepted_edges": 0 + } + res = await self.conn.upsertData(data, updateVertexOnly=True) + self.assertEqual(exp, res) + + res = await self.conn.delVertices("vertex4", where="a01>=7000,a01<8000") + self.assertEqual(3, res) + + res = await self.conn.delVerticesById("vertex5", [7000, 7001]) + self.assertEqual(2, res) + + async def test_05_getEndpoints(self): + res = await self.conn.getEndpoints() + self.assertIsInstance(res, dict) + self.assertIn("GET /endpoints/{graph_name}", res) + + res = await self.conn.getEndpoints(dynamic=True) + self.assertEqual(4, len(res)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pyTigerGraphUDT.py b/tests/test_pyTigerGraphUDT.py index 3f61f086..53b4583d 100644 --- a/tests/test_pyTigerGraphUDT.py +++ b/tests/test_pyTigerGraphUDT.py @@ -15,12 +15,12 @@ def test_01_getUDTs(self): def test_02_getUDT(self): res = self.conn.getUDT("tuple2_simple") - self.assertTrue(res[0]['fieldName']=='field1') - self.assertTrue(res[0]['fieldType']=='INT') - self.assertTrue(res[1]['fieldName']=='field2') - self.assertTrue(res[1]['fieldType']=='STRING') - self.assertTrue(res[2]['fieldName']=='field3') - self.assertTrue(res[2]['fieldType']=='DATETIME') + self.assertTrue(res[0]['fieldName'] == 'field1') + self.assertTrue(res[0]['fieldType'] == 'INT') + self.assertTrue(res[1]['fieldName'] == 'field2') + self.assertTrue(res[1]['fieldType'] == 'STRING') + self.assertTrue(res[2]['fieldName'] == 'field3') + self.assertTrue(res[2]['fieldType'] == 'DATETIME') if __name__ == '__main__': diff --git a/tests/test_pyTigerGraphUDTAsync.py b/tests/test_pyTigerGraphUDTAsync.py new file mode 100644 index 00000000..701610b8 --- /dev/null +++ b/tests/test_pyTigerGraphUDTAsync.py @@ -0,0 +1,28 @@ +import unittest + +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphUDT(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.conn = await make_connection() + + async def test_01_getUDTs(self): + res = await self.conn.getUDTs() + exp = ["tuple1_all_types", "tuple2_simple"] + self.assertEqual(exp, res) + + async def test_02_getUDT(self): + res = await self.conn.getUDT("tuple2_simple") + self.assertTrue(res[0]['fieldName'] == 'field1') + self.assertTrue(res[0]['fieldType'] == 'INT') + self.assertTrue(res[1]['fieldName'] == 'field2') + self.assertTrue(res[1]['fieldType'] == 'STRING') + self.assertTrue(res[2]['fieldName'] == 'field3') + self.assertTrue(res[2]['fieldType'] == 'DATETIME') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pyTigerGraphUtils.py b/tests/test_pyTigerGraphUtils.py index ed77b70e..156662ac 100644 --- a/tests/test_pyTigerGraphUtils.py +++ b/tests/test_pyTigerGraphUtils.py @@ -2,31 +2,34 @@ import unittest from datetime import datetime +from pyTigerGraph.common.util import _safe_char + from pyTigerGraphUnitTest import make_connection -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException + class test_pyTigerGraphUtils(unittest.TestCase): @classmethod def setUpClass(cls): - cls.conn = make_connection(graphname="Cora") + cls.conn = make_connection(graphname="tests") def test_01_safeChar(self): - res = self.conn._safeChar(" _space") + res = _safe_char(" _space") self.assertEqual("%20_space", res) - res = self.conn._safeChar("/_slash") + res = _safe_char("/_slash") self.assertEqual("%2F_slash", res) - res = self.conn._safeChar("ñ_LATIN_SMALL_LETTER_N_WITH_TILDE") + res = _safe_char("ñ_LATIN_SMALL_LETTER_N_WITH_TILDE") self.assertEqual(res, '%C3%B1_LATIN_SMALL_LETTER_N_WITH_TILDE') - res = self.conn._safeChar(12345) + res = _safe_char(12345) self.assertEqual("12345", res) - res = self.conn._safeChar(12.345) + res = _safe_char(12.345) self.assertEqual("12.345", res) now = datetime.now() - res = self.conn._safeChar(now) + res = _safe_char(now) exp = str(now).replace(" ", "%20").replace(":", "%3A") self.assertEqual(exp, res) - res = self.conn._safeChar(True) + res = _safe_char(True) self.assertEqual("True", res) def test_02_echo(self): @@ -54,7 +57,7 @@ def test_05_ping(self): self.assertEqual(res["message"], "pong") def test_06_getSystemMetrics(self): - if self.conn._versionGreaterThan4_0(): + if self.conn._version_greater_than_4_0(): res = self.conn.getSystemMetrics(what="cpu-memory") self.assertIn("CPUMemoryMetrics", res) res = self.conn.getSystemMetrics(what="diskspace") @@ -63,15 +66,17 @@ def test_06_getSystemMetrics(self): self.assertIn("NetworkMetrics", res) res = self.conn.getSystemMetrics(what="qps") self.assertIn("QPSMetrics", res) - + with self.assertRaises(TigerGraphException) as tge: res = self.conn.getSystemMetrics(what="servicestate") - self.assertEqual("This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", tge.exception.message) - + self.assertEqual( + "This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", tge.exception.message) + with self.assertRaises(TigerGraphException) as tge: res = self.conn.getSystemMetrics(what="connection") - self.assertEqual("This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", tge.exception.message) - else: + self.assertEqual( + "This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", tge.exception.message) + else: res = self.conn.getSystemMetrics(what="mem", latest=10) self.assertEqual(len(res), 10) @@ -80,13 +85,16 @@ def test_07_getQueryPerformance(self): self.assertIn("CompletedRequests", str(res)) def test_08_getServiceStatus(self): - req = {"ServiceDescriptors":[{"ServiceName": "GSQL"}]} + req = {"ServiceDescriptors": [{"ServiceName": "GSQL"}]} res = self.conn.getServiceStatus(req) - self.assertEqual(res["ServiceStatusEvents"][0]["ServiceStatus"], "Online") + self.assertEqual(res["ServiceStatusEvents"][0] + ["ServiceStatus"], "Online") def test_09_rebuildGraph(self): res = self.conn.rebuildGraph() - self.assertEqual(res["message"], "RebuildNow finished, please check details in the folder: /tmp/rebuildnow") + self.assertEqual( + res["message"], "RebuildNow finished, please check details in the folder: /tmp/rebuildnow") + if __name__ == '__main__': unittest.main() diff --git a/tests/test_pyTigerGraphUtilsAsync.py b/tests/test_pyTigerGraphUtilsAsync.py new file mode 100644 index 00000000..9f2df299 --- /dev/null +++ b/tests/test_pyTigerGraphUtilsAsync.py @@ -0,0 +1,99 @@ +import re +import unittest +from datetime import datetime + +from pyTigerGraph.common.util import _safe_char + +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphUtilsAsync(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.conn = await make_connection() + + def test_01_safeChar(self): + res = _safe_char(" _space") + self.assertEqual("%20_space", res) + res = _safe_char("/_slash") + self.assertEqual("%2F_slash", res) + res = _safe_char("ñ_LATIN_SMALL_LETTER_N_WITH_TILDE") + self.assertEqual(res, '%C3%B1_LATIN_SMALL_LETTER_N_WITH_TILDE') + res = _safe_char(12345) + self.assertEqual("12345", res) + res = _safe_char(12.345) + self.assertEqual("12.345", res) + now = datetime.now() + res = _safe_char(now) + exp = str(now).replace(" ", "%20").replace(":", "%3A") + self.assertEqual(exp, res) + res = _safe_char(True) + self.assertEqual("True", res) + + async def test_02_echo(self): + res = await self.conn.echo() + self.assertIsInstance(res, str) + self.assertEqual("Hello GSQL", res) + res = await self.conn.echo(True) + self.assertIsInstance(res, str) + self.assertEqual("Hello GSQL", res) + + async def test_03_getVersion(self): + res = await self.conn.getVersion() + self.assertIsInstance(res, list) + self.assertGreater(len(res), 0) + + async def test_04_getVer(self): + res = await self.conn.getVer() + self.assertIsInstance(res, str) + m = re.match(r"[0-9]+\.[0-9]+\.[0-9]", res) + self.assertIsNotNone(m) + + async def test_05_ping(self): + res = await self.conn.ping() + self.assertIsInstance(res, dict) + self.assertEqual(res["message"], "pong") + + async def test_06_getSystemMetrics(self): + if await self.conn._version_greater_than_4_0(): + res = await self.conn.getSystemMetrics(what="cpu-memory") + self.assertIn("CPUMemoryMetrics", res) + res = await self.conn.getSystemMetrics(what="diskspace") + self.assertIn("DiskMetrics", res) + res = await self.conn.getSystemMetrics(what="network") + self.assertIn("NetworkMetrics", res) + res = await self.conn.getSystemMetrics(what="qps") + self.assertIn("QPSMetrics", res) + + with self.assertRaises(TigerGraphException) as tge: + res = await self.conn.getSystemMetrics(what="servicestate") + self.assertEqual( + "This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", tge.exception.message) + + with self.assertRaises(TigerGraphException) as tge: + res = await self.conn.getSystemMetrics(what="connection") + self.assertEqual( + "This 'what' parameter is only supported on versions of TigerGraph < 4.1.0.", tge.exception.message) + else: + res = await self.conn.getSystemMetrics(what="mem", latest=10) + self.assertEqual(len(res), 10) + ''' Commented out because the queries are not completed yet + async def test_07_getQueryPerformance(self): + res = await self.conn.getQueryPerformance() + self.assertIn("CompletedRequests", str(res)) + ''' + async def test_08_getServiceStatus(self): + req = {"ServiceDescriptors": [{"ServiceName": "GSQL"}]} + res = await self.conn.getServiceStatus(req) + self.assertEqual(res["ServiceStatusEvents"][0] + ["ServiceStatus"], "Online") + + async def test_09_rebuildGraph(self): + res = await self.conn.rebuildGraph() + self.assertEqual( + res["message"], "RebuildNow finished, please check details in the folder: /tmp/rebuildnow") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_pyTigerGraphVertex.py b/tests/test_pyTigerGraphVertex.py index 47fc00cf..d6698782 100644 --- a/tests/test_pyTigerGraphVertex.py +++ b/tests/test_pyTigerGraphVertex.py @@ -4,7 +4,7 @@ import pandas from pyTigerGraphUnitTest import make_connection -from pyTigerGraph.pyTigerGraphException import TigerGraphException +from pyTigerGraph.common.exception import TigerGraphException class test_pyTigerGraphVertex(unittest.TestCase): @@ -17,7 +17,7 @@ def test_01_getVertexTypes(self): self.assertIsInstance(res, list) self.assertEqual(7, len(res)) exp = ["vertex1_all_types", "vertex2_primary_key", "vertex3_primary_key_composite", - "vertex4", "vertex5", "vertex6", "vertex7"] + "vertex4", "vertex5", "vertex6", "vertex7"] self.assertEqual(exp, res) def test_02_getVertexType(self): @@ -64,12 +64,13 @@ def test_03_getVertexCount(self): with self.assertRaises(TigerGraphException) as tge: self.conn.getVertexCount("*", "a01>=3") self.assertEqual("VertexType cannot be \"*\" if where condition is specified.", - tge.exception.message) + tge.exception.message) with self.assertRaises(TigerGraphException) as tge: - self.conn.getVertexCount(["vertex4", "vertex5", "vertex6"], "a01>=3") + self.conn.getVertexCount( + ["vertex4", "vertex5", "vertex6"], "a01>=3") self.assertEqual("VertexType cannot be a list if where condition is specified.", - tge.exception.message) + tge.exception.message) with self.assertRaises(TigerGraphException) as tge: self.conn.getVertexCount("non_existing_vertex_type") @@ -94,11 +95,13 @@ def test_04_upsertVertex(self): self.assertEqual(1, res) with self.assertRaises(TigerGraphException) as tge: - self.conn.upsertVertex("non_existing_vertex_type", 100, {"a01": 100}) + self.conn.upsertVertex( + "non_existing_vertex_type", 100, {"a01": 100}) self.assertEqual("REST-30200", tge.exception.code) with self.assertRaises(TigerGraphException) as tge: - self.conn.upsertVertex("vertex4", 100, {"non_existing_vertex_attribute": 100}) + self.conn.upsertVertex( + "vertex4", 100, {"non_existing_vertex_attribute": 100}) self.assertEqual("REST-30200", tge.exception.code) def test_05_upsertVertices(self): @@ -116,7 +119,8 @@ def test_05_upsertVertices(self): self.assertIsInstance(res, list) v = {} for r in res: - if "v_id" in r and r["v_id"] == '100': # v_id value is returned as str, not int + # v_id value is returned as str, not int + if "v_id" in r and r["v_id"] == '100': v = r self.assertNotEqual({}, v) self.assertIn("attributes", v) @@ -137,14 +141,14 @@ def test_06_upsertVertexDataFrame(self): def test_07_getVertices(self): res = self.conn.getVertices("vertex4", select="a01", where="a01>1,a01<5", sort="-a01", - limit=2) + limit=2) self.assertIsInstance(res, list) self.assertEqual(2, len(res)) self.assertEqual(4, res[0]["attributes"]["a01"]) self.assertEqual(3, res[1]["attributes"]["a01"]) res = self.conn.getVertices("vertex4", select="a01", where="a01>1,a01<5", sort="-a01", - limit=2, fmt="json") + limit=2, fmt="json") self.assertIsInstance(res, str) res = json.loads(res) self.assertIsInstance(res, list) @@ -153,19 +157,20 @@ def test_07_getVertices(self): self.assertEqual(3, res[1]["attributes"]["a01"]) res = self.conn.getVertices("vertex4", select="a01", where="a01>1,a01<5", sort="-a01", - limit=2, fmt="df") + limit=2, fmt="df") self.assertIsInstance(res, pandas.DataFrame) self.assertEqual(2, len(res.index)) def test_08_getVertexDataFrame(self): res = self.conn.getVertexDataFrame("vertex4", select="a01", where="a01>1,a01<5", - sort="-a01", - limit=2) + sort="-a01", + limit=2) self.assertIsInstance(res, pandas.DataFrame) self.assertEqual(2, len(res.index)) def test_09_getVerticesById(self): - res = self.conn.getVerticesById("vertex4", [1, 3, 5], select="a01") # select is ignored + res = self.conn.getVerticesById( + "vertex4", [1, 3, 5], select="a01") # select is ignored self.assertIsInstance(res, list) self.assertEqual(3, len(res)) @@ -237,7 +242,7 @@ def test_15_vertexSetToDataFrame(self): res = self.conn.vertexSetToDataFrame(res) self.assertIsInstance(res, pandas.DataFrame) self.assertEqual(5, len(res.index)) - self.assertEqual(["v_id","a01"], list(res.columns)) + self.assertEqual(["v_id", "a01"], list(res.columns)) def test_16_delVerticesByType(self): res = self.conn.delVerticesByType("vertex4") diff --git a/tests/test_pyTigerGraphVertexAsync.py b/tests/test_pyTigerGraphVertexAsync.py new file mode 100644 index 00000000..abaab125 --- /dev/null +++ b/tests/test_pyTigerGraphVertexAsync.py @@ -0,0 +1,243 @@ +import json +import unittest + +import pandas +from pyTigerGraphUnitTestAsync import make_connection + +from pyTigerGraph.common.exception import TigerGraphException + + +class test_pyTigerGraphVertexAsync(unittest.IsolatedAsyncioTestCase): + @classmethod + async def asyncSetUp(self): + self.conn = await make_connection() + + async def test_01_getVertexTypes(self): + res = sorted(await self.conn.getVertexTypes()) + self.assertIsInstance(res, list) + self.assertEqual(7, len(res)) + exp = ["vertex1_all_types", "vertex2_primary_key", "vertex3_primary_key_composite", + "vertex4", "vertex5", "vertex6", "vertex7"] + self.assertEqual(exp, res) + + async def test_02_getVertexType(self): + res = await self.conn.getVertexType("vertex1_all_types") + self.assertIsInstance(res, dict) + self.assertIn("PrimaryId", res) + self.assertIn("AttributeName", res["PrimaryId"]) + self.assertEqual("id", res["PrimaryId"]["AttributeName"]) + self.assertIn("AttributeType", res["PrimaryId"]) + self.assertIn("Name", res["PrimaryId"]["AttributeType"]) + self.assertEqual("STRING", res["PrimaryId"]["AttributeType"]["Name"]) + self.assertIn("IsLocal", res) + self.assertTrue(res["IsLocal"]) + + res = await self.conn.getVertexType("non_existing_vertex_type") + self.assertEqual({}, res) + # TODO This will need to be reviewed if/when getVertexType() return value changes from {} in + # case of invalid/non-existing edge type name is specified (e.g. an exception will be + # raised instead of returning {} + + async def test_03_getVertexCount(self): + res = await self.conn.getVertexCount("*") + self.assertIsInstance(res, dict) + self.assertEqual(7, len(res)) + self.assertIn("vertex4", res) + self.assertEqual(0, res["vertex4"]) + self.assertIn("vertex1_all_types", res) + self.assertEqual(0, res["vertex1_all_types"]) + + res = await self.conn.getVertexCount("vertex4") + self.assertIsInstance(res, int) + self.assertEqual(0, res) # vertex4 was deleted in non-async test + + res = await self.conn.getVertexCount(["vertex4", "vertex5", "vertex6"]) + self.assertIsInstance(res, dict) + self.assertEqual(3, len(res)) + self.assertIn("vertex4", res) + self.assertEqual(0, res["vertex4"]) # vertex4 was deleted in non-async test + + res = await self.conn.getVertexCount("vertex4", "a01>=3") + self.assertIsInstance(res, int) + self.assertEqual(0, res) + + with self.assertRaises(TigerGraphException) as tge: + await self.conn.getVertexCount("*", "a01>=3") + self.assertEqual("VertexType cannot be \"*\" if where condition is specified.", + tge.exception.message) + + with self.assertRaises(TigerGraphException) as tge: + await self.conn.getVertexCount(["vertex4", "vertex5", "vertex6"], "a01>=3") + self.assertEqual("VertexType cannot be a list if where condition is specified.", + tge.exception.message) + + with self.assertRaises(TigerGraphException) as tge: + await self.conn.getVertexCount("non_existing_vertex_type") + # self.assertEqual("REST-30000", tge.exception.code) + self.assertEqual("GSQL-7004", tge.exception.code) + + res = await self.conn.getVertexCount("*", realtime=True) + self.assertIsInstance(res, dict) + self.assertEqual(7, len(res)) + self.assertIn("vertex4", res) + self.assertEqual(0, res["vertex4"]) + self.assertIn("vertex1_all_types", res) + self.assertEqual(0, res["vertex1_all_types"]) + + res = await self.conn.getVertexCount("vertex4", realtime=True) + self.assertIsInstance(res, int) + self.assertEqual(0, res) + + async def test_04_upsertVertex(self): + res = await self.conn.upsertVertex("vertex4", 100, {"a01": 100}) + self.assertIsInstance(res, int) + self.assertEqual(1, res) + + with self.assertRaises(TigerGraphException) as tge: + await self.conn.upsertVertex("non_existing_vertex_type", 100, {"a01": 100}) + self.assertEqual("REST-30200", tge.exception.code) + + with self.assertRaises(TigerGraphException) as tge: + await self.conn.upsertVertex("vertex4", 100, {"non_existing_vertex_attribute": 100}) + self.assertEqual("REST-30200", tge.exception.code) + + async def test_05_upsertVertices(self): + vs = [ + (100, {"a01": (11, "+")}), + (200, {"a01": 200}), + (201, {"a01": 201}), + (202, {"a01": 202}) + ] + res = await self.conn.upsertVertices("vertex4", vs) + self.assertIsInstance(res, int) + self.assertEqual(4, res) + + res = await self.conn.getVertices("vertex4", where="a01>100") + self.assertIsInstance(res, list) + v = {} + for r in res: + # v_id value is returned as str, not int + if "v_id" in r and r["v_id"] == '100': + v = r + self.assertNotEqual({}, v) + self.assertIn("attributes", v) + self.assertIn("a01", v["attributes"]) + self.assertEqual(111, v["attributes"]["a01"]) + + res = await self.conn.delVertices("vertex4", "a01>100") + self.assertIsInstance(res, int) + self.assertEqual(4, res) + + res = await self.conn.getVertices("vertex4", where="a01>100") + self.assertIsInstance(res, list) + self.assertEqual(res, []) + + async def test_06_upsertVertexDataFrame(self): + # TODO Implement + pass + ''' + async def test_07_getVertices(self): + res = await self.conn.getVertices("vertex4", select="a01", where="a01>1,a01<5", sort="-a01", + limit=2) + self.assertIsInstance(res, list) + self.assertEqual(2, len(res)) + self.assertEqual(4, res[0]["attributes"]["a01"]) + self.assertEqual(3, res[1]["attributes"]["a01"]) + + res = await self.conn.getVertices("vertex4", select="a01", where="a01>1,a01<5", sort="-a01", + limit=2, fmt="json") + self.assertIsInstance(res, str) + res = json.loads(res) + self.assertIsInstance(res, list) + self.assertEqual(2, len(res)) + self.assertEqual(4, res[0]["attributes"]["a01"]) + self.assertEqual(3, res[1]["attributes"]["a01"]) + + res = await self.conn.getVertices("vertex4", select="a01", where="a01>1,a01<5", sort="-a01", + limit=2, fmt="df") + self.assertIsInstance(res, pandas.DataFrame) + self.assertEqual(2, len(res.index)) + + async def test_08_getVertexDataFrame(self): + res = await self.conn.getVertexDataFrame("vertex4", select="a01", where="a01>1,a01<5", + sort="-a01", + limit=2) + self.assertIsInstance(res, pandas.DataFrame) + self.assertEqual(2, len(res.index)) + + async def test_09_getVerticesById(self): + # select is ignored + res = await self.conn.getVerticesById("vertex4", [1, 3, 5], select="a01") + self.assertIsInstance(res, list) + self.assertEqual(3, len(res)) + + res = await self.conn.getVerticesById("vertex4", [1, 3, 5], fmt="json") + self.assertIsInstance(res, str) + res = json.loads(res) + self.assertIsInstance(res, list) + self.assertEqual(3, len(res)) + + res = await self.conn.getVerticesById("vertex4", [1, 3, 5], fmt="df") + self.assertIsInstance(res, pandas.DataFrame) + + async def test_10_getVertexDataFrameById(self): + res = await self.conn.getVertexDataFrameById("vertex4", [1, 3, 5]) + self.assertIsInstance(res, pandas.DataFrame) + self.assertEqual(3, len(res.index)) + ''' + async def test_11_getVertexStats(self): + res = await self.conn.getVertexStats("*", skipNA=True) + self.assertIsInstance(res, dict) + + res = await self.conn.getVertexStats("vertex4") + self.assertEqual({'vertex4': {}}, res) # vertex4 was deleted in non-async test + + res = await self.conn.getVertexStats("vertex5", skipNA=True) + self.assertEqual({}, res) + + async def test_12_delVertices(self): + vs = [ + (300, {"a01": 300}), + (301, {"a01": 301}), + (302, {"a01": 302}), + (303, {"a01": 303}), + (304, {"a01": 304}) + ] + res = await self.conn.upsertVertices("vertex4", vs) + self.assertIsInstance(res, int) + self.assertEqual(5, res) + + res = await self.conn.getVertices("vertex4", where="a01>=300") + self.assertIsInstance(res, list) + self.assertEqual(5, len(res)) + + res = await self.conn.delVertices("vertex4", where="a01>=303") + self.assertIsInstance(res, int) + self.assertEqual(2, res) + + async def test_13_delVerticesById(self): + res = await self.conn.delVerticesById("vertex4", 300) + self.assertIsInstance(res, int) + self.assertEqual(1, res) + + res = await self.conn.delVerticesById("vertex4", [301, 302]) + self.assertIsInstance(res, int) + self.assertEqual(2, res) + + async def test_14_delVerticesByType(self): + pass + # TODO Implement pyTigergraphVertices.delVerticesByType() first + ''' + async def test_15_vertexSetToDataFrame(self): + res = await self.conn.getVertices("vertex4") + self.assertIsInstance(res, list) + self.assertEqual(0, len(res)) # vertex4 was deleted in non-async test + + res = await self.conn.vertexSetToDataFrame(res) + self.assertIsInstance(res, pandas.DataFrame) + self.assertEqual(0, len(res.index)) # vertex4 was deleted in non-async test + self.assertEqual(["v_id", "a01"], list(res.columns)) + ''' + +if __name__ == '__main__': + unittest.main()