19
19
from plantseg .viewer .widget .segmentation import widget_agglomeration , widget_lifted_multicut , widget_simple_dt_ws
20
20
from plantseg .viewer .widget .utils import return_value_if_widget
21
21
from plantseg .viewer .widget .utils import start_threading_process , start_prediction_process , create_layer_name , layer_properties
22
- from plantseg .viewer .widget .validation import change_handler , get_image_volume_from_layer , widgets_inactive
22
+ from plantseg .viewer .widget .validation import _on_prediction_input_image_change , widgets_inactive
23
23
24
24
ALL_CUDA_DEVICES = [f'cuda:{ i } ' for i in range (torch .cuda .device_count ())]
25
25
MPS = ['mps' ] if torch .backends .mps .is_available () else []
@@ -109,24 +109,9 @@ def widget_unet_predictions(viewer: Viewer,
109
109
)
110
110
111
111
112
- @change_handler (widget_unet_predictions .image , init = False )
113
- def _image_change (image : Image ):
114
- shape = get_image_volume_from_layer (image ).shape
115
- ndim = len (shape )
116
- widget_unet_predictions .image .tooltip = f"Shape: { shape } "
117
-
118
- size_z = widget_unet_predictions .patch_size [0 ]
119
- halo_z = widget_unet_predictions .patch_halo [0 ]
120
- if ndim == 2 or (ndim == 3 and shape [0 ] == 1 ): # 2D image imported by Napari thus no Z, or by PlantSeg widget
121
- size_z .value = 0
122
- halo_z .value = 0
123
- widgets_inactive (size_z , halo_z , active = False )
124
- elif ndim == 3 and shape [0 ] > 1 : # 3D
125
- size_z .value = min (64 , shape [0 ]) # TODO: fetch model default
126
- halo_z .value = 8
127
- widgets_inactive (size_z , halo_z , active = True )
128
- else :
129
- raise ValueError (f"Unsupported number of dimensions: { ndim } " )
112
+ @widget_unet_predictions .image .changed .connect
113
+ def _on_widget_unet_predictions_image_change (image : Image ):
114
+ _on_prediction_input_image_change (widget_unet_predictions , image )
130
115
131
116
132
117
def _on_any_metadata_changed (dimensionality , modality , output_type ):
@@ -231,9 +216,14 @@ def on_done(result):
231
216
return future
232
217
233
218
234
- def _compute_iterative_predictions (pmap , model_name , num_iterations , sigma , patch_size , single_batch_mode , device ):
235
- func = partial (unet_predictions , model_name = model_name , patch = patch_size , single_batch_mode = single_batch_mode ,
236
- device = device )
219
+ @widget_test_all_unet_predictions .image .changed .connect
220
+ def _on_widget_test_all_unet_predictions_image_change (image : Image ):
221
+ _on_prediction_input_image_change (widget_test_all_unet_predictions , image )
222
+
223
+
224
+ def _compute_iterative_predictions (pmap , model_name , num_iterations , sigma , patch_size , patch_halo , single_batch_mode , device ):
225
+ func = partial (unet_predictions , model_name = model_name , patch = patch_size , patch_halo = patch_halo ,
226
+ single_batch_mode = single_batch_mode , device = device )
237
227
for i in range (num_iterations - 1 ):
238
228
pmap = func (pmap )
239
229
pmap = image_gaussian_smoothing (image = pmap , sigma = sigma )
@@ -258,6 +248,8 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc
258
248
'min' : 0. },
259
249
patch_size = {'label' : 'Patch size' ,
260
250
'tooltip' : 'Patch size use to processed the data.' },
251
+ patch_halo = {'label' : 'Patch halo' ,
252
+ 'tooltip' : 'Patch halo is extra padding for correct prediction on image boarder.' },
261
253
single_patch = {'label' : 'Single Patch' ,
262
254
'tooltip' : 'If True, a single patch will be processed at a time to save memory.' },
263
255
device = {'label' : 'Device' ,
@@ -268,6 +260,7 @@ def widget_iterative_unet_predictions(image: Image,
268
260
num_iterations : int = 2 ,
269
261
sigma : float = 1.0 ,
270
262
patch_size : Tuple [int , int , int ] = (80 , 170 , 170 ),
263
+ patch_halo : Tuple [int , int , int ] = (8 , 16 , 16 ),
271
264
single_patch : bool = True ,
272
265
device : str = ALL_DEVICES [0 ]) -> Future [LayerDataTuple ]:
273
266
out_name = create_layer_name (image .name , f'iterative-{ model_name } -x{ num_iterations } ' )
@@ -281,6 +274,7 @@ def widget_iterative_unet_predictions(image: Image,
281
274
num_iterations = num_iterations ,
282
275
sigma = sigma ,
283
276
patch_size = patch_size ,
277
+ patch_halo = patch_halo ,
284
278
single_batch_mode = single_patch ,
285
279
device = device )
286
280
@@ -303,6 +297,11 @@ def _on_model_name_changed_iterative(model_name: str):
303
297
widget_iterative_unet_predictions .patch_size .value = tuple (patch_size )
304
298
305
299
300
+ @widget_iterative_unet_predictions .image .changed .connect
301
+ def _on_widget_iterative_unet_predictions_image_change (image : Image ):
302
+ _on_prediction_input_image_change (widget_iterative_unet_predictions , image )
303
+
304
+
306
305
@magicgui (call_button = 'Add Custom Model' ,
307
306
new_model_name = {'label' : 'New model name' },
308
307
model_location = {'label' : 'Model location' ,
0 commit comments