Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 128 additions & 10 deletions examples/load_graph_archngv.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,116 @@
#!/usr/bin/env python
# coding: utf-8

import argparse
import base64
import getpass
import glob
import multiprocessing
import pickle
from functools import partial
from os import environ
from pathlib import Path

import numpy as np
import pandas as pd
import psutil
import requests
from archngv import NGVCircuit
from joblib import Parallel, delayed, parallel_config
from kgforge.core import KnowledgeGraphForge, Resource
from kgforge.specializations.resources import Dataset
from tqdm import tqdm

from astrovascpy import bloodflow
from astrovascpy.exceptions import BloodFlowError
from astrovascpy.utils import Graph


def get_nexus_token(
client_id="bbp-molsys-sa",
environ_name="KCS",
nexus_url="https://bbpauth.epfl.ch/auth/realms/BBP/protocol/openid-connect/token",
):
"""
retrieve a Nexus Token from keycloak
param:
client_id(str): the keycloak client id
environ_name(str): the name of the environment variable that holds the keycloak secret
nexus_url(str)
"""
try:
client_secret = environ[environ_name]

# bbp keycloack token endpoint
url = nexus_url

payload = "grant_type=client_credentials&scope=openid"
authorization = base64.b64encode(f"{client_id}:{client_secret}".encode("utf-8")).decode(
"ascii"
)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {authorization}",
}

# request the token
r = requests.request(
"POST",
url=url,
headers=headers,
data=payload,
)

# get access token
mexus_token = r.json()["access_token"]
return mexus_token
except Exception as error:
print(f"Error: {error}")
return None


def get_nexus_circuit_conf(
circuit_name, nexus_org="bbp", nexus_project="mmb-neocortical-regions-ngv"
):
"""
retrieve nexus NGV config entry
param:
circuit_name(str): the Nexus name of the NGV circuit to load
nexus_org(str): The Nexus organisation
nexus_projec(str): The Nexus project that holds the circuit
"""

nexus_token = get_nexus_token()
if nexus_token is None:
print("Error: Cannot get a valid Nexus token")
return None

nexus_endpoint = "https://bbp.epfl.ch/nexus/v1" # production environment

forge = KnowledgeGraphForge(
"https://raw.githubusercontent.com/BlueBrain/nexus-forge/master/examples/notebooks/use-cases/prod-forge-nexus.yml",
endpoint=nexus_endpoint,
bucket=f"{nexus_org}/{nexus_project}",
token=nexus_token,
debug=True,
)

p = forge.paths("Dataset")
resources = forge.search(p.type == "DetailedCircuit", p.name == circuit_name, limit=30)

forge.as_dataframe(resources)
if len(resources) != 1:
print("There are several NGV circuit with this name")
return None
else:
circuit = resources[0]
circuitConfigPath = circuit.circuitConfigPath.url
circuitConfigPath = circuitConfigPath[len("file://") :]
print(f"circuitConfigPath: {circuitConfigPath}")

return circuitConfigPath


def load_graph_archngv_parallel(
filename, n_workers, n_astro=None, parallelization_backend="multiprocessing"
):
Expand Down Expand Up @@ -63,7 +158,11 @@ def load_graph_archngv_parallel(
with multiprocessing.Pool(n_workers) as pool:
for result_ids, result_endfeet in zip(
tqdm(
pool.imap(worker, args, chunksize=max(1, int(len(endfoot_ids) / n_workers))),
pool.imap(
worker,
args,
chunksize=max(1, int(len(endfoot_ids) / n_workers)),
),
total=len(endfoot_ids),
),
endfoot_ids,
Expand All @@ -75,7 +174,10 @@ def load_graph_archngv_parallel(

elif parallelization_backend == "joblib":
with parallel_config(
backend="loky", prefer="processes", n_jobs=n_workers, inner_max_num_threads=1
backend="loky",
prefer="processes",
n_jobs=n_workers,
inner_max_num_threads=1,
):
parallel = Parallel(return_as="generator", batch_size="auto")
parallelized_region = parallel(
Expand All @@ -101,15 +203,30 @@ def main():
print = partial(print, flush=True)

parser = argparse.ArgumentParser(description="File paths for NGVCircuits and output graph.")
parser.add_argument("--circuit-name", type=str, required=False, help="NGV circuits nexus name")
parser.add_argument("--circuit-path", type=str, required=False, help="Path to the NGV circuits")
parser.add_argument(
"--filename_ngv", type=str, required=True, help="Path to the NGV circuits file"
)
parser.add_argument(
"--output_graph", type=str, required=True, help="Path to the output graph file"
"--output-graph", type=str, required=True, help="Path to the output graph file"
)
args = parser.parse_args()

filename_ngv = args.filename_ngv
if args.circuit_name is not None:
circuit_name = args.circuit_name
filename_ngv = get_nexus_circuit_conf(circuit_name)
if filename_ngv is None:
print("Error: Could not obtain a valid file path for the NGV circuit")
return -1

elif args.circuit_path is not None:
filename_ngv = args.circuit_path
# filename_ngv = args.filename_ngv

else:
print("ERROR: circuit-name or circuit-path must be provided")
return -1

print(f"INFO: filename_ngv {filename_ngv} ")

output_graph = args.output_graph

n_cores = psutil.cpu_count(logical=False)
Expand All @@ -123,11 +240,12 @@ def main():
print("loading circuit : finish")

print("pickle graph : start")
filehandler = open(output_graph, "wb")
pickle.dump(graph, filehandler)
print("pickle graph : finish")
with open(output_graph, "wb") as filehandler:
pickle.dump(graph, filehandler)
print("pickle graph : finish")
print(f"Graph file: {output_graph}")


if __name__ == "__main__":
print("INFO: Start load_graph_archngv.py")
main()
24 changes: 18 additions & 6 deletions examples/load_graph_archngv.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#SBATCH --job-name="archngv"
#SBATCH --nodes=1

#SBATCH --account=proj16
#SBATCH --account=proj137
#SBATCH --partition=prod
#SBATCH --constraint=cpu
#SBATCH --time=00:30:00
Expand All @@ -15,6 +15,8 @@

JOB_SCRIPT=$(scontrol show job ${SLURM_JOB_ID} | awk -F= '/Command=/{print $2}')
JOB_SCRIPT_DIR=$(dirname ${JOB_SCRIPT})
JOB_SCRIPT_DIR='/gpfs/bbp.cscs.ch/data/scratch/proj137/jacquemi/Nexus'


SETUP_SCRIPT="${JOB_SCRIPT_DIR}/../setup.sh"
if [[ ! -f ${SETUP_SCRIPT} ]]; then
Expand All @@ -24,13 +26,23 @@ fi

source ${SETUP_SCRIPT}

FILENAME_NGV="/gpfs/bbp.cscs.ch/project/proj137/NGVCircuits/rat_O1"
FILENAME_NGV=""

GRAPH_PATH="./data/graphs_folder/dumped_graph.bin"

echo
echo "### Loading graph"
echo
#CIRCUIT_NAME="tiny_real_spine_morph"
CIRCUIT_NAME="NGV O1.v5 (Rat)"

# It is imperative to use srun and dplace, otherwise the Python processes
# do not work properly (possible deadlocks and/or performance degradation)
time srun -n 1 --mpi=none dplace python ${JOB_SCRIPT_DIR}/load_graph_archngv.py --filename_ngv ${FILENAME_NGV} --output_graph ${GRAPH_PATH}
if [[ -z $FILENAME_NGV ]]; then
echo "### Loading graph from name: $CIRCUIT_NAME from Nexus"
echo "Execute ${JOB_SCRIPT_DIR}/load_graph_archngv.py with srun"
time srun -n 1 --mpi=none dplace python ${JOB_SCRIPT_DIR}/load_graph_archngv.py --circuit-name "${CIRCUIT_NAME}" --output-graph ${GRAPH_PATH}
else
echo "### Loading graph from filename: $FILENAME_NGV"
echo "Execute ${JOB_SCRIPT_DIR}/load_graph_archngv.py with srun"
time srun -n 1 --mpi=none dplace python ${JOB_SCRIPT_DIR}/load_graph_archngv.py --circuit-path "${FILENAME_NGV}" --output-graph ${GRAPH_PATH}
fi

echo
Loading
Loading