Skip to content

Commit a555bfc

Browse files
committed
ENH: Add from_image/from_header methods to bring logic out of tests
1 parent 9c4958d commit a555bfc

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

nibabel/coordimage.py

+40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import nibabel as nib
12
from nibabel.fileslice import fill_slicer
3+
import nibabel.pointset as ps
24

35

46
class CoordinateImage:
@@ -14,6 +16,22 @@ def __init__(self, data, coordaxis, header=None):
1416
self.coordaxis = coordaxis
1517
self.header = header
1618

19+
@classmethod
20+
def from_image(klass, img):
21+
coordaxis = CoordinateAxis.from_header(img.header)
22+
if isinstance(img, nib.Cifti2Image):
23+
if img.ndim != 2:
24+
raise ValueError("Can only interpret 2D images")
25+
for i in img.header.mapped_indices:
26+
if isinstance(img.header.get_axis(i), nib.cifti2.BrainModelAxis):
27+
break
28+
# Reinterpret data ordering based on location of coordinate axis
29+
data = img.dataobj.copy()
30+
data.order = ['F', 'C'][i]
31+
if i == 1:
32+
data._shape = data._shape[::-1]
33+
return klass(data, coordaxis, img.header)
34+
1735

1836
class CoordinateAxis:
1937
"""
@@ -81,6 +99,28 @@ def get_indices(self, parcel, indices=None):
8199
def __len__(self):
82100
return sum(len(parcel) for parcel in self.parcels)
83101

102+
# Hacky factory method for now
103+
@classmethod
104+
def from_header(klass, hdr):
105+
parcels = []
106+
if isinstance(hdr, nib.Cifti2Header):
107+
axes = [hdr.get_axis(i) for i in hdr.mapped_indices]
108+
for ax in axes:
109+
if isinstance(ax, nib.cifti2.BrainModelAxis):
110+
break
111+
else:
112+
raise ValueError("No BrainModelAxis, cannot create CoordinateAxis")
113+
for name, slicer, struct in ax.iter_structures():
114+
if struct.volume_shape:
115+
substruct = ps.NdGrid(struct.volume_shape, struct.affine)
116+
indices = struct.voxel
117+
else:
118+
substruct = None
119+
indices = struct.vertex
120+
parcels.append(Parcel(name, substruct, indices))
121+
122+
return klass(parcels)
123+
84124

85125
class Parcel:
86126
"""

nibabel/tests/test_coordimage.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,12 @@ def from_spec(klass, pathlike):
4949

5050
def test_Cifti2Image_as_CoordImage():
5151
ones = nb.load(CIFTI2_DATA / "ones.dscalar.nii")
52-
axes = [ones.header.get_axis(i) for i in range(ones.ndim)]
53-
54-
parcels = []
55-
for name, slicer, bma in axes[1].iter_structures():
56-
if bma.volume_shape:
57-
substruct = ps.NdGrid(bma.volume_shape, bma.affine)
58-
indices = bma.voxel
59-
else:
60-
substruct = None
61-
indices = bma.vertex
62-
parcels.append(ci.Parcel(name, None, indices))
63-
caxis = ci.CoordinateAxis(parcels)
64-
dobj = ones.dataobj.copy()
65-
dobj.order = 'C' # Hack for image with BMA as the last axis
66-
cimg = ci.CoordinateImage(dobj, caxis, ones.header)
52+
assert ones.shape == (1, 91282)
53+
cimg = ci.CoordinateImage.from_image(ones)
54+
assert cimg.shape == (91282, 1)
6755

56+
caxis = cimg.coordaxis
57+
assert len(caxis) == 91282
6858
assert caxis[...] is caxis
6959
assert caxis[:] is caxis
7060

0 commit comments

Comments
 (0)