Skip to content

Construct SPE from SpatialData #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ install_requires =
pillow>=11.0
requests


[options.packages.find]
where = src
exclude =
Expand All @@ -66,6 +65,9 @@ exclude =
# Add here additional requirements for extra features, to install with:
# `pip install SpatialExperiment[PDF]` like:
# PDF = ReportLab; RXP
extra =
spatialdata
anndata

# Add here test requirements (semicolon/line-separated)
testing =
Expand Down
103 changes: 103 additions & 0 deletions src/spatialexperiment/SpatialExperiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,109 @@ def mirror_img(self, sample_id=None, image_id=None, axis=("h", "v")):
def to_spatial_experiment():
raise NotImplementedError()

################################
####>> SpatialData interop <<###
################################

@classmethod
def from_spatialdata(cls, input: "spatialdata.SpatialData", points_key: str = "") -> "SpatialExperiment":
"""Create a ``SpatialExperiment`` from :py:class:`~spatialdata.SpatialData`.

When building ``SpatialExperiment``'s `img_data`, if the image is stored as a :py:class:`~xarray.DataArray`, the corresponding key will be used as the `sample_id`, `DataArray.name` will be used for the `image_id`, and `DataArray.attrs['scale_factor']` for the `scale_factor`.

For when the images are stored as a :py:class:`~xarray.DataTree`, see :py:func:`~spatialdata._sdatautils.build_img_data` for details.

**NOTE**: This is a lossy conversion. The resulting ``SpatialExperiment`` only preserves a subset of the data from the incoming `SpatialData` object.

Args:
input:
Input data.
points_key:
The key corresponding to the DataFrame that should be used for constructing spatial coordinates. Defaults to the first entry.

Returns:
A ``SpatialExperiment`` object.
"""
from spatialdata import SpatialData
from xarray import DataArray, DataTree
from ._sdatautils import build_img_data

if not isinstance(input, SpatialData):
raise TypeError("Input must be a `SpatialData` object.")

# validate that the incoming SpatialData can be converted to a SpatialExperiment
points = input.points
if points_key:
points_elem = points[points_key]
else:
points_elem = next(iter(points.values()))

adata = input.table
if adata.shape[0] != len(points_elem):
raise ValueError("Table and Points must have the same number of observations.")

sce = super().from_anndata(adata)

# build spatial coordinates
coords_2d = {'x', 'y'}
coords_3d = {'x', 'y', 'z'}

points_cols = set(points_elem.columns)
if coords_3d.issubset(points_cols):
coords_cols = list(coords_3d)
elif coords_2d.issubset(points_cols):
coords_cols = list(coords_2d)
else:
coords_cols = []

if coords_cols:
spatial_coords = points_elem[coords_cols].compute()
spatial_coords = BiocFrame.from_pandas(spatial_coords)

# build image data
images = input.images
img_data = BiocFrame(
{
"sample_id": [],
"image_id": [],
"data": [],
"scale_factor": []
}
)
for name, image in images.items():
if isinstance(image, DataArray):
curr_img = construct_spatial_image_class(np.array(image))
curr_scale_factor = [image.attrs["scale_factor"]] if "scale_factor" in image.attrs else [np.nan]
curr_img_data = BiocFrame({
"sample_id": [name],
"image_id": [image.name],
"data": [curr_img],
"scale_factor": curr_scale_factor
})
elif isinstance(image, DataTree):
curr_img_data = build_img_data(image, name)
else:
raise TypeError(f"Cannot build image data from {type(image)}")

img_data = img_data.combine_rows(curr_img_data)

return cls(
assays=sce.assays,
row_ranges=sce.row_ranges,
row_data=sce.row_data,
column_data=sce.col_data,
row_names=sce.row_names,
column_names=sce.column_names,
metadata=sce.metadata,
reduced_dims=sce.reduced_dims,
main_experiment_name=sce.main_experiment_name,
alternative_experiments=sce.alternative_experiments,
row_pairs=sce.row_pairs,
column_pairs=sce.column_pairs,
spatial_coords=spatial_coords,
img_data=img_data
)

################################
#######>> combine ops <<########
################################
75 changes: 75 additions & 0 deletions src/spatialexperiment/_sdatautils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np
from biocframe import BiocFrame
from xarray import DataArray, DataTree, Variable

from .SpatialImage import construct_spatial_image_class


def process_dataset_images(dt: DataTree, root_name: str) -> BiocFrame:
"""Processes image-related attributes from a :py:class:`~xarray.DataTree` object and compiles them into a :py:class:`~biocframe.BiocFrame`. The resulting BiocFrame adheres to the standards required for a ``SpatialExperiment``'s `img_data`.

Args:
dt: A DataTree object containing datasets with image data.
root_name: An identifier for the highest ancestor of the DataTree to which this subtree belongs.

Returns:
A BiocFrame that conforms to the standards of a ``SpatialExperiment``'s `img_data`.
"""
img_data = BiocFrame(
{
"sample_id": [],
"image_id": [],
"data": [],
"scale_factor": []
}
)
for var_name, obj in dt.dataset.items():
if isinstance(obj, (DataArray, Variable)):
var = obj
else:
dims, data, *optional = obj
attrs = optional[0] if optional else None
var = Variable(dims=dims, data=data, attrs=attrs)

scale_factor = var.attrs.get("scale_factor", np.nan)
spi = construct_spatial_image_class(np.array(var))
img_row = BiocFrame(
{
"sample_id": [f"{root_name}::{dt.name}"],
"image_id": [var_name],
"data": [spi],
"scale_factor": [scale_factor]
}
)
img_data = img_data.combine_rows(img_row)

return img_data


def build_img_data(dt: DataTree, root_name: str):
"""Recursively compiles image data from a :py:class:`~xarray.DataTree` into a :py:class:`~biocframe.BiocFrame.BiocFrame` structure.

This function traverses a `DataTree`, extracting image-related attributes from each dataset and compiling them into a `BiocFrame`. It processes the parent dataset and recursively handles dataset(s) from child nodes. The resulting `BiocFrame` adheres to the standards required for a ``SpatialExperiment``'s `img_data`.

The following conditions are assumed:
- `DataTree.name` will be used as the `sample_id`.
- The keys of `dt.dataset.data_vars` will be used as the `image_id`'s of each image.
- The `scale_factor` is extracted from the attributes of the objects in `dt.dataset.data_vars`.

Args:
dt: A DataTree object containing datasets with image data.
root_name: An identifier for the highest ancestor of the DataTree to which this subtree belongs.

Returns:
A BiocFrame containing compiled image data for the entire DataTree.
"""
if len(dt.children) == 0:
return process_dataset_images(dt, root_name)

parent_img_data = process_dataset_images(dt, root_name)

for key, child in dt.children.items():
child_img_data = build_img_data(child, root_name)
parent_img_data = parent_img_data.combine_rows(child_img_data)

return parent_img_data
42 changes: 41 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
from random import random
import numpy as np
from biocframe import BiocFrame
import anndata as ad
import spatialdata as sd
from spatialexperiment import SpatialExperiment, construct_spatial_image_class
from random import random
from spatialdata.models import Image2DModel, PointsModel


@pytest.fixture
Expand Down Expand Up @@ -70,3 +73,40 @@ def spe():
)

return spe_instance


@pytest.fixture
def sdata():
img = np.random.randint(0, 256, size=(50, 50, 3), dtype=np.uint8)
img = Image2DModel.parse(data=img)
img.name = "image01"
img.attrs['scale_factor'] = 1

num_cols = 25
x_coords = np.random.uniform(low=0.0, high=100.0, size=num_cols)
y_coords = np.random.uniform(low=0.0, high=100.0, size=num_cols)
stacked_coords = np.column_stack((x_coords, y_coords))
points = PointsModel.parse(stacked_coords)

n_vars = 10
X = np.random.random((num_cols, n_vars))
adata = ad.AnnData(X=X)

sdata = sd.SpatialData(
images={"sample01": img},
points={"coords": points},
tables=adata
)

return sdata


@pytest.fixture
def sdata_tree():
img_1 = np.random.randint(0, 256, size=(50, 50, 3), dtype=np.uint8)
img_1 = Image2DModel.parse(data=img_1)
img_1.name = "image01"

img_2 = np.random.randint(0, 256, size=(50, 50, 3), dtype=np.uint8)
img_2 = Image2DModel.parse(data=img_2)
img_2.name = "image02"
24 changes: 24 additions & 0 deletions tests/test_sdata_interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from spatialexperiment import SpatialExperiment

def test_from_sdata(sdata):
spe = SpatialExperiment.from_spatialdata(sdata)

assert isinstance(spe, SpatialExperiment)

table = sdata['table']
assert spe.shape == (table.shape[1], table.shape[0])

sdata_points = next(iter(sdata.points.values()))
assert spe.spatial_coords.shape == (len(sdata_points), sdata_points.shape[1])
assert sorted(spe.spatial_coords.columns.as_list()) == sorted(['x','y'])

assert spe.img_data.shape == (1, 4)
assert spe.img_data["sample_id"] == ["sample01"]
assert spe.img_data["image_id"] == ["image01"]
assert spe.img_data["scale_factor"] == [1]


def test_invalid_input():
with pytest.raises(TypeError):
SpatialExperiment.from_spatialdata("Not a SpatialData object!")
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ passenv =
SETUPTOOLS_*
extras =
testing
extra
commands =
pytest {posargs}

Expand Down
Loading