Skip to content

Commit a1e407b

Browse files
Rename DriftCorrection to MotionDrift + add decorators + refactor from_spikeinterface()
1 parent f019736 commit a1e407b

File tree

4 files changed

+42
-24
lines changed

4 files changed

+42
-24
lines changed

spec/ndx-motion-drift.extensions.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ groups:
4242
target_type: ElectricalSeries
4343
doc: Link to the ElectricalSeries that is the source of the motion drift
4444
estimation.
45-
- neurodata_type_def: DriftCorrection
45+
- neurodata_type_def: MotionDrift
4646
neurodata_type_inc: NWBDataInterface
47+
default_name: MotionDrift
4748
doc: Displacement data from one or more channels. This can be used to store
4849
the output of motion drift estimation and correction algorithms.
4950
groups:

src/pynwb/ndx_motion_drift/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from importlib.resources import files
2-
from pynwb import load_namespaces, get_class
2+
from pynwb import load_namespaces, get_class, register_class
3+
from pynwb.core import MultiContainerInterface
34

45
# Get path to the namespace.yaml file with the expected location when installed not in editable mode
56
__location_of_this_file = files(__name__)
@@ -13,11 +14,23 @@
1314
load_namespaces(str(__spec_path))
1415

1516
DisplacementSeries = get_class("DisplacementSeries", "ndx-motion-drift")
16-
DriftCorrection = get_class("DriftCorrection", "ndx-motion-drift")
17+
@register_class("MotionDrift", "ndx-motion-drift")
18+
class MotionDrift(MultiContainerInterface):
19+
"""
20+
Displacement data from one or more channels. This can be used to store the output of motion drift estimation and correction algorithms.
21+
"""
22+
23+
__clsconf__ = [{
24+
'attr': 'displacement_series',
25+
'type': DisplacementSeries,
26+
'add': 'add_displacement_series',
27+
'get': 'get_displacement_series',
28+
'create': 'create_displacement_series',
29+
}]
1730

1831
__all__ = [
1932
"DisplacementSeries",
20-
"DriftCorrection",
33+
"MotionDrift",
2134
]
2235

2336
# Remove these functions/modules from the package

src/pynwb/ndx_motion_drift/io.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
import json
22
import numpy as np
33
from spikeinterface.core import BaseRecording, Motion
4-
from pynwb import get_class
4+
from spikeinterface.extractors.nwbextractors import _retrieve_electrical_series_pynwb
5+
from pynwb import get_class, NWBFile
6+
from pynwb.ecephys import ElectricalSeries
57

68

7-
def from_spikeinterface(recording: BaseRecording, motion: Motion, **kwargs):
8-
assert isinstance(recording, BaseRecording), f"The input must be a SpikeInterface BaseRecording object, not {type(recording)}."
9+
def from_spikeinterface(
10+
motion: Motion, nwbfile: NWBFile,
11+
source_electrical_series: ElectricalSeries | str, electrodes,
12+
**kwargs
13+
):
14+
assert isinstance(motion, Motion), f"The input must be a SpikeInterface Motion object, not {type(motion)}."
15+
assert isinstance(nwbfile, NWBFile), f"The input must be a PyNWB NWBFile object, not {type(nwbfile)}."
16+
assert isinstance(source_electrical_series, (ElectricalSeries, str)), f"The source_electrical_series must be either a PyNWB ElectricalSeries object or the path of an ElectricalSeries in the NWBFile, not {type(source_electrical_series)}."
917

10-
recording_tree = [recording]
11-
while recording_tree[-1].get_parent() is not None:
12-
recording_tree.append(recording_tree[-1].get_parent())
13-
has_electrical_series = [hasattr(r, "electrical_series") for r in recording_tree]
14-
assert any(has_electrical_series), "The recording must have an ElectricalSeries to link the DisplacementSeries to."
18+
if isinstance(source_electrical_series, str):
19+
source_electrical_series = _retrieve_electrical_series_pynwb(nwbfile, source_electrical_series)
1520

16-
assert isinstance(motion, Motion), f"The input must be a SpikeInterface Motion object, not {type(motion)}."
21+
if isinstance(electrodes, (list, np.ndarray)):
22+
electrodes = nwbfile.create_electrode_table_region(
23+
description="Electrodes for which the motion drift was estimated.",
24+
region=np.where(np.isin(nwbfile.electrodes["channel_name"].data[:], electrodes))[0].tolist()
25+
)
1726

18-
recording_nwb = recording_tree[has_electrical_series.index(True)]
19-
nwbfile = recording_nwb._nwbfile
20-
2127
DisplacementSeries = get_class("DisplacementSeries", "ndx-motion-drift")
2228

2329
displacement_series = DisplacementSeries(
@@ -32,11 +38,8 @@ def from_spikeinterface(recording: BaseRecording, motion: Motion, **kwargs):
3238
conversion=kwargs.get("conversion", 1.0),
3339
offset=kwargs.get("offset", 0.0),
3440
resolution=kwargs.get("resolution", -1.0),
35-
source_electricalseries=recording_nwb.electrical_series,
36-
electrodes=nwbfile.create_electrode_table_region(
37-
description="Electrodes for which the motion drift was estimated.",
38-
region=np.where(np.isin(nwbfile.electrodes["channel_name"].data[:], recording.channel_ids))[0].tolist()
39-
),
41+
source_electricalseries=source_electrical_series,
42+
electrodes=electrodes,
4043
)
4144

4245
return displacement_series

src/spec/create_extension_spec.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,10 @@ def main():
7272
],
7373
)
7474

75-
drift_correction = NWBGroupSpec(
76-
neurodata_type_def="DriftCorrection",
75+
motion_drift = NWBGroupSpec(
76+
neurodata_type_def="MotionDrift",
7777
neurodata_type_inc="NWBDataInterface",
78+
default_name="MotionDrift",
7879
doc="Displacement data from one or more channels. This can be used to store the output of motion drift estimation and correction algorithms.",
7980
groups=[
8081
NWBGroupSpec(
@@ -85,7 +86,7 @@ def main():
8586
]
8687
)
8788

88-
new_data_types = [displacement_series, drift_correction]
89+
new_data_types = [displacement_series, motion_drift]
8990

9091
# export the spec to yaml files in the root spec folder
9192
output_dir = str((Path(__file__).parent.parent.parent / "spec").absolute())

0 commit comments

Comments
 (0)