11import json
22import numpy as np
33from spikeinterface .core import BaseRecording , Motion
4- from pynwb import get_class
4+ from spikeinterface .extractors .nwbextractors import _retrieve_electrical_series_pynwb
5+ from pynwb import get_class , NWBFile
6+ from pynwb .ecephys import ElectricalSeries
57
68
7- def from_spikeinterface (recording : BaseRecording , motion : Motion , ** kwargs ):
8- assert isinstance (recording , BaseRecording ), f"The input must be a SpikeInterface BaseRecording object, not { type (recording )} ."
9+ def from_spikeinterface (
10+ motion : Motion , nwbfile : NWBFile ,
11+ source_electrical_series : ElectricalSeries | str , electrodes ,
12+ ** kwargs
13+ ):
14+ assert isinstance (motion , Motion ), f"The input must be a SpikeInterface Motion object, not { type (motion )} ."
15+ assert isinstance (nwbfile , NWBFile ), f"The input must be a PyNWB NWBFile object, not { type (nwbfile )} ."
16+ assert isinstance (source_electrical_series , (ElectricalSeries , str )), f"The source_electrical_series must be either a PyNWB ElectricalSeries object or the path of an ElectricalSeries in the NWBFile, not { type (source_electrical_series )} ."
917
10- recording_tree = [recording ]
11- while recording_tree [- 1 ].get_parent () is not None :
12- recording_tree .append (recording_tree [- 1 ].get_parent ())
13- has_electrical_series = [hasattr (r , "electrical_series" ) for r in recording_tree ]
14- assert any (has_electrical_series ), "The recording must have an ElectricalSeries to link the DisplacementSeries to."
18+ if isinstance (source_electrical_series , str ):
19+ source_electrical_series = _retrieve_electrical_series_pynwb (nwbfile , source_electrical_series )
1520
16- assert isinstance (motion , Motion ), f"The input must be a SpikeInterface Motion object, not { type (motion )} ."
21+ if isinstance (electrodes , (list , np .ndarray )):
22+ electrodes = nwbfile .create_electrode_table_region (
23+ description = "Electrodes for which the motion drift was estimated." ,
24+ region = np .where (np .isin (nwbfile .electrodes ["channel_name" ].data [:], electrodes ))[0 ].tolist ()
25+ )
1726
18- recording_nwb = recording_tree [has_electrical_series .index (True )]
19- nwbfile = recording_nwb ._nwbfile
20-
2127 DisplacementSeries = get_class ("DisplacementSeries" , "ndx-motion-drift" )
2228
2329 displacement_series = DisplacementSeries (
@@ -32,11 +38,8 @@ def from_spikeinterface(recording: BaseRecording, motion: Motion, **kwargs):
3238 conversion = kwargs .get ("conversion" , 1.0 ),
3339 offset = kwargs .get ("offset" , 0.0 ),
3440 resolution = kwargs .get ("resolution" , - 1.0 ),
35- source_electricalseries = recording_nwb .electrical_series ,
36- electrodes = nwbfile .create_electrode_table_region (
37- description = "Electrodes for which the motion drift was estimated." ,
38- region = np .where (np .isin (nwbfile .electrodes ["channel_name" ].data [:], recording .channel_ids ))[0 ].tolist ()
39- ),
41+ source_electricalseries = source_electrical_series ,
42+ electrodes = electrodes ,
4043 )
4144
4245 return displacement_series
0 commit comments