Skip to content

Commit f552464

Browse files
Make panel backend work with multi-thread/multi-process (#240)
* wip - bidirectional communication * Add iframe detection and bi-directional curation communication * Send loaded=true to parent when ready and clean up set curation * Fix mergeview when include delete is checked * Panel fixes for multi-process backend * Modify panel refresh to schedule updates (for multi-proc/multi-threds serves) * oups * Select next pair after accepting a merge * Fixes to panel to work with multi-threads/processes * cleanup * Remove set_curation_data * use mode from settings Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> * Move enable/disable active scroll * Avoid double refresh in SelectableTabulator and fix sorting of spikelist view * Speed up panel scalebars and remove iframe test * Move on_selection_changed callback to SelectableTabulator * Remove iframe-y stuff * Set self.curation=True after curation is successfully loaded * revert self.curation changes * revert self.curation changes1 * remove comment and _updating_from_controller * Remove listener_pane * Remove submit and set curation functions * extra indent --------- Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com>
1 parent 8bfe93d commit f552464

18 files changed

Lines changed: 508 additions & 763 deletions

spikeinterface_gui/basescatterview.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class BaseScatterView(ViewBase):
1717
{'name': 'display_high_percentiles', 'type': 'float', 'value' : 98.0, 'limits':(50, 100), 'step':0.5},
1818
]
1919
_need_compute = False
20-
20+
2121
def __init__(self, spike_data, y_label, controller=None, parent=None, backend="qt"):
2222

2323
# compute data bounds
@@ -191,6 +191,9 @@ def on_unit_visibility_changed(self):
191191
self._current_selected = self.controller.get_indices_spike_selected().size
192192
self.refresh(set_scatter_range=True)
193193

194+
def on_spike_selection_changed(self):
195+
self.refresh()
196+
194197
def on_use_times_updated(self):
195198
self.refresh(set_scatter_range=True)
196199

@@ -322,7 +325,7 @@ def _qt_refresh(self, set_scatter_range=False):
322325

323326
# set x range to time range of the current segment for scatter, and max count for histogram
324327
# set y range to min and max of visible spike amplitudes
325-
if set_scatter_range or not self._first_refresh_done:
328+
if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done):
326329
ymin = np.min(ymins)
327330
ymax = np.max(ymaxs)
328331
t_start, t_stop = self.controller.get_t_start_t_stop()
@@ -498,6 +501,7 @@ def _panel_make_layout(self):
498501
self.plotted_inds = []
499502

500503
def _panel_refresh(self, set_scatter_range=False):
504+
import panel as pn
501505
from bokeh.models import FixedTicker
502506

503507
self.plotted_inds = []
@@ -565,28 +569,45 @@ def _panel_refresh(self, set_scatter_range=False):
565569
# handle selected spikes
566570
self._panel_update_selected_spikes()
567571

568-
# set y range to min and max of visible spike amplitudes
572+
# Defer Range updates to avoid nested document lock issues
573+
# def update_ranges():
569574
if set_scatter_range or not self._first_refresh_done:
570575
self.y_range.start = np.min(ymins)
571576
self.y_range.end = np.max(ymaxs)
572577
self._first_refresh_done = True
573578
self.hist_fig.x_range.end = max_count
574579
self.hist_fig.xaxis.ticker = FixedTicker(ticks=[0, max_count // 2, max_count])
575580

581+
# Schedule the update to run after the current event loop iteration
582+
# pn.state.execute(update_ranges, schedule=True)
583+
576584
def _panel_on_select_button(self, event):
577-
if self.select_toggle_button.value:
578-
self.scatter_fig.toolbar.active_drag = self.lasso_tool
579-
else:
580-
self.scatter_fig.toolbar.active_drag = None
581-
self.scatter_source.selected.indices = []
585+
import panel as pn
586+
587+
value = self.select_toggle_button.value
588+
589+
def _do_update():
590+
if value:
591+
self.scatter_fig.toolbar.active_drag = self.lasso_tool
592+
else:
593+
self.scatter_fig.toolbar.active_drag = None
594+
self.scatter_source.selected.indices = []
595+
596+
pn.state.execute(_do_update, schedule=True)
582597

583598
def _panel_change_segment(self, event):
599+
import panel as pn
600+
584601
self._current_selected = 0
585602
segment_index = int(self.segment_selector.value.split()[-1])
586603
self.controller.set_time(segment_index=segment_index)
587604
t_start, t_end = self.controller.get_t_start_t_stop()
588-
self.scatter_fig.x_range.start = t_start
589-
self.scatter_fig.x_range.end = t_end
605+
606+
def _do_update():
607+
self.scatter_fig.x_range.start = t_start
608+
self.scatter_fig.x_range.end = t_end
609+
610+
pn.state.execute(_do_update, schedule=True)
590611
self.refresh(set_scatter_range=True)
591612
self.notify_time_info_updated()
592613

@@ -628,9 +649,17 @@ def _panel_split(self, event):
628649
self.split()
629650

630651
def _panel_update_selected_spikes(self):
652+
import panel as pn
653+
631654
# handle selected spikes
632655
selected_spike_indices = self.controller.get_indices_spike_selected()
633656
selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds)
657+
if len(selected_spike_indices) == 1:
658+
selected_segment = self.controller.spikes[selected_spike_indices[0]]['segment_index']
659+
segment_index = self.controller.get_time()[1]
660+
if selected_segment != segment_index:
661+
self.segment_selector.value = f"Segment {selected_segment}"
662+
self._panel_change_segment(None)
634663
if len(selected_spike_indices) > 0:
635664
# map absolute indices to visible spikes
636665
segment_index = self.controller.get_time()[1]
@@ -644,23 +673,16 @@ def _panel_update_selected_spikes(self):
644673
# set selected spikes in scatter plot
645674
if self.settings["auto_decimate"] and len(selected_indices) > 0:
646675
selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices))
647-
self.scatter_source.selected.indices = list(selected_indices)
648676
else:
649-
self.scatter_source.selected.indices = []
677+
selected_indices = []
678+
679+
def _do_update():
680+
self.scatter_source.selected.indices = list(selected_indices)
681+
682+
pn.state.execute(_do_update, schedule=True)
650683

651684
def _panel_on_spike_selection_changed(self):
652-
# set selection in scatter plot
653-
selected_indices = self.controller.get_indices_spike_selected()
654-
if len(selected_indices) == 0:
655-
self.scatter_source.selected.indices = []
656-
return
657-
elif len(selected_indices) == 1:
658-
selected_segment = self.controller.spikes[selected_indices[0]]['segment_index']
659-
segment_index = self.controller.get_time()[1]
660-
if selected_segment != segment_index:
661-
self.segment_selector.value = f"Segment {selected_segment}"
662-
self._panel_change_segment(None)
663-
# update selected spikes
685+
# update selected spikes (scheduled via pn.state.execute inside)
664686
self._panel_update_selected_spikes()
665687

666688
def _panel_handle_shortcut(self, event):

spikeinterface_gui/controller.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from spikeinterface.widgets.utils import get_unit_colors
1010
from spikeinterface import compute_sparsity
1111
from spikeinterface.core import get_template_extremum_channel
12-
import spikeinterface.postprocessing
13-
import spikeinterface.qualitymetrics
1412
from spikeinterface.core.sorting_tools import spike_vector_to_indices
15-
from spikeinterface.core.core_tools import check_json
1613
from spikeinterface.curation import validate_curation_dict
1714
from spikeinterface.curation.curation_model import CurationModel
1815
from spikeinterface.widgets.utils import make_units_table_from_analyzer
@@ -340,7 +337,6 @@ def __init__(
340337
self.update_time_info()
341338

342339
self.curation = curation
343-
# TODO: Reload the dictionary if it already exists
344340
if self.curation:
345341
# rules:
346342
# * if user sends curation_data, then it is used
@@ -349,6 +345,7 @@ def __init__(
349345

350346
if curation_data is not None:
351347
# validate the curation data
348+
curation_data = deepcopy(curation_data)
352349
format_version = curation_data.get("format_version", None)
353350
# assume version 2 if not present
354351
if format_version is None:
@@ -358,24 +355,6 @@ def __init__(
358355
except Exception as e:
359356
raise ValueError(f"Invalid curation data.\nError: {e}")
360357

361-
if curation_data.get("merges") is None:
362-
curation_data["merges"] = []
363-
else:
364-
# here we reset the merges for better formatting (str)
365-
existing_merges = curation_data["merges"]
366-
new_merges = []
367-
for m in existing_merges:
368-
if "unit_ids" not in m:
369-
continue
370-
if len(m["unit_ids"]) < 2:
371-
continue
372-
new_merges = add_merge(new_merges, m["unit_ids"])
373-
curation_data["merges"] = new_merges
374-
if curation_data.get("splits") is None:
375-
curation_data["splits"] = []
376-
if curation_data.get("removed") is None:
377-
curation_data["removed"] = []
378-
379358
elif self.analyzer.format == "binary_folder":
380359
json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json"
381360
if json_file.exists():
@@ -390,24 +369,27 @@ def __init__(
390369

391370
if curation_data is None:
392371
curation_data = deepcopy(empty_curation_data)
372+
curation_data["unit_ids"] = self.unit_ids.tolist()
393373

394-
self.curation_data = curation_data
395-
396-
self.has_default_quality_labels = False
397-
if "label_definitions" not in self.curation_data:
374+
if "label_definitions" not in curation_data:
398375
if label_definitions is not None:
399-
self.curation_data["label_definitions"] = label_definitions
376+
curation_data["label_definitions"] = label_definitions
400377
else:
401-
self.curation_data["label_definitions"] = default_label_definitions.copy()
378+
curation_data["label_definitions"] = default_label_definitions.copy()
402379

403-
if "quality" in self.curation_data["label_definitions"]:
404-
curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"]
380+
# This will enable the default shortcuts if has default quality labels
381+
self.has_default_quality_labels = False
382+
if "quality" in curation_data["label_definitions"]:
383+
curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"]
405384
default_quality_labels = default_label_definitions["quality"]["label_options"]
406385
if set(curation_dict_quality_labels) == set(default_quality_labels):
407386
if self.verbose:
408387
print('Curation quality labels are the default ones')
409388
self.has_default_quality_labels = True
410389

390+
curation_data = CurationModel(**curation_data).model_dump()
391+
self.curation_data = curation_data
392+
411393
def check_is_view_possible(self, view_name):
412394
from .viewlist import get_all_possible_views
413395
possible_class_views = get_all_possible_views()
@@ -849,7 +831,7 @@ def compute_auto_merge(self, **params):
849831
)
850832

851833
return merge_unit_groups, extra
852-
834+
853835
def curation_can_be_saved(self):
854836
return self.analyzer.format != "memory"
855837

0 commit comments

Comments
 (0)