Skip to content

Commit 1d1411a

Browse files
authored
Merge PR #211 | Add validation, adaptation and halo to prediction Napari widgets
and add halo to CLI & Legacy GUI
2 parents 5d2a22c + c5ec9ee commit 1d1411a

File tree

9 files changed

+100
-18
lines changed

9 files changed

+100
-18
lines changed

plantseg/legacy_gui/gui_widgets.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None,
206206

207207

208208
class UnetPredictionFrame(ModuleFramePrototype):
209-
def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, show_all=True):
209+
def __init__(self, frame, config, col=0, module_name="prediction", font=None, show_all=True):
210210
self.prediction_frame = tkinter.Frame(frame)
211211
self.prediction_style = {
212212
"bg": "white",
@@ -278,10 +278,18 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None,
278278
type=int,
279279
font=font,
280280
),
281+
"patch_halo": ListEntry(
282+
self.prediction_frame,
283+
text="Patch Halo: ",
284+
row=5,
285+
column=0,
286+
type=int,
287+
font=font,
288+
),
281289
"device": MenuEntry(
282290
self.prediction_frame,
283291
text="Device Type: ",
284-
row=5,
292+
row=6,
285293
column=0,
286294
menu=["cuda", "cpu"],
287295
default=config[self.module]["device"],

plantseg/pipeline/raw2seg.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def configure_cnn_step(input_paths, config):
4141
device = config.get('device', 'cuda')
4242
state = config.get('state', True)
4343
model_update = config.get('model_update', False)
44+
patch_halo = config.get('patch_halo', None)
4445
return UnetPredictions(input_paths, model_name=model_name, input_key=input_key, input_channel=input_channel,
45-
patch=patch, stride_ratio=stride_ratio, device=device, model_update=model_update, state=state)
46+
patch=patch, stride_ratio=stride_ratio, device=device, model_update=model_update,
47+
state=state, patch_halo=patch_halo)
4648

4749

4850
def configure_cnn_postprocessing_step(input_paths, config):

plantseg/predictions/functional/predictions.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int]
3030
Defaults to 'cuda'.
3131
model_update (bool, optional): if True will update the model to the latest version. Defaults to False.
3232
disable_tqdm (bool, optional): if True will disable tqdm progress bar. Defaults to False.
33-
output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is
33+
output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is
3434
multi-channel 3D pmap. Now `4` only used in `widget_unet_predictions()`.
3535
3636
Returns:
@@ -45,7 +45,9 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int]
4545
state = state['model_state_dict']
4646
model.load_state_dict(state)
4747

48-
patch_halo = get_patch_halo(model_name)
48+
patch_halo = kwargs.get('patch_halo', None)
49+
if patch_halo is None:
50+
patch_halo = get_patch_halo(model_name)
4951
predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'],
5052
out_channels=model_config['out_channels'], device=device, patch=patch,
5153
patch_halo=patch_halo, single_batch_mode=single_batch_mode, headless=False,

plantseg/predictions/predict.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _check_patch_size(paths, patch_size):
3636

3737
class UnetPredictions(GenericPipelineStep):
3838
def __init__(self, input_paths, model_name, input_key=None, input_channel=None, patch=(80, 160, 160), stride_ratio=0.75, device='cuda',
39-
model_update=False, input_type="data_float32", output_type="data_float32", out_ext=".h5", state=True):
39+
model_update=False, input_type="data_float32", output_type="data_float32", out_ext=".h5", state=True, patch_halo=None):
4040
self.patch = patch
4141
self.model_name = model_name
4242
self.stride_ratio = stride_ratio
@@ -64,7 +64,8 @@ def __init__(self, input_paths, model_name, input_key=None, input_channel=None,
6464

6565
model.load_state_dict(state)
6666

67-
patch_halo = get_patch_halo(model_name)
67+
if patch_halo is None:
68+
patch_halo = get_patch_halo(model_name)
6869
is_embedding = not model_config.get('is_segmentation', True)
6970
self.predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'],
7071
out_channels=model_config['out_channels'], device=device, patch=self.patch,

plantseg/resources/config_predict_template.yaml

-3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,3 @@ loaders:
4242
- name: Standardize
4343
- name: ToTensor
4444
expand_dims: true
45-
46-
47-

plantseg/resources/config_train_example.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ training:
88
max_num_iters: 50000
99
dimensionality: 3D
1010
sparse: false
11-
device: cuda
11+
device: cuda

plantseg/resources/config_train_template.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,4 @@ loaders:
7777
# minimum volume of the labels in the patch
7878
threshold: 0.1
7979
# probability of accepting patches which do not fulfil the threshold criterion
80-
slack_acceptance: 0.01
80+
slack_acceptance: 0.01

plantseg/viewer/widget/predictions.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws
2020
from plantseg.viewer.widget.utils import return_value_if_widget
2121
from plantseg.viewer.widget.utils import start_threading_process, start_prediction_process, create_layer_name, layer_properties
22+
from plantseg.viewer.widget.validation import _on_prediction_input_image_change, widgets_inactive
2223

2324
ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
2425
MPS = ['mps'] if torch.backends.mps.is_available() else []
@@ -61,6 +62,8 @@ def unet_predictions_wrapper(raw, device, **kwargs):
6162
'choices': LIST_ALL_MODELS},
6263
patch_size={'label': 'Patch size',
6364
'tooltip': 'Patch size use to processed the data.'},
65+
patch_halo={'label': 'Patch halo',
66+
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
6467
single_patch={'label': 'Single Patch',
6568
'tooltip': 'If True, a single patch will be processed at a time to save memory.'},
6669
device={'label': 'Device',
@@ -73,6 +76,7 @@ def widget_unet_predictions(viewer: Viewer,
7376
modality: str = 'All',
7477
output_type: str = 'All',
7578
patch_size: Tuple[int, int, int] = (80, 170, 170),
79+
patch_halo: Tuple[int, int, int] = (8, 16, 16),
7680
single_patch: bool = True,
7781
device: str = ALL_DEVICES[0], ) -> Future[LayerDataTuple]:
7882
out_name = create_layer_name(image.name, model_name)
@@ -85,7 +89,7 @@ def widget_unet_predictions(viewer: Viewer,
8589
layer_kwargs['metadata']['pmap'] = True # this is used to warn the user that the layer is a pmap
8690

8791
layer_type = 'image'
88-
step_kwargs = dict(model_name=model_name, patch=patch_size, single_batch_mode=single_patch)
92+
step_kwargs = dict(model_name=model_name, patch=patch_size, patch_halo=patch_halo, single_batch_mode=single_patch)
8993

9094
return start_prediction_process(unet_predictions_wrapper,
9195
runtime_kwargs={'raw': image.data,
@@ -105,6 +109,11 @@ def widget_unet_predictions(viewer: Viewer,
105109
)
106110

107111

112+
@widget_unet_predictions.image.changed.connect
113+
def _on_widget_unet_predictions_image_change(image: Image):
114+
_on_prediction_input_image_change(widget_unet_predictions, image)
115+
116+
108117
def _on_any_metadata_changed(dimensionality, modality, output_type):
109118
dimensionality = [dimensionality] if dimensionality != 'All' else None
110119
modality = [modality] if modality != 'All' else None
@@ -152,7 +161,7 @@ def _on_model_name_changed(model_name: str):
152161
widget_unet_predictions.model_name.tooltip = f'Select a pretrained model. Current model description: {description}'
153162

154163

155-
def _compute_multiple_predictions(image, patch_size, device, use_custom_models=True):
164+
def _compute_multiple_predictions(image, patch_size, patch_halo, device, use_custom_models=True):
156165
out_layers = []
157166
model_list = list_models(use_custom_models=use_custom_models)
158167
for i, model_name in enumerate(model_list):
@@ -168,7 +177,7 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T
168177
layer_type = 'image'
169178
try:
170179
pmap = unet_predictions(raw=image.data, model_name=model_name, patch=patch_size, single_batch_mode=True,
171-
device=device)
180+
device=device, patch_halo=patch_halo)
172181
out_layers.append((pmap, layer_kwargs, layer_type))
173182

174183
except Exception as e:
@@ -182,18 +191,22 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T
182191
'tooltip': 'Raw image to be processed with a neural network.'},
183192
patch_size={'label': 'Patch size',
184193
'tooltip': 'Patch size use to processed the data.'},
194+
patch_halo={'label': 'Patch halo',
195+
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
185196
device={'label': 'Device',
186197
'choices': ALL_DEVICES},
187198
use_custom_models={'label': 'Use custom models',
188199
'tooltip': 'If True, custom models will also be used.'}
189200
)
190201
def widget_test_all_unet_predictions(image: Image,
191202
patch_size: Tuple[int, int, int] = (80, 170, 170),
203+
patch_halo: Tuple[int, int, int] = (2, 4, 4),
192204
device: str = ALL_DEVICES[0],
193205
use_custom_models: bool = True) -> Future[List[LayerDataTuple]]:
194206
func = thread_worker(partial(_compute_multiple_predictions,
195207
image=image,
196208
patch_size=patch_size,
209+
patch_halo=patch_halo,
197210
device=device,
198211
use_custom_models=use_custom_models,))
199212

@@ -208,9 +221,14 @@ def on_done(result):
208221
return future
209222

210223

211-
def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, single_batch_mode, device):
212-
func = partial(unet_predictions, model_name=model_name, patch=patch_size, single_batch_mode=single_batch_mode,
213-
device=device)
224+
@widget_test_all_unet_predictions.image.changed.connect
225+
def _on_widget_test_all_unet_predictions_image_change(image: Image):
226+
_on_prediction_input_image_change(widget_test_all_unet_predictions, image)
227+
228+
229+
def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, patch_halo, single_batch_mode, device):
230+
func = partial(unet_predictions, model_name=model_name, patch=patch_size, patch_halo=patch_halo,
231+
single_batch_mode=single_batch_mode, device=device)
214232
for i in range(num_iterations - 1):
215233
pmap = func(pmap)
216234
pmap = image_gaussian_smoothing(image=pmap, sigma=sigma)
@@ -235,6 +253,8 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc
235253
'min': 0.},
236254
patch_size={'label': 'Patch size',
237255
'tooltip': 'Patch size use to processed the data.'},
256+
patch_halo={'label': 'Patch halo',
257+
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
238258
single_patch={'label': 'Single Patch',
239259
'tooltip': 'If True, a single patch will be processed at a time to save memory.'},
240260
device={'label': 'Device',
@@ -245,6 +265,7 @@ def widget_iterative_unet_predictions(image: Image,
245265
num_iterations: int = 2,
246266
sigma: float = 1.0,
247267
patch_size: Tuple[int, int, int] = (80, 170, 170),
268+
patch_halo: Tuple[int, int, int] = (8, 16, 16),
248269
single_patch: bool = True,
249270
device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]:
250271
out_name = create_layer_name(image.name, f'iterative-{model_name}-x{num_iterations}')
@@ -258,6 +279,7 @@ def widget_iterative_unet_predictions(image: Image,
258279
num_iterations=num_iterations,
259280
sigma=sigma,
260281
patch_size=patch_size,
282+
patch_halo=patch_halo,
261283
single_batch_mode=single_patch,
262284
device=device)
263285

@@ -280,6 +302,11 @@ def _on_model_name_changed_iterative(model_name: str):
280302
widget_iterative_unet_predictions.patch_size.value = tuple(patch_size)
281303

282304

305+
@widget_iterative_unet_predictions.image.changed.connect
306+
def _on_widget_iterative_unet_predictions_image_change(image: Image):
307+
_on_prediction_input_image_change(widget_iterative_unet_predictions, image)
308+
309+
283310
@magicgui(call_button='Add Custom Model',
284311
new_model_name={'label': 'New model name'},
285312
model_location={'label': 'Model location',

plantseg/viewer/widget/validation.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Widget input validation"""
2+
3+
from napari.layers import Image
4+
from magicgui.widgets import Widget
5+
6+
7+
def widgets_inactive(*widgets, active):
8+
"""Toggle visibility of widgets."""
9+
for widget in widgets:
10+
widget.visible = active
11+
12+
13+
def widgets_valid(*widgets, valid):
14+
"""Toggle background warning color of widgets."""
15+
for widget in widgets:
16+
widget.native.setStyleSheet("" if valid else "background-color: lightcoral")
17+
18+
19+
def get_image_volume_from_layer(image):
20+
"""Used for widget parameter validation in change-handlers."""
21+
image = image.data[0] if image.multiscale else image.data
22+
if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")):
23+
from numpy import asanyarray
24+
25+
image = asanyarray(image)
26+
return image
27+
28+
29+
def _on_prediction_input_image_change(widget: Widget, image: Image):
30+
shape = get_image_volume_from_layer(image).shape
31+
ndim = len(shape)
32+
widget.image.tooltip = f"Shape: {shape}"
33+
34+
size_z = widget.patch_size[0]
35+
halo_z = widget.patch_halo[0]
36+
if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget
37+
size_z.value = 0
38+
halo_z.value = 0
39+
widgets_inactive(size_z, halo_z, active=False)
40+
elif ndim == 3 and shape[0] > 1: # 3D
41+
size_z.value = min(64, shape[0]) # TODO: fetch model default
42+
halo_z.value = 8
43+
widgets_inactive(size_z, halo_z, active=True)
44+
else:
45+
raise ValueError(f"Unsupported number of dimensions: {ndim}")

0 commit comments

Comments
 (0)