Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation, adaptation and halo to prediction widgets #211

Merged
merged 5 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions plantseg/legacy_gui/gui_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None,


class UnetPredictionFrame(ModuleFramePrototype):
def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, show_all=True):
def __init__(self, frame, config, col=0, module_name="prediction", font=None, show_all=True):
self.prediction_frame = tkinter.Frame(frame)
self.prediction_style = {
"bg": "white",
Expand Down Expand Up @@ -278,10 +278,18 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None,
type=int,
font=font,
),
"patch_halo": ListEntry(
self.prediction_frame,
text="Patch Halo: ",
row=5,
column=0,
type=int,
font=font,
),
"device": MenuEntry(
self.prediction_frame,
text="Device Type: ",
row=5,
row=6,
column=0,
menu=["cuda", "cpu"],
default=config[self.module]["device"],
Expand Down
4 changes: 3 additions & 1 deletion plantseg/pipeline/raw2seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def configure_cnn_step(input_paths, config):
device = config.get('device', 'cuda')
state = config.get('state', True)
model_update = config.get('model_update', False)
patch_halo = config.get('patch_halo', None)
return UnetPredictions(input_paths, model_name=model_name, input_key=input_key, input_channel=input_channel,
patch=patch, stride_ratio=stride_ratio, device=device, model_update=model_update, state=state)
patch=patch, stride_ratio=stride_ratio, device=device, model_update=model_update,
state=state, patch_halo=patch_halo)


def configure_cnn_postprocessing_step(input_paths, config):
Expand Down
6 changes: 4 additions & 2 deletions plantseg/predictions/functional/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int]
Defaults to 'cuda'.
model_update (bool, optional): if True will update the model to the latest version. Defaults to False.
disable_tqdm (bool, optional): if True will disable tqdm progress bar. Defaults to False.
output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is
output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is
multi-channel 3D pmap. Now `4` only used in `widget_unet_predictions()`.

Returns:
Expand All @@ -45,7 +45,9 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int]
state = state['model_state_dict']
model.load_state_dict(state)

patch_halo = get_patch_halo(model_name)
patch_halo = kwargs.get('patch_halo', None)
if patch_halo is None:
patch_halo = get_patch_halo(model_name)
predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'],
out_channels=model_config['out_channels'], device=device, patch=patch,
patch_halo=patch_halo, single_batch_mode=single_batch_mode, headless=False,
Expand Down
5 changes: 3 additions & 2 deletions plantseg/predictions/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _check_patch_size(paths, patch_size):

class UnetPredictions(GenericPipelineStep):
def __init__(self, input_paths, model_name, input_key=None, input_channel=None, patch=(80, 160, 160), stride_ratio=0.75, device='cuda',
model_update=False, input_type="data_float32", output_type="data_float32", out_ext=".h5", state=True):
model_update=False, input_type="data_float32", output_type="data_float32", out_ext=".h5", state=True, patch_halo=None):
self.patch = patch
self.model_name = model_name
self.stride_ratio = stride_ratio
Expand Down Expand Up @@ -64,7 +64,8 @@ def __init__(self, input_paths, model_name, input_key=None, input_channel=None,

model.load_state_dict(state)

patch_halo = get_patch_halo(model_name)
if patch_halo is None:
patch_halo = get_patch_halo(model_name)
is_embedding = not model_config.get('is_segmentation', True)
self.predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'],
out_channels=model_config['out_channels'], device=device, patch=self.patch,
Expand Down
3 changes: 0 additions & 3 deletions plantseg/resources/config_predict_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,3 @@ loaders:
- name: Standardize
- name: ToTensor
expand_dims: true



2 changes: 1 addition & 1 deletion plantseg/resources/config_train_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ training:
max_num_iters: 50000
dimensionality: 3D
sparse: false
device: cuda
device: cuda
2 changes: 1 addition & 1 deletion plantseg/resources/config_train_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ loaders:
# minimum volume of the labels in the patch
threshold: 0.1
# probability of accepting patches which do not fulfil the threshold criterion
slack_acceptance: 0.01
slack_acceptance: 0.01
39 changes: 33 additions & 6 deletions plantseg/viewer/widget/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws
from plantseg.viewer.widget.utils import return_value_if_widget
from plantseg.viewer.widget.utils import start_threading_process, start_prediction_process, create_layer_name, layer_properties
from plantseg.viewer.widget.validation import _on_prediction_input_image_change, widgets_inactive

ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
MPS = ['mps'] if torch.backends.mps.is_available() else []
Expand Down Expand Up @@ -61,6 +62,8 @@ def unet_predictions_wrapper(raw, device, **kwargs):
'choices': LIST_ALL_MODELS},
patch_size={'label': 'Patch size',
'tooltip': 'Patch size use to processed the data.'},
patch_halo={'label': 'Patch halo',
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
single_patch={'label': 'Single Patch',
'tooltip': 'If True, a single patch will be processed at a time to save memory.'},
device={'label': 'Device',
Expand All @@ -73,6 +76,7 @@ def widget_unet_predictions(viewer: Viewer,
modality: str = 'All',
output_type: str = 'All',
patch_size: Tuple[int, int, int] = (80, 170, 170),
patch_halo: Tuple[int, int, int] = (8, 16, 16),
single_patch: bool = True,
device: str = ALL_DEVICES[0], ) -> Future[LayerDataTuple]:
out_name = create_layer_name(image.name, model_name)
Expand All @@ -85,7 +89,7 @@ def widget_unet_predictions(viewer: Viewer,
layer_kwargs['metadata']['pmap'] = True # this is used to warn the user that the layer is a pmap

layer_type = 'image'
step_kwargs = dict(model_name=model_name, patch=patch_size, single_batch_mode=single_patch)
step_kwargs = dict(model_name=model_name, patch=patch_size, patch_halo=patch_halo, single_batch_mode=single_patch)

return start_prediction_process(unet_predictions_wrapper,
runtime_kwargs={'raw': image.data,
Expand All @@ -105,6 +109,11 @@ def widget_unet_predictions(viewer: Viewer,
)


@widget_unet_predictions.image.changed.connect
def _on_widget_unet_predictions_image_change(image: Image):
_on_prediction_input_image_change(widget_unet_predictions, image)


def _on_any_metadata_changed(dimensionality, modality, output_type):
dimensionality = [dimensionality] if dimensionality != 'All' else None
modality = [modality] if modality != 'All' else None
Expand Down Expand Up @@ -152,7 +161,7 @@ def _on_model_name_changed(model_name: str):
widget_unet_predictions.model_name.tooltip = f'Select a pretrained model. Current model description: {description}'


def _compute_multiple_predictions(image, patch_size, device, use_custom_models=True):
def _compute_multiple_predictions(image, patch_size, patch_halo, device, use_custom_models=True):
out_layers = []
model_list = list_models(use_custom_models=use_custom_models)
for i, model_name in enumerate(model_list):
Expand All @@ -168,7 +177,7 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T
layer_type = 'image'
try:
pmap = unet_predictions(raw=image.data, model_name=model_name, patch=patch_size, single_batch_mode=True,
device=device)
device=device, patch_halo=patch_halo)
out_layers.append((pmap, layer_kwargs, layer_type))

except Exception as e:
Expand All @@ -182,18 +191,22 @@ def _compute_multiple_predictions(image, patch_size, device, use_custom_models=T
'tooltip': 'Raw image to be processed with a neural network.'},
patch_size={'label': 'Patch size',
'tooltip': 'Patch size use to processed the data.'},
patch_halo={'label': 'Patch halo',
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
device={'label': 'Device',
'choices': ALL_DEVICES},
use_custom_models={'label': 'Use custom models',
'tooltip': 'If True, custom models will also be used.'}
)
def widget_test_all_unet_predictions(image: Image,
patch_size: Tuple[int, int, int] = (80, 170, 170),
patch_halo: Tuple[int, int, int] = (2, 4, 4),
device: str = ALL_DEVICES[0],
use_custom_models: bool = True) -> Future[List[LayerDataTuple]]:
func = thread_worker(partial(_compute_multiple_predictions,
image=image,
patch_size=patch_size,
patch_halo=patch_halo,
device=device,
use_custom_models=use_custom_models,))

Expand All @@ -208,9 +221,14 @@ def on_done(result):
return future


def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, single_batch_mode, device):
func = partial(unet_predictions, model_name=model_name, patch=patch_size, single_batch_mode=single_batch_mode,
device=device)
@widget_test_all_unet_predictions.image.changed.connect
def _on_widget_test_all_unet_predictions_image_change(image: Image):
_on_prediction_input_image_change(widget_test_all_unet_predictions, image)


def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, patch_halo, single_batch_mode, device):
func = partial(unet_predictions, model_name=model_name, patch=patch_size, patch_halo=patch_halo,
single_batch_mode=single_batch_mode, device=device)
for i in range(num_iterations - 1):
pmap = func(pmap)
pmap = image_gaussian_smoothing(image=pmap, sigma=sigma)
Expand All @@ -235,6 +253,8 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc
'min': 0.},
patch_size={'label': 'Patch size',
'tooltip': 'Patch size use to processed the data.'},
patch_halo={'label': 'Patch halo',
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
single_patch={'label': 'Single Patch',
'tooltip': 'If True, a single patch will be processed at a time to save memory.'},
device={'label': 'Device',
Expand All @@ -245,6 +265,7 @@ def widget_iterative_unet_predictions(image: Image,
num_iterations: int = 2,
sigma: float = 1.0,
patch_size: Tuple[int, int, int] = (80, 170, 170),
patch_halo: Tuple[int, int, int] = (8, 16, 16),
single_patch: bool = True,
device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]:
out_name = create_layer_name(image.name, f'iterative-{model_name}-x{num_iterations}')
Expand All @@ -258,6 +279,7 @@ def widget_iterative_unet_predictions(image: Image,
num_iterations=num_iterations,
sigma=sigma,
patch_size=patch_size,
patch_halo=patch_halo,
single_batch_mode=single_patch,
device=device)

Expand All @@ -280,6 +302,11 @@ def _on_model_name_changed_iterative(model_name: str):
widget_iterative_unet_predictions.patch_size.value = tuple(patch_size)


@widget_iterative_unet_predictions.image.changed.connect
def _on_widget_iterative_unet_predictions_image_change(image: Image):
_on_prediction_input_image_change(widget_iterative_unet_predictions, image)


@magicgui(call_button='Add Custom Model',
new_model_name={'label': 'New model name'},
model_location={'label': 'Model location',
Expand Down
45 changes: 45 additions & 0 deletions plantseg/viewer/widget/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Widget input validation"""

from napari.layers import Image
from magicgui.widgets import Widget


def widgets_inactive(*widgets, active):
"""Toggle visibility of widgets."""
for widget in widgets:
widget.visible = active


def widgets_valid(*widgets, valid):
"""Toggle background warning color of widgets."""
for widget in widgets:
widget.native.setStyleSheet("" if valid else "background-color: lightcoral")


def get_image_volume_from_layer(image):
"""Used for widget parameter validation in change-handlers."""
image = image.data[0] if image.multiscale else image.data
if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")):
from numpy import asanyarray

image = asanyarray(image)
return image


def _on_prediction_input_image_change(widget: Widget, image: Image):
shape = get_image_volume_from_layer(image).shape
ndim = len(shape)
widget.image.tooltip = f"Shape: {shape}"

size_z = widget.patch_size[0]
halo_z = widget.patch_halo[0]
if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget
size_z.value = 0
halo_z.value = 0
widgets_inactive(size_z, halo_z, active=False)
elif ndim == 3 and shape[0] > 1: # 3D
size_z.value = min(64, shape[0]) # TODO: fetch model default
halo_z.value = 8
widgets_inactive(size_z, halo_z, active=True)
else:
raise ValueError(f"Unsupported number of dimensions: {ndim}")
Loading