Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ SPIMquant is a Snakebids app for quantitative analysis of SPIM (lightsheet) brai
Features include:
- Deformable registration to a template
- Atlas-based quantification of pathology
- High-resolution Imaris dataset creation from atlas region bounding boxes
- Group-level statistical analysis with contrast comparisons
- Coarse-grained and fine-grained parallelization using Snakemake and Dask
- Support for reading BIDS datasets directly from cloud-based object storage
Expand Down
11 changes: 11 additions & 0 deletions spimquant/config/snakebids.yml
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ parse_args:
default: 42
type: int

--crop_labels:
help: "List of atlas label names, abbreviations, or indices to extract as Imaris crops. If not specified, all labels are cropped."
default: null
nargs: '+'

--crop_atlas_segs:
help: "Atlas segmentations to use for extracting Imaris crops (default: roi22)"
default:
- roi22
action: store
nargs: '+'
--contrast_column:
help: "Column name in participants.tsv to use for defining group contrasts (e.g., 'treatment', 'genotype'). Required for group-level statistical analysis."
default: null
Expand Down
28 changes: 28 additions & 0 deletions spimquant/workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ for seg in config["patch_atlas_segs"]:
else:
patch_atlas_segs.append(seg)

# atlas segmentations to use for Imaris crops (defaults to roi22)
crop_atlas_segs = []
for seg in config["crop_atlas_segs"]:
if seg not in all_atlas_segs:
raise ValueError(
f"Chosen crop segmentation {seg} was not found in the template {config['template']}"
)
else:
crop_atlas_segs.append(seg)
# Validate that contrast arguments are provided when using group analysis level
if config.get("analysis_level") == "group":
if config.get("contrast_column") is None or config.get("contrast_values") is None:
Expand Down Expand Up @@ -335,6 +344,25 @@ rule all_spim_patches:
),


rule all_imaris_crops:
input:
inputs["spim"].expand(
bids(
root=root,
datatype="micr",
seg="{seg}",
from_="{template}",
level="{level}",
desc="crop",
suffix="SPIM.imaris",
**inputs["spim"].wildcards,
),
seg=crop_atlas_segs,
template=config["template"],
level=0,
),


rule all:
default_target: True
input:
Expand Down
49 changes: 49 additions & 0 deletions spimquant/workflow/rules/patches.smk
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,52 @@ rule create_corrected_spim_patches:
runtime=30,
script:
"../scripts/create_patches.py"


rule create_imaris_crops:
"""Create high-resolution Imaris datasets from SPIM data based on atlas region bounding boxes.

This rule extracts crops from SPIM zarr data based on bounding boxes of
specified atlas regions. Crops are saved as Imaris datasets using zarrnii's
to_imaris() function. Level defaults to 0 for high-resolution output.
"""
input:
spim=inputs["spim"].path,
dseg=bids(
root=root,
datatype="micr",
seg="{seg}",
desc="deform",
level=config["registration_level"],
from_="{template}",
suffix="dseg.nii.gz",
**inputs["spim"].wildcards,
),
label_tsv=bids_tpl(
root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"
),
params:
crop_labels=config.get("crop_labels", None),
hires_level=0, # input is the raw data
zarrnii_kwargs={"orientation": config["orientation"]},
output:
crops_dir=directory(
bids(
root=root,
datatype="micr",
seg="{seg}",
from_="{template}",
level="{level}",
desc="crop",
suffix="SPIM.imaris",
**inputs["spim"].wildcards,
)
),
group:
"subj"
threads: 32
resources:
mem_mb=32000,
runtime=60,
script:
"../scripts/create_imaris_crops.py"
144 changes: 144 additions & 0 deletions spimquant/workflow/scripts/create_imaris_crops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Create Imaris datasets from zarr data based on atlas region bounding boxes.

This script extracts crops from SPIM zarr data based on bounding boxes of
specified atlas regions. It uses the zarrnii library to load the atlas with
its label lookup table, get bounding boxes for specified regions, crop the
image data, and save as Imaris datasets.

The output is a directory containing Imaris datasets named:
seg-{atlas_seg}_label-{labelabbrev}.ims
"""

import logging
import re
from pathlib import Path
import numpy as np

import dask
from dask.diagnostics import ProgressBar
from zarrnii import ZarrNii, ZarrNiiAtlas

# Set up logging for snakemake scripts
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

# Get input from spim
input_zarr = snakemake.input.spim

input_dseg = snakemake.input.dseg
input_tsv = snakemake.input.label_tsv
output_dir = snakemake.output.crops_dir

# Parameters
crop_labels = snakemake.params.crop_labels
zarrnii_kwargs = snakemake.params.zarrnii_kwargs

# Get wildcards
atlas_seg = snakemake.wildcards.seg

# Get level from wildcards if available
target_level = int(snakemake.wildcards.level)
hires_level = int(snakemake.params.hires_level)

downsampling_level = target_level - hires_level
if downsampling_level < 0:
raise ValueError(
"Target level for create_imaris_crops is smaller than the input level!"
)


# Set up dask for parallel processing
dask.config.set(scheduler="threads", num_workers=snakemake.threads)

# Create output directory
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Load the atlas with labels
atlas = ZarrNiiAtlas.from_files(
input_dseg,
input_tsv,
**{k: v for k, v in zarrnii_kwargs.items() if v is not None},
)

# Load the image data
image = ZarrNii.from_ome_zarr(
input_zarr,
level=downsampling_level,
**{k: v for k, v in zarrnii_kwargs.items() if v is not None},
)

# Determine which labels to use for crops
if crop_labels is None:
# Use all non-background labels from the atlas
labels_to_use = atlas.labels_df[atlas.labels_df[atlas.label_column] > 0][
[atlas.label_column, atlas.abbrev_column]
].values.tolist()
else:
# Use specified labels - can be indices or names/abbreviations
labels_to_use = []
for label in crop_labels:
if isinstance(label, int):
# Label index provided
row = atlas.labels_df[atlas.labels_df[atlas.label_column] == label]
if not row.empty:
labels_to_use.append([label, row[atlas.abbrev_column].values[0]])
else:
# Label name or abbreviation provided
row = atlas.labels_df[
(atlas.labels_df[atlas.name_column] == label)
| (atlas.labels_df[atlas.abbrev_column] == label)
]
if not row.empty:
labels_to_use.append(
[
row[atlas.label_column].values[0],
row[atlas.abbrev_column].values[0],
]
)

# Extract crops for each label
with ProgressBar():
for label_idx, label_abbrev in labels_to_use:
try:
# Get bounding box for this region
bbox_min, bbox_max = atlas.get_region_bounding_box(
region_ids=int(label_idx)
)
# Crop the image using the bounding box
cropped = image.crop(bbox_min, bbox_max, physical_coords=True)

logging.info(f"cropped shape for {label_abbrev} is {cropped.shape}")
if any(d > 5000 for d in cropped.shape):
raise ValueError(
f"Cropped image too large, shape={cropped.shape}, skipping"
)

# Clean label namn for filename: remove non-alphanumeric chars
clean_abbrev = re.sub(r"[^a-zA-Z0-9]+", "", str(label_abbrev))
# Fallback to label index if name would be empty
if not clean_abbrev:
clean_abbrev = f"idx{label_idx}"
subject = snakemake.wildcards.subject
out_file = Path(output_dir) / (
f"sub-{subject}_seg-{atlas_seg}_label-{clean_abbrev}_SPIM.ims"
)

# Save as Imaris dataset
cropped.to_imaris(str(out_file))

logging.info(
f"Created Imaris crop for label {label_abbrev} (index {label_idx}): {out_file}"
)

except ValueError as e:
# ValueError from get_region_bounding_box when region has no voxels
logging.warning(
f"Skipping label {label_abbrev} (index {label_idx}): "
f"no voxels in region - {e}"
)
continue
except Exception as e:
# Catch any other errors
logging.error(
f"Error processing label {label_abbrev} (index {label_idx}): {e}"
)
continue