diff --git a/plantseg/legacy_gui/gui_widgets.py b/plantseg/legacy_gui/gui_widgets.py index bfe4d40a..92360f47 100644 --- a/plantseg/legacy_gui/gui_widgets.py +++ b/plantseg/legacy_gui/gui_widgets.py @@ -206,7 +206,7 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, class UnetPredictionFrame(ModuleFramePrototype): - def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, show_all=True): + def __init__(self, frame, config, col=0, module_name="prediction", font=None, show_all=True): self.prediction_frame = tkinter.Frame(frame) self.prediction_style = { "bg": "white", @@ -278,10 +278,18 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, type=int, font=font, ), + "patch_halo": ListEntry( + self.prediction_frame, + text="Patch Halo: ", + row=5, + column=0, + type=int, + font=font, + ), "device": MenuEntry( self.prediction_frame, text="Device Type: ", - row=5, + row=6, column=0, menu=["cuda", "cpu"], default=config[self.module]["device"], diff --git a/plantseg/pipeline/raw2seg.py b/plantseg/pipeline/raw2seg.py index 6d0b9ee4..90dd9846 100644 --- a/plantseg/pipeline/raw2seg.py +++ b/plantseg/pipeline/raw2seg.py @@ -41,8 +41,10 @@ def configure_cnn_step(input_paths, config): device = config.get('device', 'cuda') state = config.get('state', True) model_update = config.get('model_update', False) + patch_halo = config.get('patch_halo', None) return UnetPredictions(input_paths, model_name=model_name, input_key=input_key, input_channel=input_channel, - patch=patch, stride_ratio=stride_ratio, device=device, model_update=model_update, state=state) + patch=patch, stride_ratio=stride_ratio, device=device, model_update=model_update, + state=state, patch_halo=patch_halo) def configure_cnn_postprocessing_step(input_paths, config): diff --git a/plantseg/predictions/functional/predictions.py b/plantseg/predictions/functional/predictions.py index 457c4b14..75b0b396 100644 --- a/plantseg/predictions/functional/predictions.py +++ b/plantseg/predictions/functional/predictions.py @@ -30,7 +30,7 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int] Defaults to 'cuda'. model_update (bool, optional): if True will update the model to the latest version. Defaults to False. disable_tqdm (bool, optional): if True will disable tqdm progress bar. Defaults to False. - output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is + output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is multi-channel 3D pmap. Now `4` only used in `widget_unet_predictions()`. Returns: @@ -45,7 +45,9 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int] state = state['model_state_dict'] model.load_state_dict(state) - patch_halo = get_patch_halo(model_name) + patch_halo = kwargs.get('patch_halo', None) + if patch_halo is None: + patch_halo = get_patch_halo(model_name) predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'], out_channels=model_config['out_channels'], device=device, patch=patch, patch_halo=patch_halo, single_batch_mode=single_batch_mode, headless=False, diff --git a/plantseg/predictions/predict.py b/plantseg/predictions/predict.py index 38b54fc2..b875a499 100755 --- a/plantseg/predictions/predict.py +++ b/plantseg/predictions/predict.py @@ -36,7 +36,7 @@ def _check_patch_size(paths, patch_size): class UnetPredictions(GenericPipelineStep): def __init__(self, input_paths, model_name, input_key=None, input_channel=None, patch=(80, 160, 160), stride_ratio=0.75, device='cuda', - model_update=False, input_type="data_float32", output_type="data_float32", out_ext=".h5", state=True): + model_update=False, input_type="data_float32", output_type="data_float32", out_ext=".h5", state=True, patch_halo=None): self.patch = patch self.model_name = model_name self.stride_ratio = stride_ratio @@ -64,7 +64,8 @@ def __init__(self, input_paths, model_name, input_key=None, input_channel=None, model.load_state_dict(state) - patch_halo = get_patch_halo(model_name) + if patch_halo is None: + patch_halo = get_patch_halo(model_name) is_embedding = not model_config.get('is_segmentation', True) self.predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'], out_channels=model_config['out_channels'], device=device, patch=self.patch, diff --git a/plantseg/resources/config_predict_template.yaml b/plantseg/resources/config_predict_template.yaml index 3be20f95..5f3165de 100644 --- a/plantseg/resources/config_predict_template.yaml +++ b/plantseg/resources/config_predict_template.yaml @@ -42,6 +42,3 @@ loaders: - name: Standardize - name: ToTensor expand_dims: true - - - diff --git a/plantseg/resources/config_train_example.yaml b/plantseg/resources/config_train_example.yaml index 15e6d52d..c339df77 100644 --- a/plantseg/resources/config_train_example.yaml +++ b/plantseg/resources/config_train_example.yaml @@ -8,4 +8,4 @@ training: max_num_iters: 50000 dimensionality: 3D sparse: false - device: cuda \ No newline at end of file + device: cuda diff --git a/plantseg/resources/config_train_template.yaml b/plantseg/resources/config_train_template.yaml index 572c3534..77d11fe5 100644 --- a/plantseg/resources/config_train_template.yaml +++ b/plantseg/resources/config_train_template.yaml @@ -77,4 +77,4 @@ loaders: # minimum volume of the labels in the patch threshold: 0.1 # probability of accepting patches which do not fulfil the threshold criterion - slack_acceptance: 0.01 \ No newline at end of file + slack_acceptance: 0.01 diff --git a/plantseg/viewer/widget/predictions.py b/plantseg/viewer/widget/predictions.py index 9dc99235..f4fa39da 100644 --- a/plantseg/viewer/widget/predictions.py +++ b/plantseg/viewer/widget/predictions.py @@ -19,6 +19,7 @@ from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws from plantseg.viewer.widget.utils import return_value_if_widget from plantseg.viewer.widget.utils import start_threading_process, start_prediction_process, create_layer_name, layer_properties +from plantseg.viewer.widget.validation import _on_prediction_input_image_change, widgets_inactive ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())] MPS = ['mps'] if torch.backends.mps.is_available() else [] @@ -61,6 +62,8 @@ def unet_predictions_wrapper(raw, device, **kwargs): 'choices': LIST_ALL_MODELS}, patch_size={'label': 'Patch size', 'tooltip': 'Patch size use to processed the data.'}, + patch_halo={'label': 'Patch halo', + 'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'}, single_patch={'label': 'Single Patch', 'tooltip': 'If True, a single patch will be processed at a time to save memory.'}, device={'label': 'Device', @@ -73,6 +76,7 @@ def widget_unet_predictions(viewer: Viewer, modality: str = 'All', output_type: str = 'All', patch_size: Tuple[int, int, int] = (80, 170, 170), + patch_halo: Tuple[int, int, int] = (8, 16, 16), single_patch: bool = True, device: str = ALL_DEVICES[0], ) -> Future[LayerDataTuple]: out_name = create_layer_name(image.name, model_name) @@ -85,7 +89,7 @@ def widget_unet_predictions(viewer: Viewer, layer_kwargs['metadata']['pmap'] = True # this is used to warn the user that the layer is a pmap layer_type = 'image' - step_kwargs = dict(model_name=model_name, patch=patch_size, single_batch_mode=single_patch) + step_kwargs = dict(model_name=model_name, patch=patch_size, patch_halo=patch_halo, single_batch_mode=single_patch) return start_prediction_process(unet_predictions_wrapper, runtime_kwargs={'raw': image.data, @@ -105,6 +109,11 @@ def widget_unet_predictions(viewer: Viewer, ) +@widget_unet_predictions.image.changed.connect +def _on_widget_unet_predictions_image_change(image: Image): + _on_prediction_input_image_change(widget_unet_predictions, image) + + def _on_any_metadata_changed(dimensionality, modality, output_type): dimensionality = [dimensionality] if dimensionality != 'All' else None modality = [modality] if modality != 'All' else None @@ -152,7 +161,7 @@ def _on_model_name_changed(model_name: str): widget_unet_predictions.model_name.tooltip = f'Select a pretrained model. Current model description: {description}' -def _compute_multiple_predictions(image, patch_size, device, use_custom_models=True): +def _compute_multiple_predictions(image, patch_size, patch_halo, device, use_custom_models=True): out_layers = [] model_list = list_models(use_custom_models=use_custom_models) for i, model_name in enumerate(model_list): @@ -168,7 +177,7 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T layer_type = 'image' try: pmap = unet_predictions(raw=image.data, model_name=model_name, patch=patch_size, single_batch_mode=True, - device=device) + device=device, patch_halo=patch_halo) out_layers.append((pmap, layer_kwargs, layer_type)) except Exception as e: @@ -182,6 +191,8 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T 'tooltip': 'Raw image to be processed with a neural network.'}, patch_size={'label': 'Patch size', 'tooltip': 'Patch size use to processed the data.'}, + patch_halo={'label': 'Patch halo', + 'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'}, device={'label': 'Device', 'choices': ALL_DEVICES}, use_custom_models={'label': 'Use custom models', @@ -189,11 +200,13 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T ) def widget_test_all_unet_predictions(image: Image, patch_size: Tuple[int, int, int] = (80, 170, 170), + patch_halo: Tuple[int, int, int] = (2, 4, 4), device: str = ALL_DEVICES[0], use_custom_models: bool = True) -> Future[List[LayerDataTuple]]: func = thread_worker(partial(_compute_multiple_predictions, image=image, patch_size=patch_size, + patch_halo=patch_halo, device=device, use_custom_models=use_custom_models,)) @@ -208,9 +221,14 @@ def on_done(result): return future -def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, single_batch_mode, device): - func = partial(unet_predictions, model_name=model_name, patch=patch_size, single_batch_mode=single_batch_mode, - device=device) +@widget_test_all_unet_predictions.image.changed.connect +def _on_widget_test_all_unet_predictions_image_change(image: Image): + _on_prediction_input_image_change(widget_test_all_unet_predictions, image) + + +def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, patch_halo, single_batch_mode, device): + func = partial(unet_predictions, model_name=model_name, patch=patch_size, patch_halo=patch_halo, + single_batch_mode=single_batch_mode, device=device) for i in range(num_iterations - 1): pmap = func(pmap) pmap = image_gaussian_smoothing(image=pmap, sigma=sigma) @@ -235,6 +253,8 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc 'min': 0.}, patch_size={'label': 'Patch size', 'tooltip': 'Patch size use to processed the data.'}, + patch_halo={'label': 'Patch halo', + 'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'}, single_patch={'label': 'Single Patch', 'tooltip': 'If True, a single patch will be processed at a time to save memory.'}, device={'label': 'Device', @@ -245,6 +265,7 @@ def widget_iterative_unet_predictions(image: Image, num_iterations: int = 2, sigma: float = 1.0, patch_size: Tuple[int, int, int] = (80, 170, 170), + patch_halo: Tuple[int, int, int] = (8, 16, 16), single_patch: bool = True, device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]: out_name = create_layer_name(image.name, f'iterative-{model_name}-x{num_iterations}') @@ -258,6 +279,7 @@ def widget_iterative_unet_predictions(image: Image, num_iterations=num_iterations, sigma=sigma, patch_size=patch_size, + patch_halo=patch_halo, single_batch_mode=single_patch, device=device) @@ -280,6 +302,11 @@ def _on_model_name_changed_iterative(model_name: str): widget_iterative_unet_predictions.patch_size.value = tuple(patch_size) +@widget_iterative_unet_predictions.image.changed.connect +def _on_widget_iterative_unet_predictions_image_change(image: Image): + _on_prediction_input_image_change(widget_iterative_unet_predictions, image) + + @magicgui(call_button='Add Custom Model', new_model_name={'label': 'New model name'}, model_location={'label': 'Model location', diff --git a/plantseg/viewer/widget/validation.py b/plantseg/viewer/widget/validation.py new file mode 100644 index 00000000..b8f73bd9 --- /dev/null +++ b/plantseg/viewer/widget/validation.py @@ -0,0 +1,45 @@ +"""Widget input validation""" + +from napari.layers import Image +from magicgui.widgets import Widget + + +def widgets_inactive(*widgets, active): + """Toggle visibility of widgets.""" + for widget in widgets: + widget.visible = active + + +def widgets_valid(*widgets, valid): + """Toggle background warning color of widgets.""" + for widget in widgets: + widget.native.setStyleSheet("" if valid else "background-color: lightcoral") + + +def get_image_volume_from_layer(image): + """Used for widget parameter validation in change-handlers.""" + image = image.data[0] if image.multiscale else image.data + if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")): + from numpy import asanyarray + + image = asanyarray(image) + return image + + +def _on_prediction_input_image_change(widget: Widget, image: Image): + shape = get_image_volume_from_layer(image).shape + ndim = len(shape) + widget.image.tooltip = f"Shape: {shape}" + + size_z = widget.patch_size[0] + halo_z = widget.patch_halo[0] + if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget + size_z.value = 0 + halo_z.value = 0 + widgets_inactive(size_z, halo_z, active=False) + elif ndim == 3 and shape[0] > 1: # 3D + size_z.value = min(64, shape[0]) # TODO: fetch model default + halo_z.value = 8 + widgets_inactive(size_z, halo_z, active=True) + else: + raise ValueError(f"Unsupported number of dimensions: {ndim}")