diff --git a/docs/api.rst b/docs/api.rst index dba583af..ae0d78d2 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -8,3 +8,5 @@ plots though the napari user interface. .. automodapi:: napari_matplotlib .. automodapi:: napari_matplotlib.base + +.. automodapi:: napari_matplotlib.features diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 8c717d6a..5687895e 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -281,7 +281,9 @@ def __init__( napari_viewer: napari.viewer.Viewer, parent: Optional[QWidget] = None, ): - super().__init__(napari_viewer=napari_viewer, parent=parent) + NapariMPLWidget.__init__( + self, napari_viewer=napari_viewer, parent=parent + ) self.add_single_axes() def clear(self) -> None: diff --git a/src/napari_matplotlib/features.py b/src/napari_matplotlib/features.py new file mode 100644 index 00000000..3e1eb9ba --- /dev/null +++ b/src/napari_matplotlib/features.py @@ -0,0 +1,153 @@ +from typing import Any, Dict, List, Optional, Tuple + +import napari +import napari.layers +import numpy as np +import numpy.typing as npt +import pandas as pd +from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout + +from napari_matplotlib.base import NapariMPLWidget +from napari_matplotlib.util import Interval + +__all__ = ["FeaturesMixin"] + + +class FeaturesMixin(NapariMPLWidget): + """ + Mixin to help with widgets that plot data from a features table stored + in a single napari layer. + + This provides: + + - Setup for one or two combo boxes to select features to be plotted. + - An ``on_update_layers()`` callback that updates the combo box options + when the napari layer selection changes. + """ + + n_layers_input = Interval(1, 1) + # All layers that have a .features attributes + input_layer_types = ( + napari.layers.Labels, + napari.layers.Points, + napari.layers.Shapes, + napari.layers.Tracks, + napari.layers.Vectors, + ) + + def __init__(self, *, ndim: int) -> None: + """ + Parameters + ---------- + ndim : int + Number of dimensions that are plotted by the widget. + Must be 1 or 2. + """ + assert ndim in [1, 2] + self.dims = ["x", "y"][:ndim] + # Set up selection boxes + self.layout().addLayout(QVBoxLayout()) + + self._selectors: Dict[str, QComboBox] = {} + for dim in self.dims: + self._selectors[dim] = QComboBox() + # Re-draw when combo boxes are updated + self._selectors[dim].currentTextChanged.connect(self._draw) + + self.layout().addWidget(QLabel(f"{dim}-axis:")) + self.layout().addWidget(self._selectors[dim]) + + def get_key(self, dim: str) -> Optional[str]: + """ + Get key for a given dimension. + + Parameters + ---------- + dim : str + "x" or "y" + """ + if self._selectors[dim].count() == 0: + return None + else: + return self._selectors[dim].currentText() + + def set_key(self, dim: str, value: str) -> None: + """ + Set key for a given dimension. + + Parameters + ---------- + dim : str + "x" or "y" + value : str + Value to set. + """ + assert value in self._get_valid_axis_keys(), ( + "value must be on of the columns " + "of the feature table on the currently seleted layer" + ) + self._selectors[dim].setCurrentText(value) + self._draw() + + def _get_valid_axis_keys(self) -> List[str]: + """ + Get the valid axis keys from the features table column names. + + Returns + ------- + axis_keys : List[str] + The valid axis keys in the FeatureTable. If the table is empty + or there isn't a table, returns an empty list. + """ + if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): + return [] + else: + return self.layers[0].features.keys() + + def _ready_to_plot(self) -> bool: + """ + Return True if selected layer has a feature table we can plot with, + and the columns to plot have been selected. + """ + if not hasattr(self.layers[0], "features"): + return False + + feature_table = self.layers[0].features + valid_keys = self._get_valid_axis_keys() + return ( + feature_table is not None + and len(feature_table) > 0 + and all([self.get_key(dim) in valid_keys for dim in self.dims]) + ) + + def _get_data_names( + self, + ) -> Tuple[List[npt.NDArray[Any]], List[str]]: + """ + Get the plot data from the ``features`` attribute of the first + selected layer. + + Returns + ------- + data : List[np.ndarray] + List contains X and Y columns from the FeatureTable. Returns + an empty array if nothing to plot. + names : List[str] + Names for each axis. + """ + feature_table: pd.DataFrame = self.layers[0].features + + names = [str(self.get_key(dim)) for dim in self.dims] + data = [np.array(feature_table[key]) for key in names] + return data, names + + def on_update_layers(self) -> None: + """ + Called when the layer selection changes by ``self.update_layers()``. + """ + # Clear combobox + for dim in self.dims: + while self._selectors[dim].count() > 0: + self._selectors[dim].removeItem(0) + # Add keys for newly selected layer + self._selectors[dim].addItems(self._get_valid_axis_keys()) diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 334f941c..4fa45798 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Tuple import napari import numpy.typing as npt -from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget +from qtpy.QtWidgets import QWidget from .base import SingleAxesWidget +from .features import FeaturesMixin from .util import Interval __all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"] @@ -85,144 +86,27 @@ def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: return x, y, x_axis_name, y_axis_name -class FeaturesScatterWidget(ScatterBaseWidget): +class FeaturesScatterWidget(ScatterBaseWidget, FeaturesMixin): """ Widget to scatter data stored in two layer feature attributes. """ - n_layers_input = Interval(1, 1) - # All layers that have a .features attributes - input_layer_types = ( - napari.layers.Labels, - napari.layers.Points, - napari.layers.Shapes, - napari.layers.Tracks, - napari.layers.Vectors, - ) - def __init__( self, napari_viewer: napari.viewer.Viewer, parent: Optional[QWidget] = None, ): - super().__init__(napari_viewer, parent=parent) - - self.layout().addLayout(QVBoxLayout()) - - self._selectors: Dict[str, QComboBox] = {} - for dim in ["x", "y"]: - self._selectors[dim] = QComboBox() - # Re-draw when combo boxes are updated - self._selectors[dim].currentTextChanged.connect(self._draw) - - self.layout().addWidget(QLabel(f"{dim}-axis:")) - self.layout().addWidget(self._selectors[dim]) - + ScatterBaseWidget.__init__(self, napari_viewer, parent=parent) + FeaturesMixin.__init__(self, ndim=2) self._update_layers(None) - @property - def x_axis_key(self) -> Union[str, None]: - """ - Key for the x-axis data. - """ - if self._selectors["x"].count() == 0: - return None - else: - return self._selectors["x"].currentText() - - @x_axis_key.setter - def x_axis_key(self, key: str) -> None: - self._selectors["x"].setCurrentText(key) - self._draw() - - @property - def y_axis_key(self) -> Union[str, None]: - """ - Key for the y-axis data. - """ - if self._selectors["y"].count() == 0: - return None - else: - return self._selectors["y"].currentText() - - @y_axis_key.setter - def y_axis_key(self, key: str) -> None: - self._selectors["y"].setCurrentText(key) - self._draw() - - def _get_valid_axis_keys(self) -> List[str]: - """ - Get the valid axis keys from the layer FeatureTable. - - Returns - ------- - axis_keys : List[str] - The valid axis keys in the FeatureTable. If the table is empty - or there isn't a table, returns an empty list. - """ - if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): - return [] - else: - return self.layers[0].features.keys() - - def _ready_to_scatter(self) -> bool: - """ - Return True if selected layer has a feature table we can scatter with, - and the two columns to be scatterd have been selected. - """ - if not hasattr(self.layers[0], "features"): - return False - - feature_table = self.layers[0].features - valid_keys = self._get_valid_axis_keys() - return ( - feature_table is not None - and len(feature_table) > 0 - and self.x_axis_key in valid_keys - and self.y_axis_key in valid_keys - ) - def draw(self) -> None: """ Scatter two features from the currently selected layer. """ - if self._ready_to_scatter(): + if self._ready_to_plot(): super().draw() def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: - """ - Get the plot data from the ``features`` attribute of the first - selected layer. - - Returns - ------- - data : List[np.ndarray] - List contains X and Y columns from the FeatureTable. Returns - an empty array if nothing to plot. - x_axis_name : str - The title to display on the x axis. Returns - an empty string if nothing to plot. - y_axis_name: str - The title to display on the y axis. Returns - an empty string if nothing to plot. - """ - feature_table = self.layers[0].features - - x = feature_table[self.x_axis_key] - y = feature_table[self.y_axis_key] - - x_axis_name = str(self.x_axis_key) - y_axis_name = str(self.y_axis_key) - - return x, y, x_axis_name, y_axis_name - - def on_update_layers(self) -> None: - """ - Called when the layer selection changes by ``self.update_layers()``. - """ - # Clear combobox - for dim in ["x", "y"]: - while self._selectors[dim].count() > 0: - self._selectors[dim].removeItem(0) - # Add keys for newly selected layer - self._selectors[dim].addItems(self._get_valid_axis_keys()) + data, names = self._get_data_names() + return data[0], data[1], names[0], names[1] diff --git a/src/napari_matplotlib/tests/scatter/test_scatter_features.py b/src/napari_matplotlib/tests/scatter/test_scatter_features.py index c211a064..0b3f7638 100644 --- a/src/napari_matplotlib/tests/scatter/test_scatter_features.py +++ b/src/napari_matplotlib/tests/scatter/test_scatter_features.py @@ -25,8 +25,8 @@ def test_features_scatter_widget_2D( # Select points data and chosen features viewer.layers.selection.add(viewer.layers[0]) # images need to be selected - widget.x_axis_key = "feature_0" - widget.y_axis_key = "feature_1" + widget.set_key("x", "feature_0") + widget.set_key("y", "feature_1") fig = widget.figure @@ -64,9 +64,9 @@ def test_features_scatter_get_data(make_napari_viewer): viewer.layers.selection = [labels_layer] x_column = "feature_0" - scatter_widget.x_axis_key = x_column y_column = "feature_2" - scatter_widget.y_axis_key = y_column + scatter_widget.set_key("x", x_column) + scatter_widget.set_key("y", y_column) x, y, x_axis_name, y_axis_name = scatter_widget._get_data() np.testing.assert_allclose(x, feature_table[x_column])