19
19
from plantseg .viewer .widget .segmentation import widget_agglomeration , widget_lifted_multicut , widget_simple_dt_ws
20
20
from plantseg .viewer .widget .utils import return_value_if_widget
21
21
from plantseg .viewer .widget .utils import start_threading_process , start_prediction_process , create_layer_name , layer_properties
22
+ from plantseg .viewer .widget .validation import _on_prediction_input_image_change , widgets_inactive
22
23
23
24
ALL_CUDA_DEVICES = [f'cuda:{ i } ' for i in range (torch .cuda .device_count ())]
24
25
MPS = ['mps' ] if torch .backends .mps .is_available () else []
@@ -61,6 +62,8 @@ def unet_predictions_wrapper(raw, device, **kwargs):
61
62
'choices' : LIST_ALL_MODELS },
62
63
patch_size = {'label' : 'Patch size' ,
63
64
'tooltip' : 'Patch size use to processed the data.' },
65
+ patch_halo = {'label' : 'Patch halo' ,
66
+ 'tooltip' : 'Patch halo is extra padding for correct prediction on image boarder.' },
64
67
single_patch = {'label' : 'Single Patch' ,
65
68
'tooltip' : 'If True, a single patch will be processed at a time to save memory.' },
66
69
device = {'label' : 'Device' ,
@@ -73,6 +76,7 @@ def widget_unet_predictions(viewer: Viewer,
73
76
modality : str = 'All' ,
74
77
output_type : str = 'All' ,
75
78
patch_size : Tuple [int , int , int ] = (80 , 170 , 170 ),
79
+ patch_halo : Tuple [int , int , int ] = (8 , 16 , 16 ),
76
80
single_patch : bool = True ,
77
81
device : str = ALL_DEVICES [0 ], ) -> Future [LayerDataTuple ]:
78
82
out_name = create_layer_name (image .name , model_name )
@@ -85,7 +89,7 @@ def widget_unet_predictions(viewer: Viewer,
85
89
layer_kwargs ['metadata' ]['pmap' ] = True # this is used to warn the user that the layer is a pmap
86
90
87
91
layer_type = 'image'
88
- step_kwargs = dict (model_name = model_name , patch = patch_size , single_batch_mode = single_patch )
92
+ step_kwargs = dict (model_name = model_name , patch = patch_size , patch_halo = patch_halo , single_batch_mode = single_patch )
89
93
90
94
return start_prediction_process (unet_predictions_wrapper ,
91
95
runtime_kwargs = {'raw' : image .data ,
@@ -105,6 +109,11 @@ def widget_unet_predictions(viewer: Viewer,
105
109
)
106
110
107
111
112
+ @widget_unet_predictions .image .changed .connect
113
+ def _on_widget_unet_predictions_image_change (image : Image ):
114
+ _on_prediction_input_image_change (widget_unet_predictions , image )
115
+
116
+
108
117
def _on_any_metadata_changed (dimensionality , modality , output_type ):
109
118
dimensionality = [dimensionality ] if dimensionality != 'All' else None
110
119
modality = [modality ] if modality != 'All' else None
@@ -152,7 +161,7 @@ def _on_model_name_changed(model_name: str):
152
161
widget_unet_predictions .model_name .tooltip = f'Select a pretrained model. Current model description: { description } '
153
162
154
163
155
- def _compute_multiple_predictions (image , patch_size , device , use_custom_models = True ):
164
+ def _compute_multiple_predictions (image , patch_size , patch_halo , device , use_custom_models = True ):
156
165
out_layers = []
157
166
model_list = list_models (use_custom_models = use_custom_models )
158
167
for i , model_name in enumerate (model_list ):
@@ -168,7 +177,7 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T
168
177
layer_type = 'image'
169
178
try :
170
179
pmap = unet_predictions (raw = image .data , model_name = model_name , patch = patch_size , single_batch_mode = True ,
171
- device = device )
180
+ device = device , patch_halo = patch_halo )
172
181
out_layers .append ((pmap , layer_kwargs , layer_type ))
173
182
174
183
except Exception as e :
@@ -182,18 +191,22 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T
182
191
'tooltip' : 'Raw image to be processed with a neural network.' },
183
192
patch_size = {'label' : 'Patch size' ,
184
193
'tooltip' : 'Patch size use to processed the data.' },
194
+ patch_halo = {'label' : 'Patch halo' ,
195
+ 'tooltip' : 'Patch halo is extra padding for correct prediction on image boarder.' },
185
196
device = {'label' : 'Device' ,
186
197
'choices' : ALL_DEVICES },
187
198
use_custom_models = {'label' : 'Use custom models' ,
188
199
'tooltip' : 'If True, custom models will also be used.' }
189
200
)
190
201
def widget_test_all_unet_predictions (image : Image ,
191
202
patch_size : Tuple [int , int , int ] = (80 , 170 , 170 ),
203
+ patch_halo : Tuple [int , int , int ] = (2 , 4 , 4 ),
192
204
device : str = ALL_DEVICES [0 ],
193
205
use_custom_models : bool = True ) -> Future [List [LayerDataTuple ]]:
194
206
func = thread_worker (partial (_compute_multiple_predictions ,
195
207
image = image ,
196
208
patch_size = patch_size ,
209
+ patch_halo = patch_halo ,
197
210
device = device ,
198
211
use_custom_models = use_custom_models ,))
199
212
@@ -208,9 +221,14 @@ def on_done(result):
208
221
return future
209
222
210
223
211
- def _compute_iterative_predictions (pmap , model_name , num_iterations , sigma , patch_size , single_batch_mode , device ):
212
- func = partial (unet_predictions , model_name = model_name , patch = patch_size , single_batch_mode = single_batch_mode ,
213
- device = device )
224
+ @widget_test_all_unet_predictions .image .changed .connect
225
+ def _on_widget_test_all_unet_predictions_image_change (image : Image ):
226
+ _on_prediction_input_image_change (widget_test_all_unet_predictions , image )
227
+
228
+
229
+ def _compute_iterative_predictions (pmap , model_name , num_iterations , sigma , patch_size , patch_halo , single_batch_mode , device ):
230
+ func = partial (unet_predictions , model_name = model_name , patch = patch_size , patch_halo = patch_halo ,
231
+ single_batch_mode = single_batch_mode , device = device )
214
232
for i in range (num_iterations - 1 ):
215
233
pmap = func (pmap )
216
234
pmap = image_gaussian_smoothing (image = pmap , sigma = sigma )
@@ -235,6 +253,8 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc
235
253
'min' : 0. },
236
254
patch_size = {'label' : 'Patch size' ,
237
255
'tooltip' : 'Patch size use to processed the data.' },
256
+ patch_halo = {'label' : 'Patch halo' ,
257
+ 'tooltip' : 'Patch halo is extra padding for correct prediction on image boarder.' },
238
258
single_patch = {'label' : 'Single Patch' ,
239
259
'tooltip' : 'If True, a single patch will be processed at a time to save memory.' },
240
260
device = {'label' : 'Device' ,
@@ -245,6 +265,7 @@ def widget_iterative_unet_predictions(image: Image,
245
265
num_iterations : int = 2 ,
246
266
sigma : float = 1.0 ,
247
267
patch_size : Tuple [int , int , int ] = (80 , 170 , 170 ),
268
+ patch_halo : Tuple [int , int , int ] = (8 , 16 , 16 ),
248
269
single_patch : bool = True ,
249
270
device : str = ALL_DEVICES [0 ]) -> Future [LayerDataTuple ]:
250
271
out_name = create_layer_name (image .name , f'iterative-{ model_name } -x{ num_iterations } ' )
@@ -258,6 +279,7 @@ def widget_iterative_unet_predictions(image: Image,
258
279
num_iterations = num_iterations ,
259
280
sigma = sigma ,
260
281
patch_size = patch_size ,
282
+ patch_halo = patch_halo ,
261
283
single_batch_mode = single_patch ,
262
284
device = device )
263
285
@@ -280,6 +302,11 @@ def _on_model_name_changed_iterative(model_name: str):
280
302
widget_iterative_unet_predictions .patch_size .value = tuple (patch_size )
281
303
282
304
305
+ @widget_iterative_unet_predictions .image .changed .connect
306
+ def _on_widget_iterative_unet_predictions_image_change (image : Image ):
307
+ _on_prediction_input_image_change (widget_iterative_unet_predictions , image )
308
+
309
+
283
310
@magicgui (call_button = 'Add Custom Model' ,
284
311
new_model_name = {'label' : 'New model name' },
285
312
model_location = {'label' : 'Model location' ,
0 commit comments