From a009df83597b04a09e333ef968c753a0dd9e446d Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Mon, 26 Feb 2024 14:13:46 +0100 Subject: [PATCH 1/4] Fix ignored halo: raw2seg CLI --- plantseg/pipeline/raw2seg.py | 4 +++- plantseg/predictions/predict.py | 5 +++-- plantseg/resources/config_predict_template.yaml | 3 --- plantseg/resources/config_train_example.yaml | 2 +- plantseg/resources/config_train_template.yaml | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/plantseg/pipeline/raw2seg.py b/plantseg/pipeline/raw2seg.py index 6d0b9ee4..90dd9846 100644 --- a/plantseg/pipeline/raw2seg.py +++ b/plantseg/pipeline/raw2seg.py @@ -41,8 +41,10 @@ def configure_cnn_step(input_paths, config): device = config.get('device', 'cuda') state = config.get('state', True) model_update = config.get('model_update', False) + patch_halo = config.get('patch_halo', None) return UnetPredictions(input_paths, model_name=model_name, input_key=input_key, input_channel=input_channel, - patch=patch, stride_ratio=stride_ratio, device=device, model_update=model_update, state=state) + patch=patch, stride_ratio=stride_ratio, device=device, model_update=model_update, + state=state, patch_halo=patch_halo) def configure_cnn_postprocessing_step(input_paths, config): diff --git a/plantseg/predictions/predict.py b/plantseg/predictions/predict.py index 38b54fc2..b875a499 100755 --- a/plantseg/predictions/predict.py +++ b/plantseg/predictions/predict.py @@ -36,7 +36,7 @@ def _check_patch_size(paths, patch_size): class UnetPredictions(GenericPipelineStep): def __init__(self, input_paths, model_name, input_key=None, input_channel=None, patch=(80, 160, 160), stride_ratio=0.75, device='cuda', - model_update=False, input_type="data_float32", output_type="data_float32", out_ext=".h5", state=True): + model_update=False, input_type="data_float32", output_type="data_float32", out_ext=".h5", state=True, patch_halo=None): self.patch = patch self.model_name = model_name self.stride_ratio = stride_ratio @@ -64,7 +64,8 @@ def __init__(self, input_paths, model_name, input_key=None, input_channel=None, model.load_state_dict(state) - patch_halo = get_patch_halo(model_name) + if patch_halo is None: + patch_halo = get_patch_halo(model_name) is_embedding = not model_config.get('is_segmentation', True) self.predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'], out_channels=model_config['out_channels'], device=device, patch=self.patch, diff --git a/plantseg/resources/config_predict_template.yaml b/plantseg/resources/config_predict_template.yaml index 3be20f95..5f3165de 100644 --- a/plantseg/resources/config_predict_template.yaml +++ b/plantseg/resources/config_predict_template.yaml @@ -42,6 +42,3 @@ loaders: - name: Standardize - name: ToTensor expand_dims: true - - - diff --git a/plantseg/resources/config_train_example.yaml b/plantseg/resources/config_train_example.yaml index 15e6d52d..c339df77 100644 --- a/plantseg/resources/config_train_example.yaml +++ b/plantseg/resources/config_train_example.yaml @@ -8,4 +8,4 @@ training: max_num_iters: 50000 dimensionality: 3D sparse: false - device: cuda \ No newline at end of file + device: cuda diff --git a/plantseg/resources/config_train_template.yaml b/plantseg/resources/config_train_template.yaml index 572c3534..77d11fe5 100644 --- a/plantseg/resources/config_train_template.yaml +++ b/plantseg/resources/config_train_template.yaml @@ -77,4 +77,4 @@ loaders: # minimum volume of the labels in the patch threshold: 0.1 # probability of accepting patches which do not fulfil the threshold criterion - slack_acceptance: 0.01 \ No newline at end of file + slack_acceptance: 0.01 From 2072f019acfff9fc658918d2aab2b18099c7d4f3 Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Thu, 7 Mar 2024 01:45:27 +0100 Subject: [PATCH 2/4] Add widget input validation/adaptation and add halo to Napari GUI --- .../predictions/functional/predictions.py | 6 ++- plantseg/viewer/widget/predictions.py | 34 ++++++++++++-- plantseg/viewer/widget/validation.py | 46 +++++++++++++++++++ 3 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 plantseg/viewer/widget/validation.py diff --git a/plantseg/predictions/functional/predictions.py b/plantseg/predictions/functional/predictions.py index 457c4b14..75b0b396 100644 --- a/plantseg/predictions/functional/predictions.py +++ b/plantseg/predictions/functional/predictions.py @@ -30,7 +30,7 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int] Defaults to 'cuda'. model_update (bool, optional): if True will update the model to the latest version. Defaults to False. disable_tqdm (bool, optional): if True will disable tqdm progress bar. Defaults to False. - output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is + output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is multi-channel 3D pmap. Now `4` only used in `widget_unet_predictions()`. Returns: @@ -45,7 +45,9 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int] state = state['model_state_dict'] model.load_state_dict(state) - patch_halo = get_patch_halo(model_name) + patch_halo = kwargs.get('patch_halo', None) + if patch_halo is None: + patch_halo = get_patch_halo(model_name) predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'], out_channels=model_config['out_channels'], device=device, patch=patch, patch_halo=patch_halo, single_batch_mode=single_batch_mode, headless=False, diff --git a/plantseg/viewer/widget/predictions.py b/plantseg/viewer/widget/predictions.py index fd28b577..d8134ace 100644 --- a/plantseg/viewer/widget/predictions.py +++ b/plantseg/viewer/widget/predictions.py @@ -19,6 +19,7 @@ from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws from plantseg.viewer.widget.utils import return_value_if_widget from plantseg.viewer.widget.utils import start_threading_process, start_prediction_process, create_layer_name, layer_properties +from plantseg.viewer.widget.validation import change_handler, get_image_volume_from_layer, widgets_inactive ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())] MPS = ['mps'] if torch.backends.mps.is_available() else [] @@ -61,6 +62,8 @@ def unet_predictions_wrapper(raw, device, **kwargs): 'choices': LIST_ALL_MODELS}, patch_size={'label': 'Patch size', 'tooltip': 'Patch size use to processed the data.'}, + patch_halo={'label': 'Patch halo', + 'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'}, single_patch={'label': 'Single Patch', 'tooltip': 'If True, a single patch will be processed at a time to save memory.'}, device={'label': 'Device', @@ -73,6 +76,7 @@ def widget_unet_predictions(viewer: Viewer, modality: str = 'All', output_type: str = 'All', patch_size: Tuple[int, int, int] = (80, 170, 170), + patch_halo: Tuple[int, int, int] = (8, 16, 16), single_patch: bool = True, device: str = ALL_DEVICES[0], ) -> Future[LayerDataTuple]: out_name = create_layer_name(image.name, model_name) @@ -85,7 +89,7 @@ def widget_unet_predictions(viewer: Viewer, layer_kwargs['metadata']['pmap'] = True # this is used to warn the user that the layer is a pmap layer_type = 'image' - step_kwargs = dict(model_name=model_name, patch=patch_size, single_batch_mode=single_patch) + step_kwargs = dict(model_name=model_name, patch=patch_size, patch_halo=patch_halo, single_batch_mode=single_patch) return start_prediction_process(unet_predictions_wrapper, runtime_kwargs={'raw': image.data, @@ -105,6 +109,26 @@ def widget_unet_predictions(viewer: Viewer, ) +@change_handler(widget_unet_predictions.image, init=False) +def _image_change(image: Image): + shape = get_image_volume_from_layer(image).shape + ndim = len(shape) + widget_unet_predictions.image.tooltip = f"Shape: {shape}" + + size_z = widget_unet_predictions.patch_size[0] + halo_z = widget_unet_predictions.patch_halo[0] + if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget + size_z.value = 0 + halo_z.value = 0 + widgets_inactive(size_z, halo_z, active=False) + elif ndim == 3 and shape[0] > 1: # 3D + size_z.value = min(64, shape[0]) # TODO: fetch model default + halo_z.value = 8 + widgets_inactive(size_z, halo_z, active=True) + else: + raise ValueError(f"Unsupported number of dimensions: {ndim}") + + def _on_any_metadata_changed(dimensionality, modality, output_type): dimensionality = [dimensionality] if dimensionality != 'All' else None modality = [modality] if modality != 'All' else None @@ -152,7 +176,7 @@ def _on_model_name_changed(model_name: str): widget_unet_predictions.model_name.tooltip = f'Select a pretrained model. Current model description: {description}' -def _compute_multiple_predictions(image, patch_size, device): +def _compute_multiple_predictions(image, patch_size, patch_halo, device): out_layers = [] for i, model_name in enumerate(list_models()): @@ -167,7 +191,7 @@ def _compute_multiple_predictions(image, patch_size, device): layer_type = 'image' try: pmap = unet_predictions(raw=image.data, model_name=model_name, patch=patch_size, single_batch_mode=True, - device=device) + device=device, patch_halo=patch_halo) out_layers.append((pmap, layer_kwargs, layer_type)) except Exception as e: @@ -181,15 +205,19 @@ def _compute_multiple_predictions(image, patch_size, device): 'tooltip': 'Raw image to be processed with a neural network.'}, patch_size={'label': 'Patch size', 'tooltip': 'Patch size use to processed the data.'}, + patch_halo={'label': 'Patch halo', + 'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'}, device={'label': 'Device', 'choices': ALL_DEVICES} ) def widget_test_all_unet_predictions(image: Image, patch_size: Tuple[int, int, int] = (80, 170, 170), + patch_halo: Tuple[int, int, int] = (2, 4, 4), device: str = ALL_DEVICES[0]) -> Future[List[LayerDataTuple]]: func = thread_worker(partial(_compute_multiple_predictions, image=image, patch_size=patch_size, + patch_halo=patch_halo, device=device)) future = Future() diff --git a/plantseg/viewer/widget/validation.py b/plantseg/viewer/widget/validation.py new file mode 100644 index 00000000..8425bc3c --- /dev/null +++ b/plantseg/viewer/widget/validation.py @@ -0,0 +1,46 @@ +"""Widget input validation""" + +from psygnal import Signal +from functools import wraps + +def change_handler(*widgets, init=True, debug=False): + def decorator_change_handler(handler): + @wraps(handler) + def wrapper(*args): + source = Signal.sender() + emitter = Signal.current_emitter() + if debug: + # print(f"{emitter}: {source} = {args!r}") + print(f"EVENT '{str(emitter.name)}': {source.name:>20} = {args!r}") + # print(f" {source.name:>14}.value = {source.value}") + return handler(*args) + + for widget in widgets: + widget.changed.connect(wrapper) + if init: + widget.changed(widget.value) + return wrapper + + return decorator_change_handler + + +def get_image_volume_from_layer(image): + """Used for widget parameter validation in `change_handler`s.""" + image = image.data[0] if image.multiscale else image.data + if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")): + image = np.asanyarray(image) + return image + + +def widgets_inactive(*widgets, active): + """Toggle visibility of widgets.""" + for widget in widgets: + widget.visible = active + + +def widgets_valid(*widgets, valid): + """Toggle background warning color of widgets.""" + for widget in widgets: + widget.native.setStyleSheet("" if valid else "background-color: lightcoral") + + From 39e6b4701434c55ad5a15e5d3d6a96023bd69765 Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Thu, 7 Mar 2024 02:03:14 +0100 Subject: [PATCH 3/4] Add halo to Legacy GUI --- plantseg/legacy_gui/gui_widgets.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/plantseg/legacy_gui/gui_widgets.py b/plantseg/legacy_gui/gui_widgets.py index bfe4d40a..92360f47 100644 --- a/plantseg/legacy_gui/gui_widgets.py +++ b/plantseg/legacy_gui/gui_widgets.py @@ -206,7 +206,7 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, class UnetPredictionFrame(ModuleFramePrototype): - def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, show_all=True): + def __init__(self, frame, config, col=0, module_name="prediction", font=None, show_all=True): self.prediction_frame = tkinter.Frame(frame) self.prediction_style = { "bg": "white", @@ -278,10 +278,18 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, type=int, font=font, ), + "patch_halo": ListEntry( + self.prediction_frame, + text="Patch Halo: ", + row=5, + column=0, + type=int, + font=font, + ), "device": MenuEntry( self.prediction_frame, text="Device Type: ", - row=5, + row=6, column=0, menu=["cuda", "cpu"], default=config[self.module]["device"], From b65ae75b856d0d735b3743aef2d5a9bfc9b1bd27 Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Mon, 11 Mar 2024 16:19:59 +0100 Subject: [PATCH 4/4] Improve prediction widget input validation and add halo to iterative prediction --- plantseg/viewer/widget/predictions.py | 43 ++++++++++--------- plantseg/viewer/widget/validation.py | 59 +++++++++++++-------------- 2 files changed, 50 insertions(+), 52 deletions(-) diff --git a/plantseg/viewer/widget/predictions.py b/plantseg/viewer/widget/predictions.py index d8134ace..5af94b5f 100644 --- a/plantseg/viewer/widget/predictions.py +++ b/plantseg/viewer/widget/predictions.py @@ -19,7 +19,7 @@ from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws from plantseg.viewer.widget.utils import return_value_if_widget from plantseg.viewer.widget.utils import start_threading_process, start_prediction_process, create_layer_name, layer_properties -from plantseg.viewer.widget.validation import change_handler, get_image_volume_from_layer, widgets_inactive +from plantseg.viewer.widget.validation import _on_prediction_input_image_change, widgets_inactive ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())] MPS = ['mps'] if torch.backends.mps.is_available() else [] @@ -109,24 +109,9 @@ def widget_unet_predictions(viewer: Viewer, ) -@change_handler(widget_unet_predictions.image, init=False) -def _image_change(image: Image): - shape = get_image_volume_from_layer(image).shape - ndim = len(shape) - widget_unet_predictions.image.tooltip = f"Shape: {shape}" - - size_z = widget_unet_predictions.patch_size[0] - halo_z = widget_unet_predictions.patch_halo[0] - if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget - size_z.value = 0 - halo_z.value = 0 - widgets_inactive(size_z, halo_z, active=False) - elif ndim == 3 and shape[0] > 1: # 3D - size_z.value = min(64, shape[0]) # TODO: fetch model default - halo_z.value = 8 - widgets_inactive(size_z, halo_z, active=True) - else: - raise ValueError(f"Unsupported number of dimensions: {ndim}") +@widget_unet_predictions.image.changed.connect +def _on_widget_unet_predictions_image_change(image: Image): + _on_prediction_input_image_change(widget_unet_predictions, image) def _on_any_metadata_changed(dimensionality, modality, output_type): @@ -231,9 +216,14 @@ def on_done(result): return future -def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, single_batch_mode, device): - func = partial(unet_predictions, model_name=model_name, patch=patch_size, single_batch_mode=single_batch_mode, - device=device) +@widget_test_all_unet_predictions.image.changed.connect +def _on_widget_test_all_unet_predictions_image_change(image: Image): + _on_prediction_input_image_change(widget_test_all_unet_predictions, image) + + +def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, patch_halo, single_batch_mode, device): + func = partial(unet_predictions, model_name=model_name, patch=patch_size, patch_halo=patch_halo, + single_batch_mode=single_batch_mode, device=device) for i in range(num_iterations - 1): pmap = func(pmap) pmap = image_gaussian_smoothing(image=pmap, sigma=sigma) @@ -258,6 +248,8 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc 'min': 0.}, patch_size={'label': 'Patch size', 'tooltip': 'Patch size use to processed the data.'}, + patch_halo={'label': 'Patch halo', + 'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'}, single_patch={'label': 'Single Patch', 'tooltip': 'If True, a single patch will be processed at a time to save memory.'}, device={'label': 'Device', @@ -268,6 +260,7 @@ def widget_iterative_unet_predictions(image: Image, num_iterations: int = 2, sigma: float = 1.0, patch_size: Tuple[int, int, int] = (80, 170, 170), + patch_halo: Tuple[int, int, int] = (8, 16, 16), single_patch: bool = True, device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]: out_name = create_layer_name(image.name, f'iterative-{model_name}-x{num_iterations}') @@ -281,6 +274,7 @@ def widget_iterative_unet_predictions(image: Image, num_iterations=num_iterations, sigma=sigma, patch_size=patch_size, + patch_halo=patch_halo, single_batch_mode=single_patch, device=device) @@ -303,6 +297,11 @@ def _on_model_name_changed_iterative(model_name: str): widget_iterative_unet_predictions.patch_size.value = tuple(patch_size) +@widget_iterative_unet_predictions.image.changed.connect +def _on_widget_iterative_unet_predictions_image_change(image: Image): + _on_prediction_input_image_change(widget_iterative_unet_predictions, image) + + @magicgui(call_button='Add Custom Model', new_model_name={'label': 'New model name'}, model_location={'label': 'Model location', diff --git a/plantseg/viewer/widget/validation.py b/plantseg/viewer/widget/validation.py index 8425bc3c..b8f73bd9 100644 --- a/plantseg/viewer/widget/validation.py +++ b/plantseg/viewer/widget/validation.py @@ -1,35 +1,7 @@ """Widget input validation""" -from psygnal import Signal -from functools import wraps - -def change_handler(*widgets, init=True, debug=False): - def decorator_change_handler(handler): - @wraps(handler) - def wrapper(*args): - source = Signal.sender() - emitter = Signal.current_emitter() - if debug: - # print(f"{emitter}: {source} = {args!r}") - print(f"EVENT '{str(emitter.name)}': {source.name:>20} = {args!r}") - # print(f" {source.name:>14}.value = {source.value}") - return handler(*args) - - for widget in widgets: - widget.changed.connect(wrapper) - if init: - widget.changed(widget.value) - return wrapper - - return decorator_change_handler - - -def get_image_volume_from_layer(image): - """Used for widget parameter validation in `change_handler`s.""" - image = image.data[0] if image.multiscale else image.data - if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")): - image = np.asanyarray(image) - return image +from napari.layers import Image +from magicgui.widgets import Widget def widgets_inactive(*widgets, active): @@ -44,3 +16,30 @@ def widgets_valid(*widgets, valid): widget.native.setStyleSheet("" if valid else "background-color: lightcoral") +def get_image_volume_from_layer(image): + """Used for widget parameter validation in change-handlers.""" + image = image.data[0] if image.multiscale else image.data + if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")): + from numpy import asanyarray + + image = asanyarray(image) + return image + + +def _on_prediction_input_image_change(widget: Widget, image: Image): + shape = get_image_volume_from_layer(image).shape + ndim = len(shape) + widget.image.tooltip = f"Shape: {shape}" + + size_z = widget.patch_size[0] + halo_z = widget.patch_halo[0] + if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget + size_z.value = 0 + halo_z.value = 0 + widgets_inactive(size_z, halo_z, active=False) + elif ndim == 3 and shape[0] > 1: # 3D + size_z.value = min(64, shape[0]) # TODO: fetch model default + halo_z.value = 8 + widgets_inactive(size_z, halo_z, active=True) + else: + raise ValueError(f"Unsupported number of dimensions: {ndim}")