Skip to content

Commit b65ae75

Browse files
committed
Improve prediction widget input validation and add halo to iterative prediction
1 parent 39e6b47 commit b65ae75

File tree

2 files changed

+50
-52
lines changed

2 files changed

+50
-52
lines changed

plantseg/viewer/widget/predictions.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws
2020
from plantseg.viewer.widget.utils import return_value_if_widget
2121
from plantseg.viewer.widget.utils import start_threading_process, start_prediction_process, create_layer_name, layer_properties
22-
from plantseg.viewer.widget.validation import change_handler, get_image_volume_from_layer, widgets_inactive
22+
from plantseg.viewer.widget.validation import _on_prediction_input_image_change, widgets_inactive
2323

2424
ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
2525
MPS = ['mps'] if torch.backends.mps.is_available() else []
@@ -109,24 +109,9 @@ def widget_unet_predictions(viewer: Viewer,
109109
)
110110

111111

112-
@change_handler(widget_unet_predictions.image, init=False)
113-
def _image_change(image: Image):
114-
shape = get_image_volume_from_layer(image).shape
115-
ndim = len(shape)
116-
widget_unet_predictions.image.tooltip = f"Shape: {shape}"
117-
118-
size_z = widget_unet_predictions.patch_size[0]
119-
halo_z = widget_unet_predictions.patch_halo[0]
120-
if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget
121-
size_z.value = 0
122-
halo_z.value = 0
123-
widgets_inactive(size_z, halo_z, active=False)
124-
elif ndim == 3 and shape[0] > 1: # 3D
125-
size_z.value = min(64, shape[0]) # TODO: fetch model default
126-
halo_z.value = 8
127-
widgets_inactive(size_z, halo_z, active=True)
128-
else:
129-
raise ValueError(f"Unsupported number of dimensions: {ndim}")
112+
@widget_unet_predictions.image.changed.connect
113+
def _on_widget_unet_predictions_image_change(image: Image):
114+
_on_prediction_input_image_change(widget_unet_predictions, image)
130115

131116

132117
def _on_any_metadata_changed(dimensionality, modality, output_type):
@@ -231,9 +216,14 @@ def on_done(result):
231216
return future
232217

233218

234-
def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, single_batch_mode, device):
235-
func = partial(unet_predictions, model_name=model_name, patch=patch_size, single_batch_mode=single_batch_mode,
236-
device=device)
219+
@widget_test_all_unet_predictions.image.changed.connect
220+
def _on_widget_test_all_unet_predictions_image_change(image: Image):
221+
_on_prediction_input_image_change(widget_test_all_unet_predictions, image)
222+
223+
224+
def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, patch_halo, single_batch_mode, device):
225+
func = partial(unet_predictions, model_name=model_name, patch=patch_size, patch_halo=patch_halo,
226+
single_batch_mode=single_batch_mode, device=device)
237227
for i in range(num_iterations - 1):
238228
pmap = func(pmap)
239229
pmap = image_gaussian_smoothing(image=pmap, sigma=sigma)
@@ -258,6 +248,8 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc
258248
'min': 0.},
259249
patch_size={'label': 'Patch size',
260250
'tooltip': 'Patch size use to processed the data.'},
251+
patch_halo={'label': 'Patch halo',
252+
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
261253
single_patch={'label': 'Single Patch',
262254
'tooltip': 'If True, a single patch will be processed at a time to save memory.'},
263255
device={'label': 'Device',
@@ -268,6 +260,7 @@ def widget_iterative_unet_predictions(image: Image,
268260
num_iterations: int = 2,
269261
sigma: float = 1.0,
270262
patch_size: Tuple[int, int, int] = (80, 170, 170),
263+
patch_halo: Tuple[int, int, int] = (8, 16, 16),
271264
single_patch: bool = True,
272265
device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]:
273266
out_name = create_layer_name(image.name, f'iterative-{model_name}-x{num_iterations}')
@@ -281,6 +274,7 @@ def widget_iterative_unet_predictions(image: Image,
281274
num_iterations=num_iterations,
282275
sigma=sigma,
283276
patch_size=patch_size,
277+
patch_halo=patch_halo,
284278
single_batch_mode=single_patch,
285279
device=device)
286280

@@ -303,6 +297,11 @@ def _on_model_name_changed_iterative(model_name: str):
303297
widget_iterative_unet_predictions.patch_size.value = tuple(patch_size)
304298

305299

300+
@widget_iterative_unet_predictions.image.changed.connect
301+
def _on_widget_iterative_unet_predictions_image_change(image: Image):
302+
_on_prediction_input_image_change(widget_iterative_unet_predictions, image)
303+
304+
306305
@magicgui(call_button='Add Custom Model',
307306
new_model_name={'label': 'New model name'},
308307
model_location={'label': 'Model location',

plantseg/viewer/widget/validation.py

+29-30
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,7 @@
11
"""Widget input validation"""
22

3-
from psygnal import Signal
4-
from functools import wraps
5-
6-
def change_handler(*widgets, init=True, debug=False):
7-
def decorator_change_handler(handler):
8-
@wraps(handler)
9-
def wrapper(*args):
10-
source = Signal.sender()
11-
emitter = Signal.current_emitter()
12-
if debug:
13-
# print(f"{emitter}: {source} = {args!r}")
14-
print(f"EVENT '{str(emitter.name)}': {source.name:>20} = {args!r}")
15-
# print(f" {source.name:>14}.value = {source.value}")
16-
return handler(*args)
17-
18-
for widget in widgets:
19-
widget.changed.connect(wrapper)
20-
if init:
21-
widget.changed(widget.value)
22-
return wrapper
23-
24-
return decorator_change_handler
25-
26-
27-
def get_image_volume_from_layer(image):
28-
"""Used for widget parameter validation in `change_handler`s."""
29-
image = image.data[0] if image.multiscale else image.data
30-
if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")):
31-
image = np.asanyarray(image)
32-
return image
3+
from napari.layers import Image
4+
from magicgui.widgets import Widget
335

346

357
def widgets_inactive(*widgets, active):
@@ -44,3 +16,30 @@ def widgets_valid(*widgets, valid):
4416
widget.native.setStyleSheet("" if valid else "background-color: lightcoral")
4517

4618

19+
def get_image_volume_from_layer(image):
20+
"""Used for widget parameter validation in change-handlers."""
21+
image = image.data[0] if image.multiscale else image.data
22+
if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")):
23+
from numpy import asanyarray
24+
25+
image = asanyarray(image)
26+
return image
27+
28+
29+
def _on_prediction_input_image_change(widget: Widget, image: Image):
30+
shape = get_image_volume_from_layer(image).shape
31+
ndim = len(shape)
32+
widget.image.tooltip = f"Shape: {shape}"
33+
34+
size_z = widget.patch_size[0]
35+
halo_z = widget.patch_halo[0]
36+
if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget
37+
size_z.value = 0
38+
halo_z.value = 0
39+
widgets_inactive(size_z, halo_z, active=False)
40+
elif ndim == 3 and shape[0] > 1: # 3D
41+
size_z.value = min(64, shape[0]) # TODO: fetch model default
42+
halo_z.value = 8
43+
widgets_inactive(size_z, halo_z, active=True)
44+
else:
45+
raise ValueError(f"Unsupported number of dimensions: {ndim}")

0 commit comments

Comments
 (0)