Skip to content

Add train configuration for aura GraphSage #441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
59 changes: 59 additions & 0 deletions examples/python-runtime.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DBID = \"beefbeef\"\n",
"ENVIRONMENT = \"\"\n",
"PASSWORD = \"\"\n",
"\n",
"from graphdatascience import GraphDataScience\n",
"\n",
"gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n",
"gds.set_database(\"neo4j\")\n",
"\n",
"gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" gds.graph.load_cora()\n",
"except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
5 changes: 4 additions & 1 deletion graphdatascience/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints
from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder
from .gnn.gnn_endpoints import GnnEndpoints
from .graph.graph_endpoints import (
GraphAlphaEndpoints,
GraphBetaEndpoints,
Expand Down Expand Up @@ -32,7 +33,9 @@
"""


class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints):
class DirectEndpoints(
DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints
):
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion):
super().__init__(query_runner, namespace, server_version)

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions graphdatascience/gnn/gnn_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ..caller_base import CallerBase
from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace
from .gnn_nc_runner import GNNNodeClassificationRunner


class GNNRunner(UncallableNamespace, IllegalAttrChecker):
@property
def nodeClassification(self) -> GNNNodeClassificationRunner:
return GNNNodeClassificationRunner(
self._query_runner, f"{self._namespace}.nodeClassification", self._server_version
)


class GnnEndpoints(CallerBase):
@property
def gnn(self) -> GNNRunner:
return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version)
79 changes: 79 additions & 0 deletions graphdatascience/gnn/gnn_nc_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
from typing import Any, List

from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace


class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
def make_graph_sage_config(self, graph_sage_config):
GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5,
"hidden_channels": 256, "learning_rate": 0.003}
final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG
if graph_sage_config:
bad_keys = []
for key in graph_sage_config:
if key not in GRAPH_SAGE_DEFAULT_CONFIG:
bad_keys.append(key)
if len(bad_keys) > 0:
raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.")

final_sage_config.update(graph_sage_config)
return final_sage_config

def train(
self,
graph_name: str,
model_name: str,
feature_properties: List[str],
target_property: str,
relationship_types: List[str],
target_node_label: str = None,
node_labels: List[str] = None,
graph_sage_config = None
) -> "Series[Any]": # noqa: F821
mlConfigMap = {
"featureProperties": feature_properties,
"targetProperty": target_property,
"job_type": "train",
"nodeProperties": feature_properties + [target_property],
"relationshipTypes": relationship_types,
"graph_sage_config": self.make_graph_sage_config(graph_sage_config)
}

if target_node_label:
mlConfigMap["targetNodeLabel"] = target_node_label
if node_labels:
mlConfigMap["nodeLabels"] = node_labels

mlTrainingConfig = json.dumps(mlConfigMap)

# token and uri will be injected by arrow_query_runner
self._query_runner.run_query(
"CALL gds.upload.graph($config)",
params={
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
},
)

def predict(
self,
graph_name: str,
model_name: str,
mutateProperty: str,
predictedProbabilityProperty: str = None,
) -> "Series[Any]": # noqa: F821
mlConfigMap = {
"job_type": "predict",
"mutateProperty": mutateProperty
}
if predictedProbabilityProperty:
mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty

mlTrainingConfig = json.dumps(mlConfigMap)
self._query_runner.run_query(
"CALL gds.upload.graph($config)",
params={
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
},
) # type: ignore
1 change: 1 addition & 0 deletions graphdatascience/ignored_server_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"gds.alpha.pipeline.nodeRegression.predict.stream",
"gds.alpha.pipeline.nodeRegression.selectFeatures",
"gds.alpha.pipeline.nodeRegression.train",
"gds.gnn.nc",
"gds.similarity.cosine",
"gds.similarity.euclidean",
"gds.similarity.euclideanDistance",
Expand Down
21 changes: 19 additions & 2 deletions graphdatascience/query_runner/arrow_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(
):
self._fallback_query_runner = fallback_query_runner
self._server_version = server_version
# FIXME handle version were tls cert is given
self._auth = auth
self._uri = uri

host, port_string = uri.split(":")

Expand All @@ -39,8 +42,9 @@ def __init__(
)

client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
self._auth_factory = AuthFactory(auth)
if auth:
client_options["middleware"] = [AuthFactory(auth)]
client_options["middleware"] = [self._auth_factory]
if tls_root_certs:
client_options["tls_root_certs"] = tls_root_certs

Expand Down Expand Up @@ -129,6 +133,10 @@ def run_query(
endpoint = "gds.beta.graph.relationships.stream"

return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types})
elif "gds.upload.graph" in query:
# inject parameters
params["config"]["token"] = self._get_or_request_token()
params["config"]["arrowEndpoint"] = self._uri

return self._fallback_query_runner.run_query(query, params, database, custom_error)

Expand Down Expand Up @@ -184,6 +192,10 @@ def create_graph_constructor(
database, graph_name, self._flight_client, concurrency, undirected_relationship_types
)

def _get_or_request_token(self) -> str:
self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1])
return self._auth_factory.token()


class AuthFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -217,9 +229,14 @@ def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None:
self._factory = factory

def received_headers(self, headers: Dict[str, Any]) -> None:
auth_header: str = headers.get("Authorization", None)
auth_header: str = headers.get("authorization", None)
if not auth_header:
return
# authenticate_basic_token() returns a list.
# TODO We should take the first Bearer element here
if isinstance(auth_header, list):
auth_header = auth_header[0]

[auth_type, token] = auth_header.split(" ", 1)
if auth_type == "Bearer":
self._factory.set_token(token)
Expand Down