Skip to content

Commit 93f3da7

Browse files
authored
Merge pull request #43 from matplotlib/layer-avlid
Add layer number/type validation
2 parents e47ff04 + 5dbde28 commit 93f3da7

File tree

6 files changed

+78
-10
lines changed

6 files changed

+78
-10
lines changed

Diff for: src/napari_matplotlib/base.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from pathlib import Path
3+
from typing import Tuple
34

45
import matplotlib as mpl
56
import napari
@@ -10,6 +11,8 @@
1011
from qtpy.QtGui import QIcon
1112
from qtpy.QtWidgets import QVBoxLayout, QWidget
1213

14+
from .util import Interval
15+
1316
mpl.rc("axes", edgecolor="white")
1417
mpl.rc("axes", facecolor="#262930")
1518
mpl.rc("axes", labelcolor="white")
@@ -65,6 +68,11 @@ def __init__(self, napari_viewer: napari.viewer.Viewer):
6568

6669
self.setup_callbacks()
6770

71+
# Accept any number of input layers by default
72+
n_layers_input = Interval(None, None)
73+
# Accept any type of input layer by default
74+
input_layer_types: Tuple[napari.layers.Layer, ...] = (napari.layers.Layer,)
75+
6876
@property
6977
def n_selected_layers(self) -> int:
7078
"""
@@ -104,10 +112,10 @@ def _draw(self) -> None:
104112
figure if so.
105113
"""
106114
self.clear()
107-
if self.n_selected_layers != self.n_layers_input:
108-
self.canvas.draw()
109-
return
110-
self.draw()
115+
if self.n_selected_layers in self.n_layers_input and all(
116+
isinstance(layer, self.input_layer_types) for layer in self.layers
117+
):
118+
self.draw()
111119
self.canvas.draw()
112120

113121
def clear(self) -> None:

Diff for: src/napari_matplotlib/histogram.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import napari
88

9+
from .util import Interval
10+
911
_COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"}
1012

1113

@@ -14,7 +16,8 @@ class HistogramWidget(NapariMPLWidget):
1416
Display a histogram of the currently selected layer.
1517
"""
1618

17-
n_layers_input = 1
19+
n_layers_input = Interval(1, 1)
20+
input_layer_types = (napari.layers.Image,)
1821

1922
def __init__(self, napari_viewer: napari.viewer.Viewer):
2023
super().__init__(napari_viewer)

Diff for: src/napari_matplotlib/scatter.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from magicgui import magicgui
77

88
from .base import NapariMPLWidget
9+
from .util import Interval
910

1011
__all__ = ["ScatterWidget", "FeaturesScatterWidget"]
1112

@@ -84,7 +85,8 @@ class ScatterWidget(ScatterBaseWidget):
8485
of a scatter plot, to avoid too many scatter points.
8586
"""
8687

87-
n_layers_input = 2
88+
n_layers_input = Interval(2, 2)
89+
input_layer_types = (napari.layers.Image,)
8890

8991
def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
9092
"""Get the plot data.
@@ -106,7 +108,15 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
106108

107109

108110
class FeaturesScatterWidget(ScatterBaseWidget):
109-
n_layers_input = 1
111+
n_layers_input = Interval(1, 1)
112+
# All layers that have a .features attributes
113+
input_layer_types = (
114+
napari.layers.Labels,
115+
napari.layers.Points,
116+
napari.layers.Shapes,
117+
napari.layers.Tracks,
118+
napari.layers.Vectors,
119+
)
110120

111121
def __init__(self, napari_viewer: napari.viewer.Viewer):
112122
super().__init__(napari_viewer)
@@ -146,7 +156,8 @@ def _set_axis_keys(self, x_axis_key: str, y_axis_key: str):
146156
self._draw()
147157

148158
def _get_valid_axis_keys(self, combo_widget=None) -> List[str]:
149-
"""Get the valid axis keys from the layer FeatureTable.
159+
"""
160+
Get the valid axis keys from the layer FeatureTable.
150161
151162
Returns
152163
-------

Diff for: src/napari_matplotlib/slice.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import numpy as np
55
from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox
66

7-
from napari_matplotlib.base import NapariMPLWidget
7+
from .base import NapariMPLWidget
8+
from .util import Interval
89

910
__all__ = ["SliceWidget"]
1011

@@ -17,7 +18,8 @@ class SliceWidget(NapariMPLWidget):
1718
Plot a 1D slice along a given dimension.
1819
"""
1920

20-
n_layers_input = 1
21+
n_layers_input = Interval(1, 1)
22+
input_layer_types = (napari.layers.Image,)
2123

2224
def __init__(self, napari_viewer: napari.viewer.Viewer):
2325
# Setup figure/axes

Diff for: src/napari_matplotlib/tests/test_util.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from napari_matplotlib.util import Interval
4+
5+
6+
def test_interval():
7+
interval = Interval(4, 9)
8+
for i in range(4, 10):
9+
assert i in interval
10+
11+
assert 3 not in interval
12+
assert 10 not in interval
13+
14+
with pytest.raises(ValueError, match="must be an integer"):
15+
"string" in interval

Diff for: src/napari_matplotlib/util.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Optional
2+
3+
4+
class Interval:
5+
def __init__(self, lower_bound: Optional[int], upper_bound: Optional[int]):
6+
"""
7+
Parameters
8+
----------
9+
lower_bound, upper_bound:
10+
Bounds. Use `None` to specify an open bound.
11+
"""
12+
if (
13+
lower_bound is not None
14+
and upper_bound is not None
15+
and lower_bound > upper_bound
16+
):
17+
raise ValueError("lower_bound must be <= upper_bound")
18+
19+
self.lower = lower_bound
20+
self.upper = upper_bound
21+
22+
def __contains__(self, val):
23+
if not isinstance(val, int):
24+
raise ValueError("variable must be an integer")
25+
if self.lower is not None and val < self.lower:
26+
return False
27+
if self.upper is not None and val > self.upper:
28+
return False
29+
return True

0 commit comments

Comments
 (0)