Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,12 @@ def get_traces(self, trace_source='preprocessed', **kargs):
def get_contact_location(self):
location = self.analyzer.get_channel_locations()
return location

def get_channel_groups(self):
if self.has_extension("recording"):
return self.analyzer.recording.get_channel_groups()
else:
return np.zeros(self.analyzer.get_num_channels(), dtype=int)

def get_waveform_sweep(self):
return self.nbefore, self.nafter
Expand All @@ -717,7 +723,7 @@ def get_waveforms(self, unit_id):
wfs = self.waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False)
if self.analyzer.sparsity is None:
# dense waveforms
chan_inds = np.arange(self.analyzer.recording.get_num_channels(), dtype='int64')
chan_inds = np.arange(self.analyzer.get_num_channels(), dtype='int64')
else:
# sparse waveforms
chan_inds = self.analyzer.sparsity.unit_id_to_channel_indices[unit_id]
Expand Down
18 changes: 16 additions & 2 deletions spikeinterface_gui/tracemapview.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@ class TraceMapView(ViewBase, MixinViewTrace):

def __init__(self, controller=None, parent=None, backend="qt"):
pos = controller.get_contact_location()
self.channel_order = np.lexsort((-pos[:, 0], pos[:, 1], ))
channel_groups = controller.get_channel_groups()
self.channel_order = np.lexsort((-pos[:, 0], pos[:, 1], channel_groups))
self.channel_order_reverse = np.argsort(self.channel_order, kind="stable")
if len(np.unique(channel_groups)) > 1:
self.chan_group_offsets, = np.nonzero(np.diff(np.sort(channel_groups)))
self.chan_group_offsets = self.chan_group_offsets + 1
else:
self.chan_group_offsets = None
self.color_limit = None
self.last_data_curves = None
self.factor = None
Expand Down Expand Up @@ -163,6 +169,11 @@ def _qt_seek(self, t):
self.plot.setXRange(t1, t2, padding=0.0)
self.plot.setYRange(0, num_chans, padding=0.0)

if self.chan_group_offsets is not None:
for ch in self.chan_group_offsets:
hline = pg.InfiniteLine(pos=ch, angle=0, movable=False, pen=pg.mkPen("black"))
self.plot.addItem(hline)

def _qt_on_time_info_updated(self):
# Update segment and time slider range
time, segment_index = self.controller.get_time()
Expand Down Expand Up @@ -224,7 +235,7 @@ def _panel_make_layout(self):
self.figure.xaxis.major_tick_line_color = "white"
self.figure.yaxis.visible = False
self.figure.x_range = Range1d(start=0, end=0.5)
self.figure.y_range = Range1d(start=0, end=1)
self.figure.y_range = Range1d(start=0, end=self.controller.num_channels)


# Add data sources
Expand All @@ -245,6 +256,9 @@ def _panel_make_layout(self):
x="x", y="y", size=10, fill_color="color", fill_alpha=self.settings['alpha'], source=self.spike_source
)

if self.chan_group_offsets is not None:
self.figure.hspan(y=list(self.chan_group_offsets), line_color="yellow")

# # Add hover tool for spikes
# hover_spikes = HoverTool(renderers=[self.spike_renderer], tooltips=[("Unit", "@unit_id")])
# self.figure.add_tools(hover_spikes)
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface_gui/waveformview.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def _panel_clear_scalebars(self):
self.scalebar_labels = []

def _panel_add_scalebars(self):
from bokeh.models import Span, Label
from bokeh.models import Label

if not self.settings["x_scalebar"] and not self.settings["y_scalebar"]:
return
Expand Down