Skip to content

Commit e47ff04

Browse files
authored
Merge pull request #47 from matplotlib/scatter-tidy
Tidy up scatter code
2 parents 081d4b6 + 359b104 commit e47ff04

File tree

1 file changed

+22
-37
lines changed

1 file changed

+22
-37
lines changed

Diff for: src/napari_matplotlib/scatter.py

+22-37
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Union
1+
from typing import List, Optional, Tuple
22

33
import matplotlib.colors as mcolor
44
import napari
@@ -22,21 +22,21 @@ class ScatterBaseWidget(NapariMPLWidget):
2222
# the scatter is plotted as a 2dhist
2323
_threshold_to_switch_to_histogram = 500
2424

25-
def __init__(
26-
self,
27-
napari_viewer: napari.viewer.Viewer,
28-
):
25+
def __init__(self, napari_viewer: napari.viewer.Viewer):
2926
super().__init__(napari_viewer)
3027

3128
self.axes = self.canvas.figure.subplots()
3229
self.update_layers(None)
3330

3431
def clear(self) -> None:
32+
"""
33+
Clear the axes.
34+
"""
3535
self.axes.clear()
3636

3737
def draw(self) -> None:
3838
"""
39-
Clear the axes and scatter the currently selected layers.
39+
Scatter the currently selected layers.
4040
"""
4141
data, x_axis_name, y_axis_name = self._get_data()
4242

@@ -86,14 +86,6 @@ class ScatterWidget(ScatterBaseWidget):
8686

8787
n_layers_input = 2
8888

89-
def __init__(
90-
self,
91-
napari_viewer: napari.viewer.Viewer,
92-
):
93-
super().__init__(
94-
napari_viewer,
95-
)
96-
9789
def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
9890
"""Get the plot data.
9991
@@ -116,42 +108,34 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
116108
class FeaturesScatterWidget(ScatterBaseWidget):
117109
n_layers_input = 1
118110

119-
def __init__(
120-
self,
121-
napari_viewer: napari.viewer.Viewer,
122-
key_selection_gui: bool = True,
123-
):
124-
self._key_selection_widget = None
125-
super().__init__(
126-
napari_viewer,
111+
def __init__(self, napari_viewer: napari.viewer.Viewer):
112+
super().__init__(napari_viewer)
113+
self._key_selection_widget = magicgui(
114+
self._set_axis_keys,
115+
x_axis_key={"choices": self._get_valid_axis_keys},
116+
y_axis_key={"choices": self._get_valid_axis_keys},
117+
call_button="plot",
127118
)
128119

129-
if key_selection_gui is True:
130-
self._key_selection_widget = magicgui(
131-
self._set_axis_keys,
132-
x_axis_key={"choices": self._get_valid_axis_keys},
133-
y_axis_key={"choices": self._get_valid_axis_keys},
134-
call_button="plot",
135-
)
136-
self.layout().addWidget(self._key_selection_widget.native)
120+
self.layout().addWidget(self._key_selection_widget.native)
137121

138122
@property
139-
def x_axis_key(self) -> Union[None, str]:
123+
def x_axis_key(self) -> Optional[str]:
140124
"""Key to access x axis data from the FeaturesTable"""
141125
return self._x_axis_key
142126

143127
@x_axis_key.setter
144-
def x_axis_key(self, key: Union[None, str]):
128+
def x_axis_key(self, key: Optional[str]):
145129
self._x_axis_key = key
146130
self._draw()
147131

148132
@property
149-
def y_axis_key(self) -> Union[None, str]:
133+
def y_axis_key(self) -> Optional[str]:
150134
"""Key to access y axis data from the FeaturesTable"""
151135
return self._y_axis_key
152136

153137
@y_axis_key.setter
154-
def y_axis_key(self, key: Union[None, str]):
138+
def y_axis_key(self, key: Optional[str]):
155139
self._y_axis_key = key
156140
self._draw()
157141

@@ -214,10 +198,11 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
214198
return data, x_axis_name, y_axis_name
215199

216200
def _on_update_layers(self) -> None:
217-
"""This is called when the layer selection changes
218-
by self.update_layers().
219201
"""
220-
if self._key_selection_widget is not None:
202+
This is called when the layer selection changes by
203+
``self.update_layers()``.
204+
"""
205+
if hasattr(self, "_key_selection_widget"):
221206
self._key_selection_widget.reset_choices()
222207

223208
# reset the axis keys

0 commit comments

Comments
 (0)