Skip to content

Commit cb39dc3

Browse files
authored
Add option to only run official models in "Try all Available Models" (#207)
1 parent b625246 commit cb39dc3

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

plantseg/utils.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def load_config(config_path: str) -> dict:
2626
return config
2727

2828

29-
def get_model_zoo() -> dict:
29+
def get_model_zoo(get_custom: bool = True) -> dict:
3030
"""
3131
returns a dictionary of all models in the model zoo.
3232
example:
@@ -43,22 +43,24 @@ def get_model_zoo() -> dict:
4343

4444
zoo_config = load_config(zoo_config)
4545

46-
custom_zoo_config = load_config(custom_zoo)
46+
if get_custom:
47+
custom_zoo_config = load_config(custom_zoo)
4748

48-
if custom_zoo_config is None:
49-
custom_zoo_config = {}
49+
if custom_zoo_config is None:
50+
custom_zoo_config = {}
5051

51-
zoo_config.update(custom_zoo_config)
52+
zoo_config.update(custom_zoo_config)
5253
return zoo_config
5354

5455

5556
def list_models(dimensionality_filter: list[str] = None,
5657
modality_filter: list[str] = None,
57-
output_type_filter: list[str] = None) -> list[str]:
58+
output_type_filter: list[str] = None,
59+
use_custom_models: bool = True) -> list[str]:
5860
"""
5961
return a list of models in the model zoo by name
6062
"""
61-
zoo_config = get_model_zoo()
63+
zoo_config = get_model_zoo(use_custom_models)
6264
models = list(zoo_config.keys())
6365

6466
if dimensionality_filter is not None:

plantseg/viewer/widget/predictions.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,12 @@ def _on_model_name_changed(model_name: str):
152152
widget_unet_predictions.model_name.tooltip = f'Select a pretrained model. Current model description: {description}'
153153

154154

155-
def _compute_multiple_predictions(image, patch_size, device):
155+
def _compute_multiple_predictions(image, patch_size, device, use_custom_models=True):
156156
out_layers = []
157-
for i, model_name in enumerate(list_models()):
157+
model_list = list_models(use_custom_models=use_custom_models)
158+
for i, model_name in enumerate(model_list):
158159

159-
napari_formatted_logging(f'Running UNet Predictions: {model_name} {i}/{len(list_models())}',
160+
napari_formatted_logging(f'Running UNet Predictions: {model_name} {i}/{len(model_list)}',
160161
thread='UNet Grid Predictions')
161162

162163
out_name = create_layer_name(image.name, model_name)
@@ -182,15 +183,19 @@ def _compute_multiple_predictions(image, patch_size, device):
182183
patch_size={'label': 'Patch size',
183184
'tooltip': 'Patch size use to processed the data.'},
184185
device={'label': 'Device',
185-
'choices': ALL_DEVICES}
186+
'choices': ALL_DEVICES},
187+
use_custom_models={'label': 'Use custom models',
188+
'tooltip': 'If True, custom models will also be used.'}
186189
)
187190
def widget_test_all_unet_predictions(image: Image,
188191
patch_size: Tuple[int, int, int] = (80, 170, 170),
189-
device: str = ALL_DEVICES[0]) -> Future[List[LayerDataTuple]]:
192+
device: str = ALL_DEVICES[0],
193+
use_custom_models: bool = True) -> Future[List[LayerDataTuple]]:
190194
func = thread_worker(partial(_compute_multiple_predictions,
191195
image=image,
192196
patch_size=patch_size,
193-
device=device))
197+
device=device,
198+
use_custom_models=use_custom_models,))
194199

195200
future = Future()
196201

0 commit comments

Comments
 (0)