diff --git a/LICENSE b/LICENSE deleted file mode 100644 index ccc6935..0000000 --- a/LICENSE +++ /dev/null @@ -1,28 +0,0 @@ -BSD 3-Clause License - -Copyright (c) 2024, CellMap Project Team - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..6019f08 --- /dev/null +++ b/README.md @@ -0,0 +1,50 @@ +## Still in development + +# cellmap-flow + +```bash +$ cellmap_flow_server + +Usage: cellmap_flow_server [OPTIONS] COMMAND [ARGS]... + + Examples: + To use Dacapo run the following commands: + cellmap_flow_server dacapo -r my_run -i iteration -d data_path + + To use custom script + cellmap_flow_server script -s script_path -d data_path + + To use bioimage-io model + cellmap_flow_server bioimage -m model_path -d data_path + + +Commands: + bioimage Run the CellMapFlow server with a bioimage-io model. + dacapo Run the CellMapFlow server with a DaCapo model. + script Run the CellMapFlow server with a custom script. +``` + + +Currently available: +## Using custom script: +which enable using any model by providing a script e.g. [example/model_spec.py](example/model_spec.py) +e.g. +```bash +cellmap_flow script -c /groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py -d /nrs/cellmap/data/jrc_mus-cerebellum-1/jrc_mus-cerebellum-1.zarr/recon-1/em/fibsem-uint8/s0 +``` + +## Using Dacapo model: +which enable inference using a Dacapo model by providing the run name and iteration number +e.g. +```bash +cellmap_flow dacapo -r 20241204_finetune_mito_affs_task_datasplit_v3_u21_kidney_mito_default_cache_8_1 -i 700000 -d /nrs/cellmap/data/jrc_ut21-1413-003/jrc_ut21-1413-003.zarr/recon-1/em/fibsem-uint8/s0" +``` + +## Using bioimage-io model: +still in development + + +## Limitation: +Currently only supporting data locating in /nrs/cellmap or /groups/cellmap because there is a data server already implemented for them. + + diff --git a/cellmap_flow/__init__.py b/cellmap_flow/__init__.py new file mode 100644 index 0000000..254e8b2 --- /dev/null +++ b/cellmap_flow/__init__.py @@ -0,0 +1,2 @@ +__version__ = "0.1.0" +__version_info__ = tuple(int(i) for i in __version__.split(".")) diff --git a/cellmap_flow/bioimagezoo_processor.py b/cellmap_flow/bioimagezoo_processor.py new file mode 100644 index 0000000..8b1886e --- /dev/null +++ b/cellmap_flow/bioimagezoo_processor.py @@ -0,0 +1,69 @@ +# %% +from bioimageio.core import load_description +from bioimageio.core import predict # , predict_many + +# %% +from bioimageio.core import Tensor +import numpy as np +from bioimageio.core import Sample +from bioimageio.core.digest_spec import get_member_ids + + +# create random input tensor +def process_chunk(model, idi, roi): + input_image = idi.to_ndarray_ts(roi) + # add dimensions to start of input image + # print(input_image.shape) + # input_image = block_reduce(input_image, (4,1,1), np.mean, cval=0) + # print(input_image.shape) + # input_image = (input_image - np.min(input_image))/np.max(input_image) + # input_image = np.random.rand(1, 1, 64, 64, 64).astype(np.float32) + if len(model.outputs[0].axes) == 5: + input_image = input_image[np.newaxis, np.newaxis, ...].astype(np.float32) + test_input_tensor = Tensor.from_numpy( + input_image, dims=["batch", "c", "z", "y", "x"] + ) + else: + input_image = input_image[:, np.newaxis, ...].astype(np.float32) + + test_input_tensor = Tensor.from_numpy( + input_image, dims=["batch", "c", "y", "x"] + ) + sample_input_id = get_member_ids(model.inputs)[0] + sample_output_id = get_member_ids(model.outputs)[0] + + sample = Sample( + members={sample_input_id: test_input_tensor}, stat={}, id="sample-from-numpy" + ) + prediction: Sample = predict( + model=model, inputs=sample, skip_preprocessing=sample.stat is not None + ) + ndim = prediction.members[sample_output_id].data.ndim + output = prediction.members[sample_output_id].data.to_numpy() + if ndim < 5 and len(model.outputs) > 1: + if len(model.outputs) > 1: + outputs = [] + for id in get_member_ids(model.outputs): + output = prediction.members[id].data.to_numpy() + if output.ndim == 3: + output = output[:, np.newaxis, ...] + outputs.append(output) + output = np.concatenate(outputs, axis=1) + output = np.ascontiguousarray(np.swapaxes(output, 1, 0)) + + else: + output = output[0, ...] + # if ndim == 5: + # # then is b,c,z,y,x, and only want z,y,x + # output = 255 * output[0, ...] + # elif ndim == 4: + # # then is b,c,y,x since it is 2d which is really z,c,y,x + # output = 255 * output[0, 0, ...] + output = 255 * output + output = output.astype(np.uint8) + return output + + +# %% +# import torch +# from pydeepimagej.yaml_utils import create_model_yaml diff --git a/cellmap_flow/cli.py b/cellmap_flow/cli.py new file mode 100644 index 0000000..ed81a92 --- /dev/null +++ b/cellmap_flow/cli.py @@ -0,0 +1,314 @@ +import subprocess +import logging +import neuroglancer +import os +import sys +import signal +import select +import itertools +import click + +logging.basicConfig() + +logger = logging.getLogger(__name__) + +processes = [] +job_ids = [] +hosts = [] +security = "http" +import subprocess + +neuroglancer.set_server_bind_address("0.0.0.0") + + +def is_bsub_available(): + try: + # Run 'which bsub' to check if bsub is available in PATH + result = subprocess.run( + ["which", "bsub"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + if result.stdout: + return True + else: + return False + except Exception as e: + print("Error:", e) + + +def cleanup(signum, frame): + print(f"Script is being killed. Received signal: {signum}") + + if is_bsub_available(): + # Run your command here + for job_id in job_ids: + print(f"Killing job {job_id}") + os.system(f"bkill {job_id}") + else: + for process in processes: + process.kill() + sys.exit(0) + + +# Attach signal handlers +signal.signal(signal.SIGINT, cleanup) # Handle Ctrl+C +signal.signal(signal.SIGTERM, cleanup) # Handle termination + + +def get_host_from_stdout(output): + logger.error(f"Output: {output}") + + # Print or parse the output line-by-line + if "Host name: " in output and f"* Running on {security}://" in output: + print("Host found!") + host_name = output.split("Host name: ")[1].split("\n")[0].strip() + port = output.split(f"* Running on {security}://127.0.0.1:")[1].split("\n")[0] + + hosts.append(f"{security}://{host_name}:{port}") + print(f"{hosts=}") + return True + return False + + +def parse_bpeek_output(job_id): + # Run bpeek to get the job's real-time output + command = f"bpeek {job_id}" + + try: + # Process the output in real-time + while True: + # logger.error(f"Running command: {command}") + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + output = process.stdout.read() + error_output = process.stderr.read() + # logger.error(f"Output: {output} {error_output}") + if ( + output == "" + and process.poll() is not None + and f"Job <{job_id}> : Not yet started." not in error_output + ): + logger.error(f"Job <{job_id}> has finished.") + break # End of output + if output: + # Print or parse the output line-by-line + # logger.error(output) + if get_host_from_stdout(output): + + break + + # Example: Parse a specific pattern (e.g., errors or warnings) + if "error" in output.lower(): + print(f"Error found: {output.strip()}") + + # Capture any error output + error_output = process.stderr.read() + if error_output: + print(f"Error: {error_output.strip()}") + + except Exception as e: + print(f"Error while executing bpeek: {e}") + + +def submit_bsub_job( + sc, + job_name="my_job", +): + # Create the bsub command + bsub_command = [ + "bsub", + "-J", + job_name, + # "-o", + # "/dev/stdout", + # "-e", + # "/dev/stderr", + "-P", + "cellmap", + "-q", + "gpu_h100", + "-gpu", + "num=1", + "bash", + "-c", + sc, + ] + + # Submit the job + print("Submitting job with the following command:") + + try: + result = subprocess.run( + bsub_command, capture_output=True, text=True, check=True + ) + print("Job submitted successfully:") + print(result.stdout) + except subprocess.CalledProcessError as e: + print("Error submitting job:") + print(e.stderr) + + return result + + +def generate_neuroglancer_link(dataset_path, inference_dict, output_channels): + # Create a new viewer + viewer = neuroglancer.UnsynchronizedViewer() + + # Add a layer to the viewer + with viewer.txn() as s: + # if multiscale dataset + if ( + dataset_path.split("/")[-1].startswith("s") + and dataset_path.split("/")[-1][1:].isdigit() + ): + dataset_path = dataset_path.rsplit("/", 1)[0] + if ".zarr" in dataset_path: + filetype = "zarr" + elif ".n5" in dataset_path: + filetype = "n5" + else: + filetype = "precomputed" + if dataset_path.startswith("/"): + dataset_path = dataset_path.replace("/nrs/cellmap/", "nrs/").replace( + "/groups/cellmap/cellmap/", "dm11/" + ) + s.layers["raw"] = neuroglancer.ImageLayer( + source=f"{filetype}://{security}://cellmap-vm1.int.janelia.org/{dataset_path}", + ) + else: + s.layers["raw"] = neuroglancer.ImageLayer( + source=f"{filetype}://{dataset_path}", + ) + colors = [ + "red", + "green", + "blue", + "yellow", + "purple", + "orange", + "cyan", + "magenta", + ] + color_cycle = itertools.cycle(colors) + for host, model in inference_dict.items(): + color = next(color_cycle) + s.layers[model] = neuroglancer.ImageLayer( + source=f"n5://{host}/{model}", + shader=f"""#uicontrol invlerp normalized(range=[0, 255], window=[0, 255]); +#uicontrol vec3 color color(default="{color}"); +void main(){{emitRGB(color * normalized());}}""", + ) + print(viewer) # neuroglancer.to_url(viewer.state)) + logger.error(f"link : {viewer}") + while True: + pass + + +def run_locally(sc): + # Command to execute + command = sc.split(" ") + + # Start the subprocess + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + + # Use select to check for output without blocking + # NOTE: For some reason the output is considered stderr here but is stdout ion cluster + # try: + output = "" + while True: + # Check if there is data available to read from stdout and stderr + rlist, _, _ = select.select( + [process.stdout, process.stderr], [], [], 0.1 + ) # Timeout is 0.1s + + # Read from stdout if data is available + if process.stdout in rlist: + output += process.stdout.readline() + if get_host_from_stdout(output): + break + + # Read from stderr if data is available + if process.stderr in rlist: + output += process.stderr.readline() + if get_host_from_stdout(output): + break + # Check if the process has finished and no more output is available + if process.poll() is not None and not rlist: + break + processes.append(process) + # except KeyboardInterrupt: + # print("Process interrupted.") + # finally: + # process.stdout.close() + # process.stderr.close() + # process.wait() # Wait for the process to terminate + + +def start_hosts(dataset, script_path, num_hosts=1): + if security == "https": + sc = f"cellmap_flow_server -d {dataset} -c {script_path} --certfile=host.cert --keyfile=host.key" + else: + sc = f"cellmap_flow_server -d {dataset} -c {script_path}" + + if is_bsub_available(): + for _ in range(num_hosts): + result = submit_bsub_job(sc, job_name="example_job") + job_ids.append(result.stdout.split()[1][1:-1]) + for job_id in job_ids: + parse_bpeek_output(job_id) + else: + run_locally(sc) + + +@click.command() +@click.option( + "-d", + "--dataset_path", + type=str, + help="Data path, including scale", + required=True, +) +@click.option( + "-c", "--code", type=str, help="Path to the script to run", required=False +) +@click.option( + "-ch", + "--output_channels", + type=int, + help="Number of output channels", + required=False, + default=0, +) +def main(dataset_path, code, output_channels): + + if dataset_path.endswith("/"): + dataset_path = dataset_path[:-1] + # print(f"Dataset: {dataset_path}, Scale: {scale}, Models: {models}") + logging.info("Starting hosts...") + start_hosts(dataset_path, code, num_hosts=1) + + logging.info("Starting hosts completed!") + inference_dict = {} + models = ["model"] + if len(hosts) != len(models): + raise ValueError( + "Number of hosts and models should be the same, but something went wrong" + ) + + # print(hosts, models) + for host, model in zip(hosts, models): + inference_dict[host] = f"{dataset_path}" + + generate_neuroglancer_link(dataset_path, inference_dict, output_channels) + + +if __name__ == "__main__": + main() diff --git a/example_virtual_n5.py b/cellmap_flow/example_virtual_n5.py similarity index 90% rename from example_virtual_n5.py rename to cellmap_flow/example_virtual_n5.py index 317ce81..517817b 100644 --- a/example_virtual_n5.py +++ b/cellmap_flow/example_virtual_n5.py @@ -55,6 +55,10 @@ CHUNK_ENCODER = N5ChunkWrapper(np.float32, BLOCK_SHAPE, compressor=numcodecs.GZip()) +import os + +environ = os.environ + def main(): parser = argparse.ArgumentParser() @@ -70,13 +74,33 @@ def main(): ) +@app.route("/tester") +def tester(): + return render_template("iframe.html", url="https://wikipedia.org") + + @app.route("/home") def home(): return render_template("iframe.html") +@app.route("/glb_loader") +def glb_loader(): + return render_template("glbs.html") + + @app.route("/attributes.json") def top_level_attributes(): + # if environ["REQUEST_METHOD"] == "OPTIONS": + # # Respuesta para las solicitudes preflight (OPTIONS) + # response_headers = [ + # ("Access-Control-Allow-Origin", "*"), + # ("Access-Control-Allow-Methods", "POST, OPTIONS"), + # ("Access-Control-Allow-Headers", "Content-Type"), + # ] + # start_response("200 OK", response_headers) + # print("a responder") + # return [] scales = [[2**s, 2**s, 2**s, 1] for s in range(MAX_SCALE + 1)] attr = { "pixelResolution": {"dimensions": [1.0, 1.0, 1.0, 1.0], "unit": "nm"}, diff --git a/cellmap_flow/example_virtual_n5_generic.py b/cellmap_flow/example_virtual_n5_generic.py new file mode 100644 index 0000000..3fd3bba --- /dev/null +++ b/cellmap_flow/example_virtual_n5_generic.py @@ -0,0 +1,209 @@ +""" +# Example Virtual N5 + +Example service showing how to host a virtual N5, +suitable for browsing in neuroglancer. + +Neuroglancer is capable of browsing N5 files, as long as you store them on +disk and then host those files over http (with a CORS-friendly http server). +But what if your data doesn't exist on disk yet? + +This server hosts a "virtual" N5. Nothing is stored on disk, +but neuroglancer doesn't need to know that. This server provides the +necessary attributes.json files and chunk files on-demand, in the +"locations" (url patterns) that neuroglancer expects. + +For simplicity, this file uses Flask. In a production system, +you'd probably want to use something snazzier, like FastAPI. + +To run the example, install a few dependencies: + + conda create -n example-virtual-n5 -c conda-forge zarr flask flask-cors + conda activate example-virtual-n5 + +Then just execute the file: + + python example_virtual_n5.py + +Or, for better performance, use a proper http server: + + conda install -c conda-forge gunicorn + gunicorn --bind 0.0.0.0:8000 --workers 8 --threads 1 example_virtual_n5:app + +You can browse the data in neuroglancer after configuring the viewer with the appropriate layer [settings][1]. + +[1]: http://neuroglancer-demo.appspot.com/#!%7B%22dimensions%22:%7B%22x%22:%5B1e-9%2C%22m%22%5D%2C%22y%22:%5B1e-9%2C%22m%22%5D%2C%22z%22:%5B1e-9%2C%22m%22%5D%7D%2C%22position%22:%5B5000.5%2C7500.5%2C10000.5%5D%2C%22crossSectionScale%22:25%2C%22projectionScale%22:32767.999999999996%2C%22layers%22:%5B%7B%22type%22:%22image%22%2C%22source%22:%7B%22url%22:%22n5://http://127.0.0.1:8000%22%2C%22transform%22:%7B%22outputDimensions%22:%7B%22x%22:%5B1e-9%2C%22m%22%5D%2C%22y%22:%5B1e-9%2C%22m%22%5D%2C%22z%22:%5B1e-9%2C%22m%22%5D%2C%22c%5E%22:%5B1%2C%22%22%5D%7D%7D%7D%2C%22tab%22:%22rendering%22%2C%22opacity%22:0.42%2C%22shader%22:%22void%20main%28%29%20%7B%5Cn%20%20emitRGB%28%5Cn%20%20%20%20vec3%28%5Cn%20%20%20%20%20%20getDataValue%280%29%2C%5Cn%20%20%20%20%20%20getDataValue%281%29%2C%5Cn%20%20%20%20%20%20getDataValue%282%29%5Cn%20%20%20%20%29%5Cn%20%20%29%3B%5Cn%7D%5Cn%22%2C%22channelDimensions%22:%7B%22c%5E%22:%5B1%2C%22%22%5D%7D%2C%22name%22:%22colorful-data%22%7D%5D%2C%22layout%22:%224panel%22%7D +""" + +# %% +# NOTE: To generate host key and host cert do the following: https://serverfault.com/questions/224122/what-is-crt-and-key-files-and-how-to-generate-them +# openssl genrsa 2048 > host.key +# chmod 400 host.key +# openssl req -new -x509 -nodes -sha256 -days 365 -key host.key -out host.cert +# Then can run like this: +# gunicorn --certfile=host.cert --keyfile=host.key --bind 0.0.0.0:8000 --workers 1 --threads 1 example_virtual_n5:app +# NOTE: You will probably have to access the host:8000 separately and say it is safe to go there + +# %% load image zoo +from bioimageio.core import load_description +from bioimageio.core import predict # , predict_many + + +import argparse +from http import HTTPStatus +from flask import Flask, jsonify +from flask_cors import CORS + +import numpy as np +import numcodecs +from scipy import spatial +from zarr.n5 import N5ChunkWrapper +from funlib.geometry import Roi +import numpy as np + + +# NOTE: Normally we would just load in run but here we have to recreate it to save time since our run has so many points +import socket +from image_data_interface import ImageDataInterface + +app = Flask(__name__) +CORS(app) + +from inferencer import Inferencer + +import socket + +# Get the hostname +hostname = socket.gethostname() + +# Get the local IP address + +print(f"Host name: {hostname}", flush=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--debug", action="store_true") + parser.add_argument("-p", "--port", default=8000) + args = parser.parse_args() + app.run( + host="0.0.0.0", + # port=args.port, + debug=args.debug, + threaded=not args.debug, + use_reloader=args.debug, + ) + + +# %% +SCALE_LEVEL = None +IDI_RAW = None +OUTPUT_VOXEL_SIZE = None +VOL_SHAPE_ZYX = None +VOL_SHAPE = None +VOL_SHAPE_ZYX_IN_BLOCKS = None +VOXEL_SIZE = None +BLOCK_SHAPE = None +MAX_SCALE = None +CHUNK_ENCODER = None +EQUIVALENCES = None +DS = None +INFERENCER = Inferencer() + + +# determined-chimpmunk is edges +# kind-seashell is mito +# happy-elephant is cells +@app.route("//attributes.json") +def top_level_attributes(dataset): + if "__" not in dataset: + return jsonify({"n5": "2.1.0"}), HTTPStatus.OK + + if not (dataset.startswith("gs://") or dataset.startswith("s3://")): + dataset = "/" + dataset + + dataset_name, s, BMZ_MODEL_ID = dataset.split("__") + + global OUTPUT_VOXEL_SIZE, BLOCK_SHAPE, VOL_SHAPE, CHUNK_ENCODER, IDI_RAW, INFERENCER + print(dataset_name, s, BMZ_MODEL_ID) + SCALE_LEVEL = int(s[1:]) + INFERENCER = Inferencer(BMZ_MODEL_ID) + # MODEL = Inferencer(BMZ_MODEL_ID) + + IDI_RAW = ImageDataInterface(f"{dataset_name}/s{SCALE_LEVEL}") + OUTPUT_VOXEL_SIZE = IDI_RAW.voxel_size + + # %% + BLOCK_SHAPE = np.array([128, 128, 128, 9]) + MAX_SCALE = 0 + + VOL_SHAPE_ZYX = np.array(IDI_RAW.shape) + VOL_SHAPE = np.array([*VOL_SHAPE_ZYX[::-1], 9]) + # VOL_SHAPE_ZYX_IN_BLOCKS = np.ceil(VOL_SHAPE_ZYX / BLOCK_SHAPE[:3]).astype(int) + # VOXEL_SIZE = IDI_RAW.voxel_size + + CHUNK_ENCODER = N5ChunkWrapper(np.uint8, BLOCK_SHAPE, compressor=numcodecs.GZip()) + + scales = [[2**s, 2**s, 2**s, 1] for s in range(MAX_SCALE + 1)] + attr = { + "pixelResolution": { + "dimensions": [*OUTPUT_VOXEL_SIZE, 1], + "unit": "nm", + }, + "ordering": "C", + "scales": scales, + "axes": ["x", "y", "z", "c^"], + "units": ["nm", "nm", "nm", ""], + "translate": [0, 0, 0, 0], + } + return jsonify(attr), HTTPStatus.OK + + +@app.route("//s/attributes.json") +def attributes(dataset, scale): + attr = { + "transform": { + "ordering": "C", + "axes": ["x", "y", "z", "c^"], + "scale": [ + *OUTPUT_VOXEL_SIZE, + 1, + ], + "units": ["nm", "nm", "nm", ""], + "translate": [0.0, 0.0, 0.0, 0.0], + }, + "compression": {"type": "gzip", "useZlib": False, "level": -1}, + "blockSize": BLOCK_SHAPE[:].tolist(), + "dataType": "uint8", + "dimensions": VOL_SHAPE.tolist(), + } + return jsonify(attr), HTTPStatus.OK + + +@app.route( + "//s/////" +) +def chunk(dataset, scale, chunk_x, chunk_y, chunk_z, chunk_c): + """ + Serve up a single chunk at the requested scale and location. + + This 'virtual N5' will just display a color gradient, + fading from black at (0,0,0) to white at (max,max,max). + """ + # assert chunk_c == 0, "neuroglancer requires that all blocks include all channels" + corner = BLOCK_SHAPE[:3] * np.array([chunk_z, chunk_y, chunk_x]) + box = np.array([corner, BLOCK_SHAPE[:3]]) * OUTPUT_VOXEL_SIZE + roi = Roi(box[0], box[1]) + print("about_to_process_chunk") + chunk = INFERENCER.process_chunk(IDI_RAW, roi) + print(chunk) + return ( + # Encode to N5 chunk format (header + compressed data) + CHUNK_ENCODER.encode(chunk), + HTTPStatus.OK, + {"Content-Type": "application/octet-stream"}, + ) + + +if __name__ == "__main__": + main() diff --git a/cellmap_flow/example_virtual_n5_generic_using_python_script copy.py b/cellmap_flow/example_virtual_n5_generic_using_python_script copy.py new file mode 100644 index 0000000..48e7bbe --- /dev/null +++ b/cellmap_flow/example_virtual_n5_generic_using_python_script copy.py @@ -0,0 +1,116 @@ +# %% load image zoo + +from cellmap_flow.utils import load_safe_config +import argparse +from http import HTTPStatus +from flask import Flask, jsonify +from flask_cors import CORS + +import numpy as np +import numcodecs +from scipy import spatial +from zarr.n5 import N5ChunkWrapper +from funlib.geometry import Roi +import numpy as np +import logging + +logger = logging.getLogger(__name__) + +# NOTE: Normally we would just load in run but here we have to recreate it to save time since our run has so many points +import socket +from cellmap_flow.image_data_interface import ImageDataInterface + + +from cellmap_flow.inferencer import Inferencer + + +# %% +script_path = "/groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py" +config = load_safe_config(script_path) +SCALE_LEVEL = 0 +IDI_RAW = None +OUTPUT_VOXEL_SIZE = None +VOL_SHAPE_ZYX = None +VOL_SHAPE = None +VOL_SHAPE_ZYX_IN_BLOCKS = None +VOXEL_SIZE = None +BLOCK_SHAPE = config.block_shape +MAX_SCALE = None +CHUNK_ENCODER = None +EQUIVALENCES = None +DS = None +INFERENCER = Inferencer(script_path=script_path) +dataset_name = "/nrs/cellmap/data/jrc_mus-cerebellum-1/jrc_mus-cerebellum-1.zarr/recon-1/em/fibsem-uint8" +IDI_RAW = ImageDataInterface(f"{dataset_name}/s{SCALE_LEVEL}") +OUTPUT_VOXEL_SIZE = config.output_voxel_size + +# determined-chimpmunk is edges + + +# %% +MAX_SCALE = 0 + +VOL_SHAPE_ZYX = np.array(IDI_RAW.shape) +VOL_SHAPE = np.array([*VOL_SHAPE_ZYX[::-1], 8]) +# VOL_SHAPE_ZYX_IN_BLOCKS = np.ceil(VOL_SHAPE_ZYX / BLOCK_SHAPE[:3]).astype(int) +# VOXEL_SIZE = IDI_RAW.voxel_size + +CHUNK_ENCODER = N5ChunkWrapper(np.uint8, BLOCK_SHAPE, compressor=numcodecs.GZip()) + +scales = [[2**s, 2**s, 2**s, 1] for s in range(MAX_SCALE + 1)] +attr = { + "pixelResolution": { + "dimensions": [*OUTPUT_VOXEL_SIZE, 1], + "unit": "nm", + }, + "ordering": "C", + "scales": scales, + "axes": ["x", "y", "z", "c^"], + "units": ["nm", "nm", "nm", ""], + "translate": [0, 0, 0, 0], +} + + +attr + +# %% + +attr = { + "transform": { + "ordering": "C", + "axes": ["x", "y", "z", "c^"], + "scale": [ + *OUTPUT_VOXEL_SIZE, + 1, + ], + "units": ["nm", "nm", "nm", ""], + "translate": [0.0, 0.0, 0.0, 0.0], + }, + "compression": {"type": "gzip", "useZlib": False, "level": -1}, + "blockSize": BLOCK_SHAPE[:].tolist(), + "dataType": "uint8", + "dimensions": VOL_SHAPE.tolist(), +} +attr + +chunk_x = 2 +chunk_y = 2 +chunk_z = 2 + + +corner = BLOCK_SHAPE[:3] * np.array([chunk_z, chunk_y, chunk_x]) +box = np.array([corner, BLOCK_SHAPE[:3]]) * OUTPUT_VOXEL_SIZE +roi = Roi(box[0], box[1]) +print("about_to_process_chunk") +chunk = INFERENCER.process_chunk_basic(IDI_RAW, roi) +# logger.error(f"chunk {chunk}") +print(chunk) +x = ( + # Encode to N5 chunk format (header + compressed data) + CHUNK_ENCODER.encode(chunk), + HTTPStatus.OK, + {"Content-Type": "application/octet-stream"}, +) + + +# %% diff --git a/cellmap_flow/example_virtual_n5_generic_using_python_script.py b/cellmap_flow/example_virtual_n5_generic_using_python_script.py new file mode 100644 index 0000000..9147d6c --- /dev/null +++ b/cellmap_flow/example_virtual_n5_generic_using_python_script.py @@ -0,0 +1,217 @@ +""" +# Example Virtual N5 + +Example service showing how to host a virtual N5, +suitable for browsing in neuroglancer. + +Neuroglancer is capable of browsing N5 files, as long as you store them on +disk and then host those files over http (with a CORS-friendly http server). +But what if your data doesn't exist on disk yet? + +This server hosts a "virtual" N5. Nothing is stored on disk, +but neuroglancer doesn't need to know that. This server provides the +necessary attributes.json files and chunk files on-demand, in the +"locations" (url patterns) that neuroglancer expects. + +For simplicity, this file uses Flask. In a production system, +you'd probably want to use something snazzier, like FastAPI. + +To run the example, install a few dependencies: + + conda create -n example-virtual-n5 -c conda-forge zarr flask flask-cors + conda activate example-virtual-n5 + +Then just execute the file: + + python example_virtual_n5.py + +Or, for better performance, use a proper http server: + + conda install -c conda-forge gunicorn + gunicorn --bind 0.0.0.0:8000 --workers 8 --threads 1 example_virtual_n5:app + +You can browse the data in neuroglancer after configuring the viewer with the appropriate layer [settings][1]. +""" + +# %% +# NOTE: To generate host key and host cert do the following: https://serverfault.com/questions/224122/what-is-crt-and-key-files-and-how-to-generate-them +# openssl genrsa 2048 > host.key +# chmod 400 host.key +# openssl req -new -x509 -nodes -sha256 -days 365 -key host.key -out host.cert +# Then can run like this: +# gunicorn --certfile=host.cert --keyfile=host.key --bind 0.0.0.0:8000 --workers 1 --threads 1 example_virtual_n5:app +# NOTE: You will probably have to access the host:8000 separately and say it is safe to go there + +# %% load image zoo + +from cellmap_flow.utils import load_safe_config +import argparse +from http import HTTPStatus +from flask import Flask, jsonify +from flask_cors import CORS + +import numpy as np +import numcodecs +from scipy import spatial +from zarr.n5 import N5ChunkWrapper +from funlib.geometry import Roi +import numpy as np +import logging + +logger = logging.getLogger(__name__) + +# NOTE: Normally we would just load in run but here we have to recreate it to save time since our run has so many points +import socket +from cellmap_flow.image_data_interface import ImageDataInterface + +app = Flask(__name__) +CORS(app) + +from cellmap_flow.inferencer import Inferencer + +import socket + +# Get the hostname +hostname = socket.gethostname() + +# Get the local IP address + +print(f"Host name: {hostname}", flush=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--debug", action="store_true") + parser.add_argument("-p", "--port", default=8000) + args = parser.parse_args() + app.run( + host="0.0.0.0", + # port=args.port, + debug=args.debug, + threaded=not args.debug, + use_reloader=args.debug, + ) + + +# %% +script_path = "/groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py" +config = load_safe_config(script_path) +SCALE_LEVEL = None +IDI_RAW = None +OUTPUT_VOXEL_SIZE = None +VOL_SHAPE_ZYX = None +VOL_SHAPE = None +VOL_SHAPE_ZYX_IN_BLOCKS = None +VOXEL_SIZE = None +BLOCK_SHAPE = config.block_shape +MAX_SCALE = None +CHUNK_ENCODER = None +EQUIVALENCES = None +DS = None +INFERENCER = Inferencer(script_path=script_path) + + +# determined-chimpmunk is edges +# kind-seashell is mito +# happy-elephant is cells +@app.route("//attributes.json") +def top_level_attributes(dataset): + if "__" not in dataset: + return jsonify({"n5": "2.1.0"}), HTTPStatus.OK + + if not (dataset.startswith("gs://") or dataset.startswith("s3://")): + dataset = "/" + dataset + + dataset_name, s, BMZ_MODEL_ID = dataset.split("__") + + global OUTPUT_VOXEL_SIZE, BLOCK_SHAPE, VOL_SHAPE, CHUNK_ENCODER, IDI_RAW, INFERENCER + print(dataset_name, s, BMZ_MODEL_ID) + + # self.read_shape = config.read_shape + # self.write_shape = config.write_shape + + # self.context = (self.read_shape - self.write_shape) / 2 + SCALE_LEVEL = int(s[1:]) + + # MODEL = Inferencer(BMZ_MODEL_ID) + + IDI_RAW = ImageDataInterface(f"{dataset_name}/s{SCALE_LEVEL}") + OUTPUT_VOXEL_SIZE = config.output_voxel_size + + # %% + MAX_SCALE = 0 + + VOL_SHAPE_ZYX = np.array(IDI_RAW.shape) + VOL_SHAPE = np.array([*VOL_SHAPE_ZYX[::-1], 8]) + # VOL_SHAPE_ZYX_IN_BLOCKS = np.ceil(VOL_SHAPE_ZYX / BLOCK_SHAPE[:3]).astype(int) + # VOXEL_SIZE = IDI_RAW.voxel_size + + CHUNK_ENCODER = N5ChunkWrapper(np.uint8, BLOCK_SHAPE, compressor=numcodecs.GZip()) + + scales = [[2**s, 2**s, 2**s, 1] for s in range(MAX_SCALE + 1)] + attr = { + "pixelResolution": { + "dimensions": [*OUTPUT_VOXEL_SIZE, 1], + "unit": "nm", + }, + "ordering": "C", + "scales": scales, + "axes": ["x", "y", "z", "c^"], + "units": ["nm", "nm", "nm", ""], + "translate": [0, 0, 0, 0], + } + return jsonify(attr), HTTPStatus.OK + + +@app.route("//s/attributes.json") +def attributes(dataset, scale): + attr = { + "transform": { + "ordering": "C", + "axes": ["x", "y", "z", "c^"], + "scale": [ + *OUTPUT_VOXEL_SIZE, + 1, + ], + "units": ["nm", "nm", "nm", ""], + "translate": [0.0, 0.0, 0.0, 0.0], + }, + "compression": {"type": "gzip", "useZlib": False, "level": -1}, + "blockSize": BLOCK_SHAPE[:].tolist(), + "dataType": "uint8", + "dimensions": VOL_SHAPE.tolist(), + } + return jsonify(attr), HTTPStatus.OK + + +@app.route( + "//s/////" +) +def chunk(dataset, scale, chunk_x, chunk_y, chunk_z, chunk_c): + """ + Serve up a single chunk at the requested scale and location. + + This 'virtual N5' will just display a color gradient, + fading from black at (0,0,0) to white at (max,max,max). + """ + try: + # assert chunk_c == 0, "neuroglancer requires that all blocks include all channels" + corner = BLOCK_SHAPE[:3] * np.array([chunk_z, chunk_y, chunk_x]) + box = np.array([corner, BLOCK_SHAPE[:3]]) * OUTPUT_VOXEL_SIZE + roi = Roi(box[0], box[1]) + print("about_to_process_chunk") + chunk = INFERENCER.process_chunk_basic(IDI_RAW, roi) + # logger.error(f"chunk {chunk}") + print(chunk) + return ( + # Encode to N5 chunk format (header + compressed data) + CHUNK_ENCODER.encode(chunk), + HTTPStatus.OK, + {"Content-Type": "application/octet-stream"}, + ) + except Exception as e: + return jsonify(error=str(e)), HTTPStatus.INTERNAL_SERVER_ERROR + + +if __name__ == "__main__": + main() diff --git a/cellmap_flow/image_data_interface.py b/cellmap_flow/image_data_interface.py new file mode 100644 index 0000000..87ecfcb --- /dev/null +++ b/cellmap_flow/image_data_interface.py @@ -0,0 +1,389 @@ +# %% +import logging +import tensorstore as ts +import numpy as np +from funlib.geometry import Coordinate +from funlib.geometry import Roi +import os +import s3fs +import re +import zarr + +# Ensure tensorstore does not attempt to use GCE credentials +os.environ["GCE_METADATA_ROOT"] = "metadata.google.internal.invalid" + +from funlib.persistence import open_ds +from skimage.measure import block_reduce + +# Much below taken from flyemflows: https://github.com/janelia-flyem/flyemflows/blob/master/flyemflows/util/util.py +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) + + +def ends_with_scale(string): + pattern = ( + r"s\d+$" # Matches 's' followed by one or more digits at the end of the string + ) + return bool(re.search(pattern, string)) + + +def split_dataset_path(dataset_path, scale=None) -> tuple[str, str]: + """Split the dataset path into the filename and dataset + + Args: + dataset_path ('str'): Path to the dataset + scale ('int'): Scale to use, if present + + Returns: + Tuple of filename and dataset + """ + + # split at .zarr or .n5, whichever comes last + splitter = ( + ".zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else ".n5" + ) + + filename, dataset = dataset_path.split(splitter) + if dataset.startswith("/"): + dataset = dataset[1:] + # include scale if present + if scale is not None: + dataset += f"/s{scale}" + + return filename + splitter, dataset + + +def open_ds_tensorstore(dataset_path: str, mode="r", concurrency_limit=None): + # open with zarr or n5 depending on extension + filetype = ( + "zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else "n5" + ) + extra_args = {} + + if dataset_path.startswith("s3://"): + kvstore = { + "driver": "s3", + "bucket": dataset_path.split("/")[2], + "path": "/".join(dataset_path.split("/")[3:]), + "aws_credentials": { + "anonymous": True, + }, + } + elif dataset_path.startswith("gs://"): + # check if path ends with s#int + if ends_with_scale(dataset_path): + scale_index = int(dataset_path.rsplit("/s")[1]) + dataset_path = dataset_path.rsplit("/s")[0] + else: + scale_index = 0 + filetype = "neuroglancer_precomputed" + kvstore = dataset_path + extra_args = {"scale_index": scale_index} + else: + kvstore = { + "driver": "file", + "path": os.path.normpath(dataset_path), + } + + if concurrency_limit: + spec = { + "driver": filetype, + "context": { + "data_copy_concurrency": {"limit": concurrency_limit}, + "file_io_concurrency": {"limit": concurrency_limit}, + }, + "kvstore": kvstore, + **extra_args, + } + else: + spec = {"driver": filetype, "kvstore": kvstore, **extra_args} + + if mode == "r": + dataset_future = ts.open(spec, read=True, write=False) + else: + dataset_future = ts.open(spec, read=False, write=True) + + if dataset_path.startswith("gs://"): + # NOTE: Currently a hack since google store is for some reason stored as mutlichannel + return dataset_future.result()[ts.d["channel"][0]] + else: + return dataset_future.result() + + +def to_ndarray_tensorstore( + dataset, + roi=None, + voxel_size=None, + offset=None, + output_voxel_size=None, + swap_axes=False, + custom_fill_value=None, +): + """Read a region of a tensorstore dataset and return it as a numpy array + + Args: + dataset ('tensorstore.dataset'): Tensorstore dataset + roi ('funlib.geometry.Roi'): Region of interest to read + + Returns: + Numpy array of the region + """ + if swap_axes: + print("Swapping axes") + if roi: + roi = Roi(roi.begin[::-1], roi.shape[::-1]) + if offset: + offset = Coordinate(offset[::-1]) + + if roi is None: + with ts.Transaction() as txn: + return dataset.with_transaction(txn).read().result() + + if offset is None: + offset = Coordinate(np.zeros(roi.dims, dtype=int)) + + if output_voxel_size is None: + output_voxel_size = voxel_size + + rescale_factor = 1 + if voxel_size != output_voxel_size: + # in the case where there is a mismatch in voxel sizes, we may need to extra pad to ensure that the output is a multiple of the output voxel size + original_roi = roi + roi = original_roi.snap_to_grid(voxel_size) + rescale_factor = voxel_size[0] / output_voxel_size[0] + snapped_offset = (original_roi.begin - roi.begin) / output_voxel_size + snapped_end = (original_roi.end - roi.begin) / output_voxel_size + snapped_slices = tuple( + slice(snapped_offset[i], snapped_end[i]) for i in range(3) + ) + + roi -= offset + roi /= voxel_size + + # Specify the range + roi_slices = roi.to_slices() + + domain = dataset.domain + # Compute the valid range + valid_slices = tuple( + slice(max(s.start, inclusive_min), min(s.stop, exclusive_max)) + for s, inclusive_min, exclusive_max in zip( + roi_slices, domain.inclusive_min, domain.exclusive_max + ) + ) + + # Create an array to hold the requested data, filled with a default value (e.g., zeros) + # output_shape = [s.stop - s.start for s in roi_slices] + + if not dataset.fill_value: + fill_value = 0 + if custom_fill_value: + fill_value = custom_fill_value + with ts.Transaction() as txn: + data = dataset.with_transaction(txn)[valid_slices].read().result() + pad_width = [ + [valid_slice.start - s.start, s.stop - valid_slice.stop] + for s, valid_slice in zip(roi_slices, valid_slices) + ] + if np.any(np.array(pad_width)): + if fill_value == "edge": + data = np.pad( + data, + pad_width=pad_width, + mode="edge", + ) + else: + data = np.pad( + data, + pad_width=pad_width, + mode="constant", + constant_values=fill_value, + ) + # else: + # padded_data = ( + # np.ones(output_shape, dtype=dataset.dtype.numpy_dtype) * fill_value + # ) + # padded_slices = tuple( + # slice(valid_slice.start - s.start, valid_slice.stop - s.start) + # for s, valid_slice in zip(roi_slices, valid_slices) + # ) + + # # Read the region of interest from the dataset + # padded_data[padded_slices] = dataset[valid_slices].read().result() + + if rescale_factor > 1: + rescale_factor = voxel_size[0] / output_voxel_size[0] + data = ( + data.repeat(rescale_factor, 0) + .repeat(rescale_factor, 1) + .repeat(rescale_factor, 2) + ) + data = data[snapped_slices] + + elif rescale_factor < 1: + data = block_reduce(data, block_size=int(1 / rescale_factor), func=np.median) + data = data[snapped_slices] + + if swap_axes: + data = np.swapaxes(data, 0, 2) + + return data + + +def get_ds_info(path): + swap_axes = False + if path.startswith("s3://"): + ts_info = open_ds_tensorstore(path) + shape = ts_info.shape + path, filename = split_dataset_path(path) + filename, scale = filename.rsplit("/s") + scale = int(scale) + fs = s3fs.S3FileSystem( + anon=True + ) # Set anon=True if you don't need authentication + store = s3fs.S3Map(root=path, s3=fs) + zarr_dataset = zarr.open( + store, + mode="r", + ) + multiscale_attrs = zarr_dataset[filename].attrs.asdict() + if "multiscales" in multiscale_attrs: + multiscales = multiscale_attrs["multiscales"][0] + axes = [axis["name"] for axis in multiscales["axes"]] + for scale_info in multiscale_attrs["multiscales"][0]["datasets"]: + if scale_info["path"] == f"s{scale}": + voxel_size = Coordinate( + scale_info["coordinateTransformations"][0]["scale"] + ) + if axes[:3] == ["x", "y", "z"]: + swap_axes = True + chunk_shape = Coordinate(ts_info.chunk_layout.read_chunk.shape) + roi = Roi((0, 0, 0), Coordinate(shape) * voxel_size) + elif path.startswith("gs://"): + ts_info = open_ds_tensorstore(path) + shape = ts_info.shape + voxel_size = Coordinate( + (d.to_json()[0] if d is not None else 1 for d in ts_info.dimension_units) + ) + if ts_info.spec().transform.input_labels[:3] == ("x", "y", "z"): + swap_axes = True + chunk_shape = Coordinate(ts_info.chunk_layout.read_chunk.shape) + roi = Roi([0] * len(shape), Coordinate(shape) * voxel_size) + else: + path, filename = split_dataset_path(path) + ds = open_ds(path, filename) + voxel_size = ds.voxel_size + chunk_shape = ds.chunk_shape + roi = ds.roi + shape = ds.shape + if swap_axes: + voxel_size = Coordinate(voxel_size[::-1]) + chunk_shape = Coordinate(chunk_shape[::-1]) + shape = shape[::-1] + roi = Roi(roi.begin[::-1], roi.shape[::-1]) + return voxel_size, chunk_shape, shape, roi, swap_axes + + +class ImageDataInterface: + def __init__( + self, + dataset_path, + mode="r", + output_voxel_size=None, + custom_fill_value=None, + concurrency_limit=1, + ): + self.path = dataset_path + self.filetype = ( + "zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else "n5" + ) + self.swap_axes = self.filetype == "n5" + self.ts = None + self.voxel_size, self.chunk_shape, self.shape, self.roi, self.swap_axes = ( + get_ds_info(dataset_path) + ) + self.offset = self.roi.offset + self.custom_fill_value = custom_fill_value + self.concurrency_limit = concurrency_limit + if output_voxel_size is not None: + self.output_voxel_size = output_voxel_size + else: + self.output_voxel_size = self.voxel_size + + def to_ndarray_ts(self, roi=None): + if not self.ts: + self.ts = open_ds_tensorstore( + self.path, concurrency_limit=self.concurrency_limit + ) + res = to_ndarray_tensorstore( + self.ts, + roi, + self.voxel_size, + self.offset, + self.output_voxel_size, + self.swap_axes, + self.custom_fill_value, + ) + self.ts = None + return res + + # ds not found + # def to_ndarray_ds(self, roi=None): + # return self.ds.to_ndarray(roi) + + +# %% +# multiscale_attrs = zarr_dataset["/recon-1/em/fibsem-uint8"].attrs.asdict() +# if "multiscales" in multiscale_attrs: +# multiscales = multiscale_attrs["multiscales"][0] +# axes = [axis["name"] for axis in multiscales["axes"]] +# for scale_info in multiscale_attrs["multiscales"][0]["datasets"]: +# if scale_info["path"] == "s2": +# voxel_size = Coordinate(scale_info["coordinateTransformations"][0]["scale"]) +# print(multiscales) +# chunk_shape = Coordinate(ts_info.chunk_layout.read_chunk.shape) +# roi = Roi((0, 0, 0), Coordinate(ts_info.shape) * voxel_size) + +# # %% +# idi = ImageDataInterface( +# "gs://neuroglancer-janelia-flyem-hemibrain/emdata/clahe_yz/jpeg/" +# ) +# # %% + +# em_8nm = ts.open( +# { +# "driver": "neuroglancer_precomputed", +# "kvstore": "gs://neuroglancer-janelia-flyem-hemibrain/emdata/clahe_yz/jpeg", +# }, +# read=True, +# dimension_units=["16 nm", "16 nm", "16 nm", None], +# ).result() +# em_8nm.spec().transform.input_labels +# # # %% +# # from tensorstore import Unit + +# # u = Unit(8, "nm") +# # print(ts_info.dimension_units[0].to_json()) + +# # # %% +# idi = ImageDataInterface( +# "gs://neuroglancer-janelia-flyem-hemibrain/emdata/clahe_yz/jpeg/" +# ) +# print(idi.voxel_size) +# idi = ImageDataInterface( +# "s3://janelia-cosem-datasets/jrc_hela-2/jrc_hela-2.zarr/recon-1/em/fibsem-uint8/s0", # /recon-1/em/fibsem-uint8/s0", +# ) +# print(idi.voxel_size) +idi = ImageDataInterface( + "/nrs/cellmap/data/jrc_mus-liver-zon-1/jrc_mus-liver-zon-1.zarr/recon-1/em/fibsem-uint8/s1", +) + +# print(idi.voxel_size) +# # %% +# id +# %% +# %% +# voxel_size = Coordinate(1, 1, 3) +# voxel_size = voxel_size[::-1] +# type(voxel_size) +# %% diff --git a/cellmap_flow/inferencer.py b/cellmap_flow/inferencer.py new file mode 100644 index 0000000..f2e80f9 --- /dev/null +++ b/cellmap_flow/inferencer.py @@ -0,0 +1,157 @@ +import numpy as np +import torch +from cellmap_flow.utils.data import ( + ModelConfig, + BioModelConfig, + DaCapoModelConfig, + ScriptModelConfig, +) +from funlib.persistence import Array + + +class Inferencer: + def __init__(self, model_config: ModelConfig): + self.model_config = model_config + self.load_model(model_config) + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + self.model.to(self.device) + print(f"Using device: {self.device}") + + def process_chunk(self, idi, roi): + if isinstance(self.config, BioModelConfig): + return self.process_chunk_bioimagezoo(idi, roi) + elif isinstance(self.config, DaCapoModelConfig) or isinstance( + self.config, ScriptModelConfig + ): + return self.process_chunk_basic(idi, roi) + else: + raise ValueError(f"Invalid model config type {type(self.config)}") + + def process_chunk_basic(self, idi, roi): + output_roi = roi + + input_roi = output_roi.grow(self.context, self.context) + # input_roi = output_roi + context + raw_input = idi.to_ndarray_ts(input_roi).astype(np.float32) / 255.0 + raw_input = np.expand_dims(raw_input, (0, 1)) + + with torch.no_grad(): + predictions = Array( + self.model.forward(torch.from_numpy(raw_input).float().to(self.device)) + .detach() + .cpu() + .numpy()[0], + output_roi, + self.output_voxel_size, + ) + write_data = predictions.to_ndarray(output_roi).clip(-1, 1) + write_data = (write_data + 1) * 255.0 / 2.0 + return write_data.astype(np.uint8) + + # create random input tensor + def process_chunk_bioimagezoo(self, idi, roi): + input_image = idi.to_ndarray_ts(roi) + if len(self.model.outputs[0].axes) == 5: + input_image = input_image[np.newaxis, np.newaxis, ...].astype(np.float32) + test_input_tensor = Tensor.from_numpy( + input_image, dims=["batch", "c", "z", "y", "x"] + ) + else: + input_image = input_image[:, np.newaxis, ...].astype(np.float32) + + test_input_tensor = Tensor.from_numpy( + input_image, dims=["batch", "c", "y", "x"] + ) + sample_input_id = get_member_ids(self.model.inputs)[0] + sample_output_id = get_member_ids(self.model.outputs)[0] + + sample = Sample( + members={sample_input_id: test_input_tensor}, + stat={}, + id="sample-from-numpy", + ) + prediction: Sample = predict( + model=self.model, inputs=sample, skip_preprocessing=sample.stat is not None + ) + ndim = prediction.members[sample_output_id].data.ndim + output = prediction.members[sample_output_id].data.to_numpy() + if ndim < 5 and len(self.model.outputs) > 1: + if len(self.model.outputs) > 1: + outputs = [] + for id in get_member_ids(self.model.outputs): + output = prediction.members[id].data.to_numpy() + if output.ndim == 3: + output = output[:, np.newaxis, ...] + outputs.append(output) + output = np.concatenate(outputs, axis=1) + output = np.ascontiguousarray(np.swapaxes(output, 1, 0)) + + else: + output = output[0, ...] + + output = 255 * output + output = output.astype(np.uint8) + return output + + def load_model(self, config: ModelConfig): + if isinstance(config, DaCapoModelConfig): + # self.load_dacapo_model(config.run_name, iteration=config.iteration) + self.load_script_model(config) + elif isinstance(config, ScriptModelConfig): + self.load_script_model(config) + elif isinstance(config, BioModelConfig): + self.load_bio_model(config.model_name) + else: + raise ValueError(f"Invalid model config type {type(config)}") + + def load_dacapo_model(self, bio_model_name, iteration="best"): + from dacapo.store.create_store import create_config_store, create_weights_store + from dacapo.experiments import Run + + config_store = create_config_store() + + weights_store = create_weights_store() + run_config = config_store.retrieve_run_config(bio_model_name) + + run = Run(run_config) + self.model = run.model + + weights = weights_store.retrieve_weights( + bio_model_name, + iteration, + ) + self.model.load_state_dict(weights.model) + self.model.eval() + # output_voxel_size = self.model.scale(input_voxel_size) + # input_shape = Coordinate(model.eval_input_shape) + # output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + # context = (input_size - output_size) / 2 + # TODO load this part from dacapo + # self.read_shape = config.read_shape + # self.write_shape = config.write_shape + # self.output_voxel_size = config.output_voxel_size + # self.context = (self.read_shape - self.write_shape) / 2 + + def load_bio_model(self, bio_model_name): + from bioimageio.core import load_description + from bioimageio.core import predict # , predict_many + from bioimageio.core import Tensor + from bioimageio.core import Sample + from bioimageio.core.digest_spec import get_member_ids + + self.model = load_description(bio_model_name) + + def load_script_model(self, model_config: ScriptModelConfig): + config = model_config.config + self.model = config.model + self.read_shape = config.read_shape + self.write_shape = config.write_shape + self.output_voxel_size = config.output_voxel_size + self.context = (self.read_shape - self.write_shape) / 2 + + +# %% diff --git a/cellmap_flow/n_cli.py b/cellmap_flow/n_cli.py new file mode 100644 index 0000000..161c1ee --- /dev/null +++ b/cellmap_flow/n_cli.py @@ -0,0 +1,133 @@ +import click +import logging + +import neuroglancer +import os +import sys +import signal +import select +import itertools +import click + +from cellmap_flow.utils.bsub_utils import is_bsub_available, submit_bsub_job, parse_bpeek_output, run_locally, job_ids, security +from cellmap_flow.utils.neuroglancer_utils import generate_neuroglancer_link + +logging.basicConfig() + +logger = logging.getLogger(__name__) + + + +SERVER_COMMAND = "cellmap_flow_server" + + +@click.group() +@click.option( + "--log-level", + type=click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False + ), + default="INFO", +) +def cli(log_level): + """ + Command-line interface for the Cellmap flo application. + + Args: + log_level (str): The desired log level for the application. + Examples: + To use Dacapo run the following commands: + ``` + cellmap_flow dacapo -r my_run -i iteration -d data_path + ``` + + To use custom script + ``` + cellmap_flow script -s script_path -d data_path + ``` + + To use bioimage-io model + ``` + cellmap_flow bioimage -m model_path -d data_path + ``` + """ + logging.basicConfig(level=getattr(logging, log_level.upper())) + + +logger = logging.getLogger(__name__) + + +@cli.command() +@click.option( + "-r", "--run-name", required=True, type=str, help="The NAME of the run to train." +) +@click.option( + "-i", + "--iteration", + required=False, + type=int, + help="The iteration at which to train the run.", + default=0, +) +@click.option( + "-d", "--data_path", required=True, type=str, help="The path to the data." +) +def dacapo(run_name, iteration, data_path): + command = f"{SERVER_COMMAND} dacapo -r {run_name} -i {iteration} -d {data_path}" + run(command,data_path) + raise NotImplementedError("This command is not yet implemented.") + + +@cli.command() +@click.option( + "-s", + "--script_path", + required=True, + type=str, + help="The path to the script to run.", +) +@click.option( + "-d", "--data_path", required=True, type=str, help="The path to the data." +) +def script(script_path, data_path): + command = f"{SERVER_COMMAND} script -s {script_path} -d {data_path}" + run(command,data_path) + raise NotImplementedError("This command is not yet implemented.") + + +@cli.command() +@click.option( + "-m", "--model_path", required=True, type=str, help="The path to the model." +) +@click.option( + "-d", "--data_path", required=True, type=str, help="The path to the data." +) +def bioimage(model_path, data_path): + raise NotImplementedError("This command is not yet implemented.") + + + +def run(command,dataset_path): + + host = start_hosts(command) + if host is None: + raise Exception("Could not start host") + + inference_dict = {host:"prediction"} + + generate_neuroglancer_link(dataset_path, inference_dict) + +def start_hosts(command): + if security == "https": + command = f"{command} --certfile=host.cert --keyfile=host.key" + + + if is_bsub_available(): + result = submit_bsub_job(command, job_name="example_job") + job_id = result.stdout.split()[1][1:-1] + job_ids.append(job_id) + host = parse_bpeek_output(job_id) + else: + host= run_locally(command) + + return host diff --git a/cellmap_flow/neuroglancer_link_generator.py b/cellmap_flow/neuroglancer_link_generator.py new file mode 100644 index 0000000..e69de29 diff --git a/cellmap_flow/server.py b/cellmap_flow/server.py new file mode 100644 index 0000000..fbd8be7 --- /dev/null +++ b/cellmap_flow/server.py @@ -0,0 +1,199 @@ +import argparse +import logging +import socket +from http import HTTPStatus + +import numpy as np +import numcodecs +from flask import Flask, jsonify +from flask_cors import CORS +from zarr.n5 import N5ChunkWrapper +from funlib.geometry import Roi + +from cellmap_flow.image_data_interface import ImageDataInterface +from cellmap_flow.inferencer import Inferencer +from funlib.geometry.coordinate import Coordinate +from cellmap_flow.utils.data import ( + ModelConfig, + BioModelConfig, + DaCapoModelConfig, + ScriptModelConfig, + IP_PATTERN, +) +from cellmap_flow.utils.web_utils import get_free_port, get_public_ip +import click + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------------------ +# Example usage: +# conda install -c conda-forge gunicorn +# gunicorn --bind 0.0.0.0:8000 --workers 8 --threads 1 example_virtual_n5:app +# +# Or just run: +# python example_virtual_n5.py +# ------------------------------------------------------------------------------ + + +class CellMapFlowServer: + """ + Flask application hosting a "virtual N5" for neuroglancer. + + Attributes: + script_path (str): Path to a Python script containing the model specification + block_shape (tuple): The block shape for chunking + app (Flask): The Flask application instance + inferencer (Inferencer): Your CellMapFlow inferencer object + """ + + def __init__(self, dataset_name: str, model_config: ModelConfig): + """ + Initialize the server. + + Args: + script_path (str): Path to the Python script containing model specification + block_shape (tuple): Shape of the blocks used for chunking + """ + self.block_shape = [(int(x)) for x in model_config.config.block_shape] + self.output_voxel_size = Coordinate(model_config.config.output_voxel_size) + self.output_channels = model_config.config.output_channels + + self.inferencer = Inferencer(model_config) + + self.idi_raw = ImageDataInterface(dataset_name) + if ".zarr" in dataset_name: + self.vol_shape = np.array( + [*np.array(self.idi_raw.shape)[::-1], self.output_channels] + ) # converting from z,y,x order to x,y,z order zarr to n5 + self.axis = ["x", "y", "z", "c^"] + else: + self.vol_shape = np.array( + [*np.array(self.idi_raw.shape), self.output_channels] + ) + self.axis = ["z", "y", "x", "c^"] + self.chunk_encoder = N5ChunkWrapper( + np.uint8, self.block_shape, compressor=numcodecs.GZip() + ) + + # Create and configure Flask + self.app = Flask(__name__) + CORS(self.app) + + # To help debug which machine we're on + hostname = socket.gethostname() + print(f"Host name: {hostname}", flush=True) + + # Register Flask routes + self._register_routes() + + def _register_routes(self): + """ + Register all routes for the Flask application. + """ + + # Top-level attributes (and dataset-level attributes) + self.app.add_url_rule( + "//attributes.json", + view_func=self.top_level_attributes, + methods=["GET"], + ) + + # Scale-level attributes + self.app.add_url_rule( + "//s/attributes.json", + view_func=self.attributes, + methods=["GET"], + ) + + # Chunk data route + chunk_route = "//s/////" + self.app.add_url_rule(chunk_route, view_func=self.chunk, methods=["GET"]) + + def top_level_attributes(self, dataset): + """ + Return top-level N5 attributes, or dataset-level attributes. + + The Neuroglancer N5 data source expects '/attributes.json' at the dataset root. + """ + # For simplicity, let's say we only allow a single scale (s0), or up to some MAX_SCALE + max_scale = 0 + # We define the chunk encoder + # Prepare scales array + scales = [[2**s, 2**s, 2**s, 1] for s in range(max_scale + 1)] + + # Construct top-level attributes + attr = { + "pixelResolution": { + "dimensions": [*self.output_voxel_size, 1], + "unit": "nm", + }, + "ordering": "C", + "scales": scales, + "axes": self.axis, + "units": ["nm", "nm", "nm", ""], + "translate": [0, 0, 0, 0], + } + + return jsonify(attr), HTTPStatus.OK + + def attributes(self, dataset, scale): + """ + Return the attributes of a specific scale (like /s0/attributes.json). + """ + attr = { + "transform": { + "ordering": "C", + "axes": self.axis, + "scale": [*self.output_voxel_size, 1], + "units": ["nm", "nm", "nm", ""], + "translate": [0.0, 0.0, 0.0, 0.0], + }, + "compression": {"type": "gzip", "useZlib": False, "level": -1}, + "blockSize": list(self.block_shape), + "dataType": "uint8", + "dimensions": self.vol_shape.tolist(), + } + print(f"Attributes: {attr}", flush=True) + return jsonify(attr), HTTPStatus.OK + + def chunk(self, dataset, scale, chunk_x, chunk_y, chunk_z, chunk_c): + """ + Serve up a single chunk at the requested scale and location. + This 'virtual N5' will just run an inference function and return the data. + """ + # try: + # assert chunk_c == 0, "neuroglancer requires that all blocks include all channels" + corner = self.block_shape[:3] * np.array([chunk_z, chunk_y, chunk_x]) + box = np.array([corner, self.block_shape[:3]]) * self.output_voxel_size + roi = Roi(box[0], box[1]) + + chunk = self.inferencer.process_chunk_basic(self.idi_raw, roi) + return ( + # Encode to N5 chunk format (header + compressed data) + self.chunk_encoder.encode(chunk), + HTTPStatus.OK, + {"Content-Type": "application/octet-stream"}, + ) + # except Exception as e: + # return jsonify(error=str(e)), HTTPStatus.INTERNAL_SERVER_ERROR + + def run(self, debug=False, port=8000, certfile=None, keyfile=None): + """ + Run the Flask development server with optional SSL cert/key. + """ + ssl_context = None + if certfile and keyfile: + # (certfile, keyfile) tuple enables HTTPS in the built-in dev server + ssl_context = (certfile, keyfile) + + self.app.run( + host="0.0.0.0", + port=port, + debug=debug, + use_reloader=debug, + ssl_context=ssl_context, # <-- pass SSL context to Flask dev server + ) + address = f"{'https' if ssl_context else 'http'}://{get_public_ip()}:{port}" + logger.error(IP_PATTERN.format(ip_address=address)) + print(IP_PATTERN.format(ip_address=address), flush=True) diff --git a/cellmap_flow/server_cli.py b/cellmap_flow/server_cli.py new file mode 100644 index 0000000..7f6761e --- /dev/null +++ b/cellmap_flow/server_cli.py @@ -0,0 +1,117 @@ +import click +import logging + +from cellmap_flow.utils.data import ScriptModelConfig, DaCapoModelConfig, BioModelConfig +from cellmap_flow.server import CellMapFlowServer +from cellmap_flow.utils.web_utils import get_free_port + + +@click.group() +@click.option( + "--log-level", + type=click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False + ), + default="INFO", +) +def cli(log_level): + """ + Command-line interface for the Cellmap flo application. + + Args: + log_level (str): The desired log level for the application. + Examples: + To use Dacapo run the following commands: + ``` + cellmap_flow_server dacapo -r my_run -i iteration -d data_path + ``` + + To use custom script + ``` + cellmap_flow_server script -s script_path -d data_path + ``` + + To use bioimage-io model + ``` + cellmap_flow_server bioimage -m model_path -d data_path + ``` + """ + logging.basicConfig(level=getattr(logging, log_level.upper())) + + +logger = logging.getLogger(__name__) + + +@cli.command() +@click.option( + "-r", "--run-name", required=True, type=str, help="The NAME of the run to train." +) +@click.option( + "-i", + "--iteration", + required=False, + type=int, + help="The iteration at which to train the run.", + default=0, +) +@click.option( + "-d", "--data_path", required=True, type=str, help="The path to the data." +) +@click.option("--debug", is_flag=True, help="Run in debug mode.") +@click.option("-p", "--port", default=0, type=int, help="Port to listen on.") +@click.option("--certfile", default=None, help="Path to SSL certificate file.") +@click.option("--keyfile", default=None, help="Path to SSL private key file.") +def dacapo(run_name, iteration, data_path, debug, port, certfile, keyfile): + """Run the CellMapFlow server with a DaCapo model.""" + model_config = DaCapoModelConfig(run_name=run_name, iteration=iteration) + run_server(model_config, data_path, debug, port, certfile, keyfile) + + +@cli.command() +@click.option( + "-s", + "--script_path", + required=True, + type=str, + help="The path to the script to run.", +) +@click.option( + "-d", "--data_path", required=True, type=str, help="The path to the data." +) +@click.option("--debug", is_flag=True, help="Run in debug mode.") +@click.option("-p", "--port", default=0, type=int, help="Port to listen on.") +@click.option("--certfile", default=None, help="Path to SSL certificate file.") +@click.option("--keyfile", default=None, help="Path to SSL private key file.") +def script(script_path, data_path, debug, port, certfile, keyfile): + """Run the CellMapFlow server with a custom script.""" + model_config = ScriptModelConfig(script_path=script_path) + run_server(model_config, data_path, debug, port, certfile, keyfile) + + +@cli.command() +@click.option( + "-m", "--model_path", required=True, type=str, help="The path to the model." +) +@click.option( + "-d", "--data_path", required=True, type=str, help="The path to the data." +) +@click.option("--debug", is_flag=True, help="Run in debug mode.") +@click.option("-p", "--port", default=0, type=int, help="Port to listen on.") +@click.option("--certfile", default=None, help="Path to SSL certificate file.") +@click.option("--keyfile", default=None, help="Path to SSL private key file.") +def bioimage(model_path, data_path, debug, port, certfile, keyfile): + """Run the CellMapFlow server with a bioimage-io model.""" + raise NotImplementedError("This command is not yet implemented.") + + +def run_server(model_config, data_path, debug, port, certfile, keyfile): + server = CellMapFlowServer(data_path, model_config) + if port == 0: + port = get_free_port() + + server.run( + debug=debug, + port=port, + certfile=certfile, + keyfile=keyfile, + ) diff --git a/templates/iframe.html b/cellmap_flow/templates/iframe.html similarity index 100% rename from templates/iframe.html rename to cellmap_flow/templates/iframe.html diff --git a/cellmap_flow/utils/__init__.py b/cellmap_flow/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cellmap_flow/utils/bsub_utils.py b/cellmap_flow/utils/bsub_utils.py new file mode 100644 index 0000000..625d904 --- /dev/null +++ b/cellmap_flow/utils/bsub_utils.py @@ -0,0 +1,175 @@ +import subprocess +import logging +import neuroglancer +import os +import sys +import signal +import select +import itertools +import click + +from cellmap_flow.utils.data import IP_PATTERN + +processes = [] +job_ids = [] +security = "http" + +def cleanup(signum, frame): + print(f"Script is being killed. Received signal: {signum}") + if is_bsub_available(): + for job_id in job_ids: + print(f"Killing job {job_id}") + os.system(f"bkill {job_id}") + else: + for process in processes: + process.kill() + sys.exit(0) + +signal.signal(signal.SIGINT, cleanup) # Handle Ctrl+C +signal.signal(signal.SIGTERM, cleanup) + +logger = logging.getLogger(__name__) + +def is_bsub_available(): + try: + # Run 'which bsub' to check if bsub is available in PATH + result = subprocess.run( + ["which", "bsub"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + if result.stdout: + return True + else: + return False + except Exception as e: + print("Error:", e) + + +def submit_bsub_job( + command, + job_name="my_job", +): + bsub_command = [ + "bsub", + "-J", + job_name, + "-P", + "cellmap", + "-q", + "gpu_h100", + "-gpu", + "num=1", + "bash", + "-c", + command, + ] + + print("Submitting job with the following command:") + + try: + result = subprocess.run( + bsub_command, capture_output=True, text=True, check=True + ) + print("Job submitted successfully:") + print(result.stdout) + except subprocess.CalledProcessError as e: + print("Error submitting job:") + print(e.stderr) + + return result + + +def parse_bpeek_output(job_id): + command = f"bpeek {job_id}" + host = None + try: + # Process the output in real-time + while True: + # logger.error(f"Running command: {command}") + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + output = process.stdout.read() + error_output = process.stderr.read() + if ( + output == "" + and process.poll() is not None + and f"Job <{job_id}> : Not yet started." not in error_output + ): + logger.error(f"Job <{job_id}> has finished.") + break # End of output + if output: + host = get_host_from_stdout(output) + if host: + break + if "error" in output.lower(): + print(f"Error found: {output.strip()}") + + error_output = process.stderr.read() + if error_output: + print(f"Error: {error_output.strip()}") + + except Exception as e: + print(f"Error while executing bpeek: {e}") + + return host + + +# def get_host_from_stdout(output): +# parts = IP_PATTERN.split("ip_address") + +# if parts[0] in output and parts[1] in output: +# host = output.split(parts[0])[1].split(parts[1])[0] +# return host +# return None + +def get_host_from_stdout(output): + if "Host name: " in output and f"* Running on {security}://" in output: + host_name = output.split("Host name: ")[1].split("\n")[0].strip() + port = output.split(f"* Running on {security}://127.0.0.1:")[1].split("\n")[0] + + host = f"{security}://{host_name}:{port}" + print(f"{host}") + return host + return None + + +def run_locally(sc): + command = sc.split(" ") + + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + + output = "" + while True: + # Check if there is data available to read from stdout and stderr + rlist, _, _ = select.select( + [process.stdout, process.stderr], [], [], 0.1 + ) # Timeout is 0.1s + + # Read from stdout if data is available + if process.stdout in rlist: + output += process.stdout.readline() + host = get_host_from_stdout(output) + if host: + break + + # Read from stderr if data is available + if process.stderr in rlist: + output += process.stderr.readline() + host = get_host_from_stdout(output) + if host: + break + # Check if the process has finished and no more output is available + if process.poll() is not None and not rlist: + break + processes.append(process) + return host + + diff --git a/cellmap_flow/utils/data.py b/cellmap_flow/utils/data.py new file mode 100644 index 0000000..4a0d1b5 --- /dev/null +++ b/cellmap_flow/utils/data.py @@ -0,0 +1,128 @@ +IP_PATTERN = "CELLMAP_FLOW_SERVER_IP(ip_address)CELLMAP_FLOW_SERVER_IP" + +import logging + +logger = logging.getLogger(__name__) + + +class ModelConfig: + def __init__(self): + self._config = None + + def _get_config(self): + raise NotImplementedError() + + @property + def config(self): + if self._config is None: + self._config = self._get_config() + check_config(self._config) + return self._config + + + +class BioModelConfig(ModelConfig): + def __init__(self, model_name: str): + super().__init__() + self.model_name = model_name + + +class ScriptModelConfig(ModelConfig): + + def __init__(self, script_path): + super().__init__() + self.script_path = script_path + + def _get_config(self): + from cellmap_flow.utils.load_py import load_safe_config + config = load_safe_config(self.script_path) + return config + + +class DaCapoModelConfig(ModelConfig): + + def __init__(self, run_name: str, iteration: int): + super().__init__() + self.run_name = run_name + self.iteration = iteration + + def _get_config(self): + from daisy.coordinate import Coordinate + import numpy as np + import torch + + config = Config() + + run = get_dacapo_run_model(self.run_name, self.iteration) + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + print("device:", device) + + run.model.to(device) + run.model.eval() + config.model = run.model + + in_shape = run.model.eval_input_shape + out_shape = run.model.compute_output_shape(in_shape)[1] + + voxel_size = run.datasplit.train[0].raw.voxel_size + config.input_voxel_size = voxel_size + config.read_shape = Coordinate(in_shape) * Coordinate(voxel_size) + config.write_shape = Coordinate(out_shape) * Coordinate(voxel_size) + config.output_voxel_size = Coordinate(run.model.scale(voxel_size)) + channels = get_dacapo_channels(run.task) + config.output_channels = len( + channels + ) # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld + config.block_shape = np.array(tuple(out_shape) + (len(channels),)) + + return config + + +def check_config(config): + assert hasattr(config, "model"), f"Model not found in config" + assert hasattr( + config, "read_shape" + ), f"read_shape not found in config" + assert hasattr( + config, "write_shape" + ), f"write_shape not found in config" + assert hasattr( + config, "output_voxel_size" + ), f"output_voxel_size not found in config" + assert hasattr( + config, "output_channels" + ), f"output_channels not found in config" + assert hasattr( + config, "block_shape" + ), f"block_shape not found in config" + +class Config: + pass + + +def get_dacapo_channels(task): + if hasattr(task, "channels"): + return task.channels + elif type(task).__name__ == "AffinitiesTask": + return ["x", "y", "z"] + else: + return ["membrane"] + + +def get_dacapo_run_model(run_name, iteration): + from dacapo.experiments import Run + from dacapo.store.create_store import create_config_store, create_weights_store + + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + if iteration > 0: + + weights_store = create_weights_store() + weights = weights_store.retrieve_weights(run, iteration) + run.model.load_state_dict(weights.model) + + return run diff --git a/cellmap_flow/utils/load_py.py b/cellmap_flow/utils/load_py.py new file mode 100644 index 0000000..22de252 --- /dev/null +++ b/cellmap_flow/utils/load_py.py @@ -0,0 +1,95 @@ +## copied from https://github.com/janelia-cellmap/cellmap-segmentation-challenge/blob/a9525b31502abb7ea01e10c16340bbc1056cf1fc/src/cellmap_segmentation_challenge/utils/security.py + +import ast +import importlib +from importlib.machinery import SourceFileLoader +import os + +from upath import UPath + +# Define restricted imports and functions +DISALLOWED_IMPORTS = {"os", "subprocess", "sys"} +# DISALLOWED_FUNCTIONS = {"eval", "exec", "open", "compile", "__import__"} +DISALLOWED_FUNCTIONS = {"eval", "exec", "compile", "__import__"} + + +def analyze_script(filepath): + """ + Analyzes the script at `filepath` using `ast` for potentially unsafe imports and function calls. + Returns a boolean indicating whether the script is safe and a list of detected issues. + """ + issues = [] + with open(filepath, "r") as file: + source_code = file.read() + + # Parse the code into an AST + tree = ast.parse(source_code, filename=filepath) + + # Traverse the AST and analyze nodes + for node in ast.walk(tree): + # Check for disallowed imports + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name in DISALLOWED_IMPORTS: + issues.append(f"Disallowed import detected: {alias.name}") + + elif isinstance(node, ast.ImportFrom): + if node.module in DISALLOWED_IMPORTS: + issues.append(f"Disallowed import detected: {node.module}") + + # Check for disallowed function calls + elif isinstance(node, ast.Call): + # If function is a direct name (e.g., `eval()`) + if isinstance(node.func, ast.Name) and node.func.id in DISALLOWED_FUNCTIONS: + issues.append(f"Disallowed function call detected: {node.func.id}") + # If function is an attribute call (e.g., `os.system()`) + elif ( + isinstance(node.func, ast.Attribute) + and node.func.attr in DISALLOWED_FUNCTIONS + ): + issues.append(f"Disallowed function call detected: {node.func.attr}") + + # Return whether the script is safe (no issues found) and the list of issues + is_safe = len(issues) == 0 + return is_safe, issues + + +def load_safe_config(config_path, force_safe=os.getenv("FORCE_SAFE_CONFIG", False)): + """ + Loads the configuration script at `config_path` after verifying its safety. + If `force_safe` is True, raises an error if the script is deemed unsafe. + """ + # print(f"Analyzing script for obvious security liabilities:\n\t{config_path}") + # print( + # "Keep in mind that this is not a foolproof security measure. Use caution using code from untrusted sources." + # ) + is_safe, issues = analyze_script(config_path) + if not is_safe: + print("Script contains unsafe elements:") + for issue in issues: + print(f" - {issue}") + if force_safe: + raise ValueError("Unsafe script detected; loading aborted.") + + # Load the config module if script is safe + config_path = UPath(config_path) + # Create a dedicated namespace for the config + config_namespace = {} + try: + with open(config_path, "r") as config_file: + code = config_file.read() + exec(code, config_namespace) + # Extract the config object from the namespace + config = Config(**config_namespace) + except Exception as e: + print(e) + raise RuntimeError( + f"Failed to execute configuration file: {config_path}" + ) from e + + return config + + +class Config: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) diff --git a/cellmap_flow/utils/neuroglancer_utils.py b/cellmap_flow/utils/neuroglancer_utils.py new file mode 100644 index 0000000..60fc859 --- /dev/null +++ b/cellmap_flow/utils/neuroglancer_utils.py @@ -0,0 +1,78 @@ + +import neuroglancer +import itertools +import logging + +neuroglancer.set_server_bind_address("0.0.0.0") + +logger = logging.getLogger(__name__) +def generate_neuroglancer_link(dataset_path, inference_dict): + # Create a new viewer + viewer = neuroglancer.UnsynchronizedViewer() + + # Add a layer to the viewer + with viewer.txn() as s: + # if multiscale dataset + if ( + dataset_path.split("/")[-1].startswith("s") + and dataset_path.split("/")[-1][1:].isdigit() + ): + dataset_path = dataset_path.rsplit("/", 1)[0] + if ".zarr" in dataset_path: + filetype = "zarr" + elif ".n5" in dataset_path: + filetype = "n5" + else: + filetype = "precomputed" + if dataset_path.startswith("/"): + if "nrs/cellmap" in dataset_path: + security = "https" + dataset_path = dataset_path.replace("/nrs/cellmap/", "nrs/") + elif "/groups/cellmap/cellmap" in dataset_path: + security = "http" + dataset_path = dataset_path.replace("/groups/cellmap/cellmap/", "dm11/") + else: + raise ValueError("Currently only supporting nrs/cellmap and /groups/cellmap/cellmap") + + s.layers["raw"] = neuroglancer.ImageLayer( + source=f"{filetype}://{security}://cellmap-vm1.int.janelia.org/{dataset_path}", + ) + else: + s.layers["raw"] = neuroglancer.ImageLayer( + source=f"{filetype}://{dataset_path}", + ) + colors = [ + "red", + "green", + "blue", + "yellow", + "purple", + "orange", + "cyan", + "magenta", + ] + color_cycle = itertools.cycle(colors) + for host, model in inference_dict.items(): + color = next(color_cycle) + s.layers[model] = neuroglancer.ImageLayer( + source=f"n5://{host}/{model}", + shader=f"""#uicontrol invlerp normalized(range=[0, 255], window=[0, 255]); +#uicontrol vec3 color color(default="{color}"); +void main(){{emitRGB(color * normalized());}}""", + ) + # print(viewer) # neuroglancer.to_url(viewer.state)) + show(str(viewer)) + # logger.error(f"\n \n \n link : {viewer}") + while True: + pass + + +def show(viewer): + print() + print() + print("**********************************************") + print("LINK:") + print(viewer) + print("**********************************************") + print() + print() \ No newline at end of file diff --git a/cellmap_flow/utils/web_utils.py b/cellmap_flow/utils/web_utils.py new file mode 100644 index 0000000..39d9e9b --- /dev/null +++ b/cellmap_flow/utils/web_utils.py @@ -0,0 +1,17 @@ +def get_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("localhost", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def get_public_ip(): + import requests + + try: + return requests.get("https://api.ipify.org").text + except: + return None diff --git a/example/cellmap/20240114_fly.sh b/example/cellmap/20240114_fly.sh new file mode 100755 index 0000000..1cb73da --- /dev/null +++ b/example/cellmap/20240114_fly.sh @@ -0,0 +1,3 @@ +cellmap_flow -c /groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py -d /nrs/cellmap/data/jrc_mus-salivary-1/jrc_mus-salivary-1.zarr/recon-1/em/fibsem-uint8/s0 +cellmap_flow -c /groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py -d /nrs/cellmap/data/jrc_mus-pancreas-5/jrc_mus-pancreas-5.zarr/recon-1/em/fibsem-uint8/s0 +cellmap_flow -c /groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py -d /nrs/cellmap/data/jrc_mus-pancreas-6/jrc_mus-pancreas-6.zarr/recon-1/em/fibsem-uint8/s0 diff --git a/example/cellmap/20240114_submit_fly.sh b/example/cellmap/20240114_submit_fly.sh new file mode 100755 index 0000000..0c080c9 --- /dev/null +++ b/example/cellmap/20240114_submit_fly.sh @@ -0,0 +1,3 @@ +bsub -P cellmap -J jrc_mus-salivary-1 -We 120 cellmap_flow -c /groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py -d /nrs/cellmap/data/jrc_mus-salivary-1/jrc_mus-salivary-1.zarr/recon-1/em/fibsem-uint8/s0 +bsub -P cellmap -J jrc_mus-pancreas-5 -We 120 cellmap_flow -c /groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py -d /nrs/cellmap/data/jrc_mus-pancreas-5/jrc_mus-pancreas-5.zarr/recon-1/em/fibsem-uint8/s0 +bsub -P cellmap -J jrc_mus-pancreas-6 -We 120 cellmap_flow -c /groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py -d /nrs/cellmap/data/jrc_mus-pancreas-6/jrc_mus-pancreas-6.zarr/recon-1/em/fibsem-uint8/s0 diff --git a/example/dacapo_run.py b/example/dacapo_run.py new file mode 100644 index 0000000..6224143 --- /dev/null +++ b/example/dacapo_run.py @@ -0,0 +1,84 @@ +#%% +from daisy.coordinate import Coordinate +import numpy as np +import torch + +channels = ['ecs', 'pm', 'mito', 'mito_mem', 'ves', 'ves_mem', 'endo', 'endo_mem', 'er', 'er_mem', 'eres', 'nuc', 'mt', 'mt_out'] + +# %% +from dacapo.experiments.tasks import DistanceTaskConfig + +task_config = DistanceTaskConfig( + name="tmp_cosem_distance", + channels=channels , + clip_distance=40.0, + tol_distance=40.0, + scale_factor=80.0, +) + +# %% +from dacapo.experiments.architectures import CNNectomeUNetConfig + +architecture_config = CNNectomeUNetConfig( + name="upsample_unet", + input_shape=Coordinate(216, 216, 216), + eval_shape_increase=Coordinate(72, 72, 72), + fmaps_in=1, + num_fmaps=12, + fmaps_out=72, + fmap_inc_factor=6, + downsample_factors=[(2, 2, 2), (3, 3, 3), (3, 3, 3)], + constant_upsample=True, + upsample_factors=[(2, 2, 2)], +) + +#%% +from dacapo.experiments.starts import CosemStartConfig + +start_config = CosemStartConfig("setup04", "1820500") +# %% +from dacapo.experiments import RunConfig +from dacapo.experiments.run import Run + + +run_config = RunConfig( + task_config=task_config, + architecture_config=architecture_config, + start_config=start_config, +) + + + +run = Run(run_config) +model = run.model +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") + +model.to(device) +model.eval() + +#%% +# pip install fly-organelles + + +output_voxel_size = Coordinate((4, 4, 4)) +voxel_size = Coordinate((8, 8, 8)) + +read_shape = Coordinate((216, 216, 216)) * Coordinate(voxel_size) +write_shape = Coordinate((68, 68, 68)) * Coordinate(output_voxel_size) + + + +output_channels = 14 # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld +block_shape = np.array((68, 68, 68,14)) + +# %% +start_config. +# %% +from dacapo.store.create_store import create_config_store +# %% +config_store = create_config_store() +# %% +config_store.retrieve_run_config_names \ No newline at end of file diff --git a/example/dacapo_run_retrieve.py b/example/dacapo_run_retrieve.py new file mode 100644 index 0000000..2fc6582 --- /dev/null +++ b/example/dacapo_run_retrieve.py @@ -0,0 +1,52 @@ +#%% + +from dacapo.experiments import Run +from dacapo.store.create_store import create_config_store, create_weights_store +from daisy.coordinate import Coordinate +import numpy as np + +def get_dacapo_run_model(run_name, iteration): + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + if iteration > 0: + weights_store = create_weights_store() + weights = weights_store.retrieve_weights(run, iteration) + run.model.load_state_dict(weights.model) + + return run + +model_name = "20241204_finetune_mito_affs_task_datasplit_v3_u21_kidney_mito_default_cache_8_1" +channels = ["mito"] +checkpoint= 700000 +run = get_dacapo_run_model(model_name, checkpoint) +model = run.model + +in_shape = model.eval_input_shape +out_shape = model.compute_output_shape(in_shape)[1] + +voxel_size = run.datasplit.train[0].raw.voxel_size +read_shape = Coordinate(in_shape) * Coordinate(voxel_size) +write_shape = Coordinate(out_shape) * Coordinate(voxel_size) +output_voxel_size = Coordinate(model.scale(voxel_size)) + +#%% +import torch +#%% + + + +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") +print("device:", device) + +model.to(device) +model.eval() + + +output_channels = len(channels) # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld +block_shape = np.array(tuple(out_shape) +(output_channels,)) +# %% + diff --git a/example/generate_slider.py b/example/generate_slider.py new file mode 100644 index 0000000..a2dd12a --- /dev/null +++ b/example/generate_slider.py @@ -0,0 +1,53 @@ +#%% +def generate_script(num_channels=2): + """ + Generates a script with the following pattern: + - For each channel i in [0..num_channels-1]: + #uicontrol bool show_ch{i} checkbox() + #uicontrol invlerp ch{i}(channel={i}, range=[0, 255], window=[0, 255]); + #uicontrol vec3 color{i} color(default="red") + - main() function that emits the sum of each channel's color * channel value * bool show. + + :param num_channels: The number of channels to generate UI controls and code for. + :return: A string containing the generated script. + """ + lines = [] + + # Generate #uicontrol lines + # for i in range(num_channels): + # lines.append(f"#uicontrol bool show_ch{i} checkbox()") + for i in range(num_channels): + lines.append(f"#uicontrol invlerp ch{i}(channel={i}, range=[0, 255], window=[0, 255]);") + for i in range(num_channels): + lines.append(f"#uicontrol vec3 color{i} color(default=\"red\")") + + # Build the emit line + # Example piece for channel i: "color{i} * ch{i}() * float(show_ch{i})" + # emit_parts = [f"color{i} * ch{i}() * float(show_ch{i})/float({num_channels})" for i in range(num_channels)] + emit_parts = [f"color{i} * ch{i}()/float({num_channels})" for i in range(num_channels)] + + # Wrap up in the main + lines.append("") + lines.append("void main() {") + lines.append(f" emitRGB({' + '.join(emit_parts)});") + lines.append("}") + + # Return the generated script as a single string + return "\n".join(lines) + + +# Example usage: +if __name__ == "__main__": + # Generate a script for 2 channels (default) + script_2_channels = generate_script(8) + print("Generated script for 2 channels:") + print(script_2_channels) + + # print("\n" + "-"*50 + "\n") + + # # Generate a script for, say, 4 channels + # script_4_channels = generate_script(4) + # print("Generated script for 4 channels:") + # print(script_4_channels) + +# %% diff --git a/example/model_setup04.py b/example/model_setup04.py new file mode 100644 index 0000000..c41bee0 --- /dev/null +++ b/example/model_setup04.py @@ -0,0 +1,30 @@ +#%% +# pip install fly-organelles +from daisy.coordinate import Coordinate +import numpy as np + +output_voxel_size = Coordinate((4, 4, 4)) +voxel_size = Coordinate((8, 8, 8)) + +read_shape = Coordinate((216, 216, 216)) * Coordinate(voxel_size) +write_shape = Coordinate((68, 68, 68)) * Coordinate(output_voxel_size) + + +#%% +import torch +import cellmap_models.pytorch.cosem as cosem_models +model = cosem_models.load_model('setup04/1820500') + +#%% +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") + +model.to(device) +model.eval() + + +output_channels = 14 # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld +block_shape = np.array((68, 68, 68,14)) +# # %% diff --git a/example/model_setup04_dacapo.py b/example/model_setup04_dacapo.py new file mode 100644 index 0000000..9739a23 --- /dev/null +++ b/example/model_setup04_dacapo.py @@ -0,0 +1,43 @@ +#%% +# pip install fly-organelles +from daisy.coordinate import Coordinate +import numpy as np + +output_voxel_size = Coordinate((4, 4, 4)) +voxel_size = Coordinate((8, 8, 8)) + +read_shape = Coordinate((216, 216, 216)) * Coordinate(voxel_size) +write_shape = Coordinate((68, 68, 68)) * Coordinate(output_voxel_size) + + +#%% +import torch +from dacapo.experiments.starts import CosemStartConfig +start_config = CosemStartConfig("setup04", "1820500") +# model = cosem_models.load_model('setup04/1820500') + +#%% + + + +output_channels = 14 # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld +block_shape = np.array((68, 68, 68,14)) +# %% +start = start_config.start_type(start_config) +# %% +start.run +# %% +from cellmap_models import cosem +model = cosem.load_model("setup04") +# %% +model +# %% +start.initialize_weights(model) +# %% +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") + +model.to(device) +model.eval() \ No newline at end of file diff --git a/example/model_spec.py b/example/model_spec.py new file mode 100644 index 0000000..fbe2dc2 --- /dev/null +++ b/example/model_spec.py @@ -0,0 +1,34 @@ +#%% +# pip install fly-organelles +from daisy.coordinate import Coordinate +import numpy as np +voxel_size = (8, 8, 8) +read_shape = Coordinate((178, 178, 178)) * Coordinate(voxel_size) +write_shape = Coordinate((56, 56, 56)) * Coordinate(voxel_size) +output_voxel_size = Coordinate((8, 8, 8)) + +#%% +import torch +from fly_organelles.model import StandardUnet +#%% +def load_eval_model(num_labels, checkpoint_path): + model_backbone = StandardUnet(num_labels) + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + print("device:", device) + checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device) + model_backbone.load_state_dict(checkpoint["model_state_dict"]) + model = torch.nn.Sequential(model_backbone, torch.nn.Sigmoid()) + model.to(device) + model.eval() + return model + +CHECKPOINT_PATH = "/nrs/saalfeld/heinrichl/fly_organelles/run08/model_checkpoint_438000" +output_channels = 8 # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld +model = load_eval_model(output_channels, CHECKPOINT_PATH) +block_shape = np.array((56, 56, 56,8)) +# %% +# print("model loaded",model) +# %% diff --git a/example/server_check.py b/example/server_check.py new file mode 100644 index 0000000..04a6e3a --- /dev/null +++ b/example/server_check.py @@ -0,0 +1,20 @@ +#%% +from cellmap_flow.server import CellMapFlowServer +from cellmap_flow.utils.data import ModelConfig, BioModelConfig, DaCapoModelConfig, ScriptModelConfig +#%% +# dataset = "/nrs/cellmap/data/jrc_mus-cerebellum-1/jrc_mus-cerebellum-1.zarr/recon-1/em/fibsem-uint8/s0" +# script_path = "/groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_setup04.py" +# script_path = "/groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/dacapo_run_retrieve.py" + +script_path = "/groups/cellmap/cellmap/zouinkhim/cellmap-flow/example/model_spec.py" +dataset = "/groups/cellmap/cellmap/ackermand/for_hideo/jrc_pri_neuron_0710Dish4/jrc_pri_neuron_0710Dish4.n5/em/fibsem-uint8/s0" + +model_config = ScriptModelConfig(script_path=script_path) +server = CellMapFlowServer(dataset, model_config) +# %% +chunk_x = 2 +chunk_y = 2 +chunk_z = 2 + +server.chunk(None, None, chunk_x, chunk_y, chunk_z, None) +# %% diff --git a/example_virtual_n5_david.py b/example_virtual_n5_david.py deleted file mode 100644 index ad0e3ae..0000000 --- a/example_virtual_n5_david.py +++ /dev/null @@ -1,404 +0,0 @@ -""" -# Example Virtual N5 - -Example service showing how to host a virtual N5, -suitable for browsing in neuroglancer. - -Neuroglancer is capable of browsing N5 files, as long as you store them on -disk and then host those files over http (with a CORS-friendly http server). -But what if your data doesn't exist on disk yet? - -This server hosts a "virtual" N5. Nothing is stored on disk, -but neuroglancer doesn't need to know that. This server provides the -necessary attributes.json files and chunk files on-demand, in the -"locations" (url patterns) that neuroglancer expects. - -For simplicity, this file uses Flask. In a production system, -you'd probably want to use something snazzier, like FastAPI. - -To run the example, install a few dependencies: - - conda create -n example-virtual-n5 -c conda-forge zarr flask flask-cors - conda activate example-virtual-n5 - -Then just execute the file: - - python example_virtual_n5.py - -Or, for better performance, use a proper http server: - - conda install -c conda-forge gunicorn - gunicorn --bind 0.0.0.0:8000 --workers 8 --threads 1 example_virtual_n5:app - -You can browse the data in neuroglancer after configuring the viewer with the appropriate layer [settings][1]. - -[1]: http://neuroglancer-demo.appspot.com/#!%7B%22dimensions%22:%7B%22x%22:%5B1e-9%2C%22m%22%5D%2C%22y%22:%5B1e-9%2C%22m%22%5D%2C%22z%22:%5B1e-9%2C%22m%22%5D%7D%2C%22position%22:%5B5000.5%2C7500.5%2C10000.5%5D%2C%22crossSectionScale%22:25%2C%22projectionScale%22:32767.999999999996%2C%22layers%22:%5B%7B%22type%22:%22image%22%2C%22source%22:%7B%22url%22:%22n5://http://127.0.0.1:8000%22%2C%22transform%22:%7B%22outputDimensions%22:%7B%22x%22:%5B1e-9%2C%22m%22%5D%2C%22y%22:%5B1e-9%2C%22m%22%5D%2C%22z%22:%5B1e-9%2C%22m%22%5D%2C%22c%5E%22:%5B1%2C%22%22%5D%7D%7D%7D%2C%22tab%22:%22rendering%22%2C%22opacity%22:0.42%2C%22shader%22:%22void%20main%28%29%20%7B%5Cn%20%20emitRGB%28%5Cn%20%20%20%20vec3%28%5Cn%20%20%20%20%20%20getDataValue%280%29%2C%5Cn%20%20%20%20%20%20getDataValue%281%29%2C%5Cn%20%20%20%20%20%20getDataValue%282%29%5Cn%20%20%20%20%29%5Cn%20%20%29%3B%5Cn%7D%5Cn%22%2C%22channelDimensions%22:%7B%22c%5E%22:%5B1%2C%22%22%5D%7D%2C%22name%22:%22colorful-data%22%7D%5D%2C%22layout%22:%224panel%22%7D -""" - -# %% -# NOTE: To generate host key and host cert do the following: https://serverfault.com/questions/224122/what-is-crt-and-key-files-and-how-to-generate-them -# openssl genrsa 2048 > host.key -# chmod 400 host.key -# openssl req -new -x509 -nodes -sha256 -days 365 -key host.key -out host.cert -# Then can run like this: -# gunicorn --certfile=host.cert --keyfile=host.key --bind 0.0.0.0:8000 --workers 1 --threads 1 example_virtual_n5:app -# NOTE: You will probably have to access the host:8000 separately and say it is safe to go there - -import argparse -from http import HTTPStatus -from flask import Flask, jsonify -from flask_cors import CORS - -import numpy as np -import numcodecs -from scipy import spatial -from zarr.n5 import N5ChunkWrapper -import torch -from funlib.persistence import open_ds -from funlib.geometry import Roi -import numpy as np -from dacapo.store.create_store import create_config_store, create_weights_store - -# NOTE: Normally we would just load in run but here we have to recreate it to save time since our run has so many points -from funlib.geometry import Coordinate -from dacapo.experiments.tasks import AffinitiesTaskConfig, AffinitiesTask, DistanceTask -from dacapo.experiments.architectures import CNNectomeUNetConfig, CNNectomeUNet - -from dacapo.experiments.tasks import DistanceTaskConfig -import gc -from dacapo.blockwise.watershed_function import segment_function -import neuroglancer -import socket - -app = Flask(__name__) -CORS(app) - -# This demo produces an RGB volume for aesthetic purposes. -# Note that this is 3 (virtual) teravoxels per channel. -SEGMENTATION = True -NUM_OUT_CHANNELS = 9 -if SEGMENTATION: - NUM_OUT_CHANNELS = 1 -BLOCK_SHAPE = np.array([36, 36, 36, NUM_OUT_CHANNELS]) -MAX_SCALE = 0 - -CHUNK_ENCODER = N5ChunkWrapper(np.uint64, BLOCK_SHAPE, compressor=numcodecs.GZip()) -EQUIVALENCES = neuroglancer.equivalence_map.EquivalenceMap() - -MODEL = None -DS = None - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-d", "--debug", action="store_true") - parser.add_argument("-p", "--port", default=8000) - args = parser.parse_args() - app.run( - host="0.0.0.0", - port=args.port, - debug=args.debug, - threaded=not args.debug, - use_reloader=args.debug, - ) - - -# global DS, CONFIG_STORE, WEIGHTS_STORE, MODEL -CONFIG_STORE = create_config_store() -WEIGHTS_STORE = create_weights_store() -ZARR_PATH = "/nrs/cellmap/data/jrc_22ak351-leaf-3m/jrc_22ak351-leaf-3m.zarr" -DATASET = "recon-1/em/fibsem-uint8" - -# load raw data -DS = open_ds( - ZARR_PATH, - f"{DATASET}/s0", -) -VOL_SHAPE_ZYX = np.array(DS.shape) -VOL_SHAPE = np.array([*VOL_SHAPE_ZYX[::-1], NUM_OUT_CHANNELS]) -VOL_SHAPE_ZYX_IN_BLOCKS = np.ceil(VOL_SHAPE_ZYX / BLOCK_SHAPE[:3]).astype(int) -BIAS = 0 -INPUT_VOXEL_SIZE = [8, 8, 8] -OUTPUT_VOXEL_SIZE = [8, 8, 8] - -OFFSETS = [ - Coordinate(1, 0, 0), - Coordinate(0, 1, 0), - Coordinate(0, 0, 1), - Coordinate(3, 0, 0), - Coordinate(0, 3, 0), - Coordinate(0, 0, 3), - Coordinate(9, 0, 0), - Coordinate(0, 9, 0), - Coordinate(0, 0, 9), -] -task_config = AffinitiesTaskConfig( - name=f"tmp", - neighborhood=OFFSETS, - lsds=True, - lsds_to_affs_weight_ratio=0.5, -) -VOXEL_SIZE = DS.voxel_size - -architecture_config = CNNectomeUNetConfig( - name="unet", - input_shape=Coordinate(216, 216, 216), - eval_shape_increase=Coordinate(72, 72, 72), - fmaps_in=1, - num_fmaps=12, - fmaps_out=72, - fmap_inc_factor=6, - downsample_factors=[(2, 2, 2), (3, 3, 3), (3, 3, 3)], - use_attention=False, - batch_norm=False, -) - -task = AffinitiesTask(task_config) -architecture = CNNectomeUNet(architecture_config) -MODEL = task.create_model(architecture) -weights = WEIGHTS_STORE.retrieve_weights( - "finetuned_3d_lsdaffs_weight_ratio_0.5_jrc_22ak351-leaf-3m_plasmodesmata_all_training_points_unet_default_trainer_lr_0.00015_bs_6__0", - 400000, -) -MODEL.load_state_dict(weights.model) -print("loaded weights") -MODEL.to("cuda") -MODEL.eval() - -# %% -neuroglancer.set_server_bind_address("0.0.0.0") -VIEWER = neuroglancer.Viewer() -ip_address = socket.getfqdn() - -with VIEWER.txn() as s: - s.layers["raw"] = neuroglancer.ImageLayer( - source=f'zarr://http://cellmap-vm1.int.janelia.org/{ZARR_PATH.replace("/nrs/cellmap", "/nrs")}/{DATASET}', - shader="""#uicontrol invlerp normalized -#uicontrol float bias slider(min=-1, max=1, step=0.01) - -void main() { -emitGrayscale(normalized()); -}""", - shaderControls={"bias": BIAS}, - ) - s.layers[f"inference_and_postprocessing_{BIAS}"] = neuroglancer.SegmentationLayer( - source=f"n5://http://{ip_address}:8000/test.n5/inference_and_postprocessing_{BIAS}", - equivalences=EQUIVALENCES.to_json(), - ) - s.cross_section_scale = 1e-9 - s.projection_scale = 500e-9 -print(VIEWER) - -PREVIOUS_UPDATE_TIME = 0 -import time - -EDGE_VOXEL_POSITION_TO_VAL_DICT = {} -EQUIVALENCES = neuroglancer.equivalence_map.EquivalenceMap() - - -def update_state(): - global PREVIOUS_UPDATE_TIME, BIAS, EQUIVALENCES, EDGE_VOXEL_POSITION_TO_VAL_DICT - current_bias = VIEWER.state.layers["raw"].shaderControls.get("bias", 0) - if current_bias != BIAS: - with VIEWER.txn() as s: - s.layers.__delitem__(f"inference_and_postprocessing_{BIAS}") - BIAS = current_bias - EDGE_VOXEL_POSITION_TO_VAL_DICT = {} - EQUIVALENCES = neuroglancer.equivalence_map.EquivalenceMap() - s.layers[f"inference_and_postprocessing_{BIAS}"] = ( - neuroglancer.SegmentationLayer( - source=f"n5://http://{ip_address}:8000/test.n5/inference_and_postprocessing_{BIAS}", - ) - ) - - with VIEWER.txn() as s: - print(f"{EQUIVALENCES.to_json()=}") - s.layers[f"inference_and_postprocessing_{BIAS}"].equivalences = ( - EQUIVALENCES.to_json() - ) - # LOCAL_VOLUME.invalidate() - PREVIOUS_UPDATE_TIME = time.time() - - -# %% -@app.route("/test.n5//attributes.json") -def top_level_attributes(dataset): - scales = [[2**s, 2**s, 2**s, 1] for s in range(MAX_SCALE + 1)] - attr = { - "pixelResolution": {"dimensions": [*OUTPUT_VOXEL_SIZE, 1.0], "unit": "nm"}, - "ordering": "C", - "scales": scales, - "axes": ["x", "y", "z", "c^"], - "units": ["nm", "nm", "nm", ""], - "translate": [0, 0, 0, 0], - } - return jsonify(attr), HTTPStatus.OK - - -@app.route("/test.n5//s/attributes.json") -def attributes(dataset, scale): - attr = { - "transform": { - "ordering": "C", - "axes": ["x", "y", "z", "c^"], - "scale": [ - *OUTPUT_VOXEL_SIZE, - 1, - ], - "units": ["nm", "nm", "nm"], - "translate": [0.0, 0.0, 0.0], - }, - "compression": {"type": "gzip", "useZlib": False, "level": -1}, - "blockSize": BLOCK_SHAPE.tolist(), - "dataType": "uint64", - "dimensions": (VOL_SHAPE[:3] // 2**scale).tolist() + [int(VOL_SHAPE[3])], - } - return jsonify(attr), HTTPStatus.OK - - -@app.route( - "/test.n5//s////" -) -def chunk(dataset, scale, chunk_x, chunk_y, chunk_z, chunk_c): - """ - Serve up a single chunk at the requested scale and location. - - This 'virtual N5' will just display a color gradient, - fading from black at (0,0,0) to white at (max,max,max). - """ - - assert chunk_c == 0, "neuroglancer requires that all blocks include all channels" - if dataset == f"inference_and_postprocessing_{BIAS}": - print(dataset, chunk_x, chunk_y, chunk_z, chunk_c, 0) - corner = BLOCK_SHAPE[:3] * np.array([chunk_z, chunk_y, chunk_x]) - box = np.array([corner, BLOCK_SHAPE[:3]]) * OUTPUT_VOXEL_SIZE - print(dataset, chunk_x, chunk_y, chunk_z, chunk_c, 1) - global_id_offset = np.prod(BLOCK_SHAPE[:3]) * ( - VOL_SHAPE_ZYX_IN_BLOCKS[0] * VOL_SHAPE_ZYX_IN_BLOCKS[1] * chunk_x - + VOL_SHAPE_ZYX_IN_BLOCKS[0] * chunk_y - + chunk_z - ) - print(dataset, chunk_x, chunk_y, chunk_z, chunk_c, 2) - block_vol = postprocess_for_chunk( - dataset, inference_for_chunk(scale, box), global_id_offset, corner - ) - print(dataset, chunk_x, chunk_y, chunk_z, chunk_c, 3) - print(corner / 8) - return ( - # Encode to N5 chunk format (header + compressed data) - CHUNK_ENCODER.encode(block_vol), - HTTPStatus.OK, - {"Content-Type": "application/octet-stream"}, - ) - else: - return jsonify(dataset), HTTPStatus.OK - - -# %% -def inference_for_chunk(scale, box): - global OFFSETS - # Compute the portion of the box that is actually populated. - # It will differ from [(0,0,0), BLOCK_SHAPE] at higher scales, - # where the chunk may extend beyond the bounding box of the entire volume. - box = box.copy() - # box[1] = np.minimum(box[0] + box[1], VOL_SHAPE[:3] // 2**scale) - print(f"{box=}") - grow_by = 90 * INPUT_VOXEL_SIZE[0] - roi = Roi(box[0], box[1]).grow(grow_by, grow_by) - print(f"{(roi/8)=} after grow") - data = DS.to_ndarray(roi) / 255.0 - # create random array with floats between 0 and 1 - # prepend batch and channel dimensions - data = data[np.newaxis, np.newaxis, ...].astype(np.float32) - # move to cuda - data = torch.from_numpy(data).to("cuda") - with torch.no_grad(): - block_vol_czyx = MODEL(data) - block_vol_czyx = block_vol_czyx.cpu().numpy() - block_vol_czyx = block_vol_czyx[0, : len(OFFSETS), ...] - print(np.unique(block_vol_czyx)) - # block_vol_czyx = np.swapaxes(block_vol_czyx, 1, 3).copy() - del data - return block_vol_czyx - - -import numpy_indexed as npi -import mwatershed as mws -from scipy.ndimage import measurements - - -def postprocess_for_chunk(dataset, chunk, global_id_offset, corner): - global BIAS, OFFSETS - affs = chunk.astype(np.float64) - segmentation = mws.agglom( - affs - BIAS, - OFFSETS, - ) - # filter fragments - average_affs = np.mean(affs, axis=0) - - filtered_fragments = [] - - fragment_ids = np.unique(segmentation) - - for fragment, mean in zip( - fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids) - ): - if mean < BIAS: - filtered_fragments.append(fragment) - - filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) - replace = np.zeros_like(filtered_fragments) - - # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input - if filtered_fragments.size > 0: - segmentation = npi.remap( - segmentation.flatten(), filtered_fragments, replace - ).reshape(segmentation.shape) - segmentation_squeezed = segmentation.copy() - segmentation = segmentation[np.newaxis, ...] - - # get touching voxels - mask = np.zeros_like(segmentation_squeezed, dtype=bool) - mask[1:-1, 1:-1, 1:-1] = True - segmentation_squeezed_ma = np.ma.masked_array(segmentation_squeezed, mask) - z, y, x = np.ma.where(segmentation_squeezed_ma > 0) - values = segmentation_squeezed_ma[z, y, x] - EDGE_VOXEL_POSITION_TO_VAL_DICT.update( - dict( - zip( - zip( - z + corner[0], - y + corner[1], - x + corner[2], - ), - values, - ) - ) - ) - update_equivalences() - return segmentation - - -# %% -def update_equivalences(): - global PREVIOUS_UPDATE_TIME, EDGE_VOXEL_POSITION_TO_VAL_DICT, EQUIVALENCES - if time.time() - PREVIOUS_UPDATE_TIME > 5: - print("Updating equivalences") - positions = list(EDGE_VOXEL_POSITION_TO_VAL_DICT.keys()) - ids = list(EDGE_VOXEL_POSITION_TO_VAL_DICT.values()) - tree = spatial.cKDTree(positions) - neighbors = tree.query_ball_tree(tree, 1) # distance of 1 voxel - for i in range(len(neighbors)): - for j in neighbors[i]: - EQUIVALENCES.union(ids[i], ids[j]) - update_state() - PREVIOUS_UPDATE_TIME = time.time() - print("Updated equivalences") - - -# %% -if __name__ == "__main__": - main() - -# %% diff --git a/example_virtual_n5_marwan.py b/example_virtual_n5_marwan.py deleted file mode 100644 index 20df701..0000000 --- a/example_virtual_n5_marwan.py +++ /dev/null @@ -1,375 +0,0 @@ -# %% -""" -# Example Virtual N5 - -Example service showing how to host a virtual N5, -suitable for browsing in neuroglancer. - -Neuroglancer is capable of browsing N5 files, as long as you store them on -disk and then host those files over http (with a CORS-friendly http server). -But what if your data doesn't exist on disk yet? - -This server hosts a "virtual" N5. Nothing is stored on disk, -but neuroglancer doesn't need to know that. This server provides the -necessary attributes.json files and chunk files on-demand, in the -"locations" (url patterns) that neuroglancer expects. - -For simplicity, this file uses Flask. In a production system, -you'd probably want to use something snazzier, like FastAPI. - -To run the example, install a few dependencies: - - conda create -n example-virtual-n5 -c conda-forge zarr flask flask-cors - conda activate example-virtual-n5 - -Then just execute the file: - - python example_virtual_n5.py - -Or, for better performance, use a proper http server: - - conda install -c conda-forge gunicorn - gunicorn --bind 0.0.0.0:8000 --workers 8 --threads 1 example_virtual_n5:app - -You can browse the data in neuroglancer after configuring the viewer with the appropriate layer [settings][1]. - -[1]: http://neuroglancer-demo.appspot.com/#!%7B%22dimensions%22:%7B%22x%22:%5B1e-9%2C%22m%22%5D%2C%22y%22:%5B1e-9%2C%22m%22%5D%2C%22z%22:%5B1e-9%2C%22m%22%5D%7D%2C%22position%22:%5B5000.5%2C7500.5%2C10000.5%5D%2C%22crossSectionScale%22:25%2C%22projectionScale%22:32767.999999999996%2C%22layers%22:%5B%7B%22type%22:%22image%22%2C%22source%22:%7B%22url%22:%22n5://http://127.0.0.1:8000%22%2C%22transform%22:%7B%22outputDimensions%22:%7B%22x%22:%5B1e-9%2C%22m%22%5D%2C%22y%22:%5B1e-9%2C%22m%22%5D%2C%22z%22:%5B1e-9%2C%22m%22%5D%2C%22c%5E%22:%5B1%2C%22%22%5D%7D%7D%7D%2C%22tab%22:%22rendering%22%2C%22opacity%22:0.42%2C%22shader%22:%22void%20main%28%29%20%7B%5Cn%20%20emitRGB%28%5Cn%20%20%20%20vec3%28%5Cn%20%20%20%20%20%20getDataValue%280%29%2C%5Cn%20%20%20%20%20%20getDataValue%281%29%2C%5Cn%20%20%20%20%20%20getDataValue%282%29%5Cn%20%20%20%20%29%5Cn%20%20%29%3B%5Cn%7D%5Cn%22%2C%22channelDimensions%22:%7B%22c%5E%22:%5B1%2C%22%22%5D%7D%2C%22name%22:%22colorful-data%22%7D%5D%2C%22layout%22:%224panel%22%7D -""" - -# NOTE: To generate host key and host cert do the following: https://serverfault.com/questions/224122/what-is-crt-and-key-files-and-how-to-generate-them -# openssl genrsa 2048 > host.key -# chmod 400 host.key -# openssl req -new -x509 -nodes -sha256 -days 365 -key host.key -out host.cert -# Then can run like this: -# gunicorn --certfile=host.cert --keyfile=host.key --bind 0.0.0.0:8000 --workers 2 --threads 1 example_virtual_n5:app -# NOTE: You will probably have to access the host:8000 separately and say it is safe to go there -# use this works with tensorstore: python example_virtual_n5_marwan.py - -import argparse -from http import HTTPStatus -from flask import Flask, jsonify, render_template -from flask_cors import CORS - -import numpy as np -import numcodecs -from zarr.n5 import N5ChunkWrapper -import torch -from funlib.persistence import open_ds -from funlib.geometry import Roi -import numpy as np - -from skimage.measure import label -from scipy import spatial -import neuroglancer -import time - -# NOTE: Normally we would just load in run but here we have to recreate it to save time since our run has so many points -import gc -import socket - -app = Flask(__name__) -CORS(app) -DEBUG = False - -# This demo produces an RGB volume for aesthetic purposes. -# Note that this is 3 (virtual) teravoxels per channel. -NUM_CHANNELS = 1 -BLOCK_SHAPE = np.array([36, 36, 36, NUM_CHANNELS]) -MAX_SCALE = 0 - -CHUNK_ENCODER = N5ChunkWrapper(np.uint64, BLOCK_SHAPE, compressor=numcodecs.GZip()) - -MODEL = None -DS = None -EDGE_VOXEL_POSITION_TO_VAL_DICT = {} -EQUIVALENCES = neuroglancer.equivalence_map.EquivalenceMap() - -# %% -ZARR_PATH = "/nrs/cellmap/data/jrc_c-elegans-bw-1/jrc_c-elegans-bw-1_normalized.zarr" -DATASET = "recon-1/em/fibsem-uint8" -# load raw data -DS = open_ds( - ZARR_PATH, - f"{DATASET}/s2", -) -VOL_SHAPE_ZYX = np.array(DS.shape) -VOL_SHAPE = np.array([*VOL_SHAPE_ZYX[::-1], NUM_CHANNELS]) -VOL_SHAPE_ZYX_IN_BLOCKS = np.ceil(VOL_SHAPE_ZYX / BLOCK_SHAPE[:3]).astype(int) -PREDICTION_THRESHOLD = 0 - -# update_state() -# global DS, CONFIG_STORE, WEIGHTS_STORE, MODEL - -INPUT_VOXEL_SIZE = [16, 16, 16] -OUTPUT_VOXEL_SIZE = [16, 16, 16] - -if not DEBUG: - from dacapo.store.create_store import create_config_store, create_weights_store - from dacapo.experiments import Run - - CONFIG_STORE = create_config_store() - WEIGHTS_STORE = create_weights_store() - run_name = "20240924_mito_setup04_no_upsample_16_16_0" - run_config = CONFIG_STORE.retrieve_run_config(run_name) - - run = Run(run_config) # , load_starter_model=False) - task = run.task - MODEL = run.model - # print(MODEL.architecture) - # path_to_weights = "/nrs/cellmap/zouinkhim/crop_num_experiment_v2/v21_mito_attention_finetuned_distances_8nm_mito_jrc_mus-livers_mito_8nm_attention-upsample-unet_default_one_label_1/checkpoints/iterations/345000" - # weights = torch.load(path_to_weights, map_location="cuda") - weights = WEIGHTS_STORE.retrieve_weights( - run_name, - 80000, - ) - MODEL.load_state_dict(weights.model) - # MODEL.load_state_dict(weights.model) - MODEL.to("cuda") - MODEL.eval() - -neuroglancer.set_server_bind_address("0.0.0.0") -VIEWER = neuroglancer.Viewer() -ip_address = socket.getfqdn() - -with VIEWER.txn() as s: - s.layers["raw"] = neuroglancer.ImageLayer( - source=f'zarr://http://cellmap-vm1.int.janelia.org/{ZARR_PATH.replace("/nrs/cellmap", "/nrs")}/{DATASET}', - shader="""#uicontrol invlerp normalized -#uicontrol float prediction_threshold slider(min=-1, max=1, step=0.1) - -void main() { -emitGrayscale(normalized()); -}""", - shaderControls={"prediction_threshold": PREDICTION_THRESHOLD}, - ) - s.layers[f"inference_and_postprocessing_{PREDICTION_THRESHOLD}"] = ( - neuroglancer.SegmentationLayer( - source=f"n5://http://{ip_address}:8000/test.n5/inference_and_postprocessing_{PREDICTION_THRESHOLD}", - equivalences=EQUIVALENCES.to_json(), - ) - ) - s.cross_section_scale = 1e-9 - s.projection_scale = 500e-9 -print(VIEWER) - - -PREVIOUS_UPDATE_TIME = 0 - - -def update_state(): - global PREVIOUS_UPDATE_TIME, PREDICTION_THRESHOLD, EQUIVALENCES, EDGE_VOXEL_POSITION_TO_VAL_DICT - current_prediction_threshold = VIEWER.state.layers["raw"].shaderControls.get( - "prediction_threshold", 0 - ) - if current_prediction_threshold != PREDICTION_THRESHOLD: - with VIEWER.txn() as s: - s.layers.__delitem__(f"inference_and_postprocessing_{PREDICTION_THRESHOLD}") - PREDICTION_THRESHOLD = current_prediction_threshold - EDGE_VOXEL_POSITION_TO_VAL_DICT = {} - EQUIVALENCES = neuroglancer.equivalence_map.EquivalenceMap() - s.layers[f"inference_and_postprocessing_{PREDICTION_THRESHOLD}"] = ( - neuroglancer.SegmentationLayer( - source=f"n5://http://{ip_address}:8000/test.n5/inference_and_postprocessing_{PREDICTION_THRESHOLD}", - ) - ) - - with VIEWER.txn() as s: - print(f"{EQUIVALENCES.to_json()=}") - s.layers[ - f"inference_and_postprocessing_{PREDICTION_THRESHOLD}" - ].equivalences = EQUIVALENCES.to_json() - # LOCAL_VOLUME.invalidate() - PREVIOUS_UPDATE_TIME = time.time() - - -# %% - - -# @app.route("/home") -# def home(): - -# print(VIEWER) -# # print(neuroglancer.to_url(viewer.state)) -# # s.position = VOL_SHAPE_ZYX[::-1] / 2 - -# return render_template("iframe.html", url=VIEWER) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-d", "--debug", action="store_true") - parser.add_argument("-p", "--port", default=8000) - args = parser.parse_args() - app.run( - host="0.0.0.0", - port=args.port, - debug=args.debug, - threaded=not args.debug, - use_reloader=args.debug, - ) - - -@app.route("/test.n5//attributes.json") -def top_level_attributes(dataset): - scales = [[2**s, 2**s, 2**s, 1] for s in range(MAX_SCALE + 1)] - attr = { - "pixelResolution": {"dimensions": [*OUTPUT_VOXEL_SIZE, 1.0], "unit": "nm"}, - "ordering": "C", - "scales": scales, - "axes": ["x", "y", "z", "c^"], - "units": ["nm", "nm", "nm", ""], - "translate": [0, 0, 0, 0], - } - return jsonify(attr), HTTPStatus.OK - - -@app.route("/test.n5//s/attributes.json") -def attributes(dataset, scale): - attr = { - "transform": { - "ordering": "C", - "axes": ["x", "y", "z", "c^"], - "scale": [ - *OUTPUT_VOXEL_SIZE, - 1, - ], - "units": ["nm", "nm", "nm"], - "translate": [0.0, 0.0, 0.0], - }, - "compression": {"type": "gzip", "useZlib": False, "level": -1}, - "blockSize": BLOCK_SHAPE.tolist(), - "dataType": "uint64", - "dimensions": (VOL_SHAPE[:3] // 2**scale).tolist() + [int(VOL_SHAPE[3])], - } - return jsonify(attr), HTTPStatus.OK - - -@app.route( - "/test.n5//s////" -) -def chunk(dataset, scale, chunk_x, chunk_y, chunk_z, chunk_c): - """ - Serve up a single chunk at the requested scale and location. - - This 'virtual N5' will just display a color gradient, - fading from black at (0,0,0) to white at (max,max,max). - """ - - assert chunk_c == 0, "neuroglancer requires that all blocks include all channels" - if dataset == f"inference_and_postprocessing_{PREDICTION_THRESHOLD}": - print(dataset, chunk_x, chunk_y, chunk_z, chunk_c, 0) - corner = BLOCK_SHAPE[:3] * np.array([chunk_z, chunk_y, chunk_x]) - box = np.array([corner, BLOCK_SHAPE[:3]]) * OUTPUT_VOXEL_SIZE - print(dataset, chunk_x, chunk_y, chunk_z, chunk_c, 1) - global_id_offset = np.prod(BLOCK_SHAPE[:3]) * ( - VOL_SHAPE_ZYX_IN_BLOCKS[0] * VOL_SHAPE_ZYX_IN_BLOCKS[1] * chunk_x - + VOL_SHAPE_ZYX_IN_BLOCKS[0] * chunk_y - + chunk_z - ) - print(dataset, chunk_x, chunk_y, chunk_z, chunk_c, 2) - block_vol = postprocess_for_chunk( - dataset, inference_for_chunk(scale, box), global_id_offset, corner - ) - print(dataset, chunk_x, chunk_y, chunk_z, chunk_c, 3) - print(corner / 16) - return ( - # Encode to N5 chunk format (header + compressed data) - CHUNK_ENCODER.encode(block_vol), - HTTPStatus.OK, - {"Content-Type": "application/octet-stream"}, - ) - else: - return jsonify(dataset), HTTPStatus.OK - - -def inference_for_chunk(scale, box): - if not DEBUG: - - # Compute the portion of the box that is actually populated. - # It will differ from [(0,0,0), BLOCK_SHAPE] at higher scales, - # where the chunk may extend beyond the bounding box of the entire volume. - box = box.copy() - # box[1] = np.minimum(box[0] + box[1], VOL_SHAPE[:3] // 2**scale) - print(f"{box=}") - grow_by = 90 * INPUT_VOXEL_SIZE[0] - roi = Roi(box[0], box[1]).grow(grow_by, grow_by) - print(f"{(roi/16)=} after grow") - data = DS.to_ndarray(roi) / 255.0 - # create random array with floats between 0 and 1 - # prepend batch and channel dimensions - data = data[np.newaxis, np.newaxis, ...].astype(np.float32) - # move to cuda - data = torch.from_numpy(data).to("cuda") - with torch.no_grad(): - block_vol_czyx = MODEL(data) - block_vol_czyx = block_vol_czyx.cpu().numpy() - block_vol_czyx = block_vol_czyx[0, :NUM_CHANNELS, ...] - # block_vol_czyx = np.swapaxes(block_vol_czyx, 1, 3).copy() - del data - - torch.cuda.empty_cache() - gc.collect() - return block_vol_czyx - else: - print("inference_for_chunk", 4) - return np.random.random((1, *BLOCK_SHAPE[:3])) * 2 - 1 - - -# %% -def postprocess_for_chunk(dataset, chunk, id_offset, corner): - global EDGE_VOXEL_POSITION_TO_VAL_DICT, PREDICTION_THRESHOLD - if dataset == f"inference_and_postprocessing_{PREDICTION_THRESHOLD}": - # do connected components on thresholded chunk - thresholded = chunk > PREDICTION_THRESHOLD - postprocessed, num = label(thresholded, return_num=True) - postprocessed = postprocessed.astype(np.uint64) - if num == 0: - return postprocessed.astype(np.uint64) - - postprocessed[postprocessed > 0] += id_offset - - postprocessed_squeezed = postprocessed[0, ...] - mask = np.zeros_like(postprocessed_squeezed, dtype=bool) - mask[1:-1, 1:-1, 1:-1] = True - postprocessed_squeezed_ma = np.ma.masked_array(postprocessed_squeezed, mask) - z, y, x = np.ma.where(postprocessed_squeezed_ma > 0) - values = postprocessed_squeezed_ma[z, y, x] - EDGE_VOXEL_POSITION_TO_VAL_DICT.update( - dict( - zip( - zip( - z + corner[0], - y + corner[1], - x + corner[2], - ), - values, - ) - ) - ) - update_equivalences() - return postprocessed - else: - print("postprocessed", 5) - return np.zeros((1, *BLOCK_SHAPE[:3]), dtype=np.uint64) - - -# %% -def update_equivalences(): - global PREVIOUS_UPDATE_TIME, EDGE_VOXEL_POSITION_TO_VAL_DICT, EQUIVALENCES - if time.time() - PREVIOUS_UPDATE_TIME > 5: - print("Updating equivalences") - positions = list(EDGE_VOXEL_POSITION_TO_VAL_DICT.keys()) - ids = list(EDGE_VOXEL_POSITION_TO_VAL_DICT.values()) - tree = spatial.cKDTree(positions) - neighbors = tree.query_ball_tree(tree, 1) # distance of 1 voxel - for i in range(len(neighbors)): - for j in neighbors[i]: - EQUIVALENCES.union(ids[i], ids[j]) - update_state() - PREVIOUS_UPDATE_TIME = time.time() - print("Updated equivalences") - - -# %% -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3a0456a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,190 @@ +# https://peps.python.org/pep-0517/ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +# https://peps.python.org/pep-0621/ +[project] +name = "cellmap-flow" +description = "Realtime prediction using neuroglancer" +readme = "README.md" +requires-python = ">=3.10" +# license = { text = "BSD 3-Clause License" } +authors = [ + { email = "ackermand@hhmi.org", name = "David Ackerman" }, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: BSD License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Typing :: Typed", +] +dynamic = ["version"] +dependencies = [ + "dacapo-ml @ git+https://github.com/janelia-cellmap/dacapo.git@cellmap_stable", + "funlib.persistence==0.4.0", + "numpy", + # "PyYAML", + "gunicorn", + "flask", + "flask-cors", + "tensorstore", + "daisy", + "bioimageio.core[onnx,pytorch]==0.6.1", + "marshmallow", + "scikit-image", + "funlib.segment @ git+https://github.com/funkelab/funlib.segment.git", + "funlib.show.neuroglancer @ git+https://github.com/funkelab/funlib.show.neuroglancer.git@0c609d2cbf09af976bda998fff17a0454e2782ee" + ] + +# extras +# https://peps.python.org/pep-0621/#dependencies-optional-dependencies +[project.optional-dependencies] +test = ["pytest", "pytest-cov", "pytest-lazy-fixtures"] +dev = [ + "black", + "mypy", + "pdbpp", + "rich", + "ruff", + "pre-commit", +] +docs = [ + "sphinx-autodoc-typehints", + "sphinx-autoapi", + "sphinx-click", + "sphinx-rtd-theme", + "myst-parser", +] + +[project.urls] +homepage = "https://github.io/janelia-cellmap/process_blockwise" +repository = "https://github.com/janelia-cellmap/process_blockwise" + +# https://hatch.pypa.io/latest/config/metadata/ +[tool.hatch.version] +source = "vcs" + +# https://hatch.pypa.io/latest/config/build/#file-selection +# [tool.hatch.build.targets.sdist] +# include = ["/src", "/tests"] +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["cellmap_flow"] + +[project.scripts] +cellmap_flow = "cellmap_flow.n_cli:cli" +cellmap_flow_server = "cellmap_flow.server_cli:cli" + +# https://github.com/charliermarsh/ruff +[tool.ruff] +line-length = 88 +target-version = "py310" +src = ["cellmap_flow"] + +[tool.ruff.lint] +# https://beta.ruff.rs/docs/rules/ +# We may want to enable some of these options later +select = [ + "E", # style errors +# "W", # style warnings + "F", # flakes +# "D", # pydocstyle +# "I", # isort +# "UP", # pyupgrade +# "C4", # flake8-comprehensions +# "B", # flake8-bugbear +# "A001", # flake8-builtins +# "RUF", # ruff-specific rules +] +extend-ignore = ["E501"] + +[tool.ruff.lint.per-file-ignores] +"tests/*.py" = ["D", "S"] +"__init__.py" = ["F401"] + +# https://docs.pytest.org/en/6.2.x/customize.html +[tool.pytest.ini_options] +minversion = "6.0" +testpaths = ["tests"] +filterwarnings = [ + "error", + "ignore::DeprecationWarning", + ] + +# https://mypy.readthedocs.io/en/stable/config_file.html +[tool.mypy] +files = "cellmap_flow/**/" +strict = false +disallow_any_generics = false +disallow_subclassing_any = false +show_error_codes = true +pretty = true +exclude = [ + "scratch/*", + "examples/*", +] + + +# # module specific overrides +[[tool.mypy.overrides]] +module = [ + "cellmap_models.*", + "funlib.*", + "toml.*", + "gunpowder.*", + "scipy.*", + "augment.*", + "tifffile.*", + "daisy.*", + "lazy_property.*", + "skimage.*", + "fibsem_tools.*", + "neuroglancer.*", + "tqdm.*", + "zarr.*", + "pymongo.*", + "bson.*", + "affogato.*", + "SimpleITK.*", + "bokeh.*", + "lsd.*", + "yaml.*", + "pytest_lazyfixture.*", + "neuclease.dvid.*", + "mwatershed.*", + "numpy_indexed.*", + "empanada_napari.*", + "napari.*", + "empanada.*", + "IPython.*", +] +ignore_missing_imports = true + +# https://coverage.readthedocs.io/en/6.4/config.html +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "@overload", + "except ImportError", + "\\.\\.\\.", + "raise NotImplementedError()", +] +[tool.coverage.run] +source = ["process_blockwise"] + +# https://github.com/mgedmin/check-manifest#configuration +[tool.check-manifest] +ignore = [ + ".github_changelog_generator", + ".pre-commit-config.yaml", + ".ruff_cache/**/*", + "tests/**/*", +] +