Skip to content

Commit 5d2a22c

Browse files
authored
Merge PR #202 | A generic pmap-guided object removal widget
2 parents cb39dc3 + ac0f2bf commit 5d2a22c

File tree

4 files changed

+87
-7
lines changed

4 files changed

+87
-7
lines changed

plantseg/dataprocessing/functional/advanced_dataprocessing.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import numpy as np
66
import tqdm
77
from skimage.filters import gaussian
8-
from skimage.segmentation import watershed
9-
8+
from skimage.segmentation import watershed, relabel_sequential
9+
from skimage.measure import regionprops
1010

1111
def get_bbox(mask: np.array, pixel_toll: int = 0) -> tuple[tuple, int, int, int]:
1212
"""
@@ -235,3 +235,52 @@ def fix_over_under_segmentation_from_nuclei(cell_seg: np.array,
235235
cell_assignments,
236236
cell_idx=None)
237237
return _cell_seg
238+
239+
240+
def remove_false_positives_by_foreground_probability(segmentation: np.array,
241+
foreground: np.array,
242+
threshold: float) -> np.array:
243+
"""
244+
Remove false positive regions in a segmentation based on a foreground probability map in a smart way.
245+
If the mean(an instance * its own probability region) < threshold, it is removed.
246+
247+
Args:
248+
segmentation (np.ndarray): The segmentation array, where each unique non-zero value indicates a distinct region.
249+
foreground (np.ndarray): The foreground probability map, same shape as `segmentation`.
250+
threshold (float): Probability threshold below which regions are considered false positives.
251+
252+
Returns:
253+
np.ndarray: The modified segmentation array with false positives removed.
254+
"""
255+
# TODO: make a channel for removed regions for easier inspection
256+
# TODO: use `relabel_sequential` to recover the original labels
257+
258+
if not segmentation.shape == foreground.shape:
259+
raise ValueError("Shape of segmentation and probability map must match.")
260+
if foreground.max() > 1:
261+
raise ValueError("Foreground must be a probability map probability map.")
262+
263+
instances, _, _ = relabel_sequential(segmentation)
264+
265+
regions = regionprops(instances)
266+
to_keep = np.ones(len(regions) + 1)
267+
pixel_count = np.zeros(len(regions) + 1)
268+
pixel_value = np.zeros(len(regions) + 1)
269+
270+
for region in tqdm.tqdm(regions):
271+
bbox = region.bbox
272+
cube = instances[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]] == region.label # other instances may exist, don't use `> 0`
273+
prob = foreground[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]]
274+
pixel_count[region.label] = region.area
275+
pixel_value[region.label] = (cube * prob).sum()
276+
277+
likelihood = pixel_value / pixel_count
278+
to_keep[likelihood < threshold] = 0
279+
ids_to_delete = np.argwhere(to_keep == 0)
280+
assert ids_to_delete.shape[1] == 1
281+
ids_to_delete = ids_to_delete.flatten()
282+
# print(f" Removing instance {region.label}: pixel count: {pixel_count}, pixel value: {pixel_value}, likelihood: {likelihood}")
283+
284+
instances[np.isin(instances, ids_to_delete)] = 0
285+
instances, _, _ = relabel_sequential(instances)
286+
return instances

plantseg/segmentation/gasp.py

-2
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ def __init__(self,
5656

5757
# Postprocessing size threshold
5858
self.post_minsize = post_minsize
59-
6059
self.n_threads = n_threads
61-
6260
self.dt_watershed = partial(dt_watershed,
6361
threshold=ws_threshold, sigma_seeds=ws_sigma,
6462
stacked=ws_2D, sigma_weights=ws_w_sigma,

plantseg/viewer/containers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from plantseg.viewer.widget.proofreading.proofreading import widget_split_and_merge_from_scribbles
1414
from plantseg.viewer.widget.segmentation import widget_dt_ws, widget_agglomeration
1515
from plantseg.viewer.widget.segmentation import widget_fix_over_under_segmentation_from_nuclei
16+
from plantseg.viewer.widget.segmentation import widget_fix_false_positive_from_foreground_pmap
1617
from plantseg.viewer.widget.segmentation import widget_lifted_multicut
1718
from plantseg.viewer.widget.segmentation import widget_simple_dt_ws
1819

@@ -65,7 +66,8 @@ def get_gasp_workflow():
6566
def get_extra_seg():
6667
container = MainWindow(widgets=[widget_dt_ws,
6768
widget_lifted_multicut,
68-
widget_fix_over_under_segmentation_from_nuclei],
69+
widget_fix_over_under_segmentation_from_nuclei,
70+
widget_fix_false_positive_from_foreground_pmap],
6971
labels=False)
7072
container = setup_menu(container, path='https://hci-unihd.github.io/plant-seg/chapters/plantseg_interactive_napari/extra_seg.html')
7173
return container

plantseg/viewer/widget/segmentation.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from napari.types import LayerDataTuple
88

99
from napari import Viewer
10-
from plantseg.dataprocessing.functional.advanced_dataprocessing import fix_over_under_segmentation_from_nuclei
10+
from plantseg.dataprocessing.functional.advanced_dataprocessing import fix_over_under_segmentation_from_nuclei, remove_false_positives_by_foreground_probability
1111
from plantseg.dataprocessing.functional.dataprocessing import normalize_01
1212
from plantseg.segmentation.functional import gasp, multicut, dt_watershed, mutex_ws
1313
from plantseg.segmentation.functional import lifted_multicut_from_nuclei_segmentation, lifted_multicut_from_nuclei_pmaps
@@ -330,5 +330,36 @@ def widget_fix_over_under_segmentation_from_nuclei(cell_segmentation: Labels,
330330
input_keys=inputs_names,
331331
layer_kwarg=layer_kwargs,
332332
layer_type=layer_type,
333-
step_name=f'Fix Over / Under segmentation',
333+
step_name='Fix Over / Under Segmentation',
334+
)
335+
336+
337+
@magicgui(call_button='Run Segmentation Fix from Foreground Pmap',
338+
segmentation={'label': 'Segmentation'},
339+
foreground={'label': 'Foreground Pmap'},
340+
threshold={'label': 'Threshold',
341+
'widget_type': 'FloatSlider', 'max': 1., 'min': 0.})
342+
def widget_fix_false_positive_from_foreground_pmap(segmentation: Labels,
343+
foreground: Image, # TODO: maybe also allow labels
344+
threshold=0.6) -> Future[LayerDataTuple]:
345+
out_name = create_layer_name(segmentation.name, 'FGPmapFix')
346+
347+
inputs_names = (segmentation.name, foreground.name)
348+
func_kwargs = {'segmentation': segmentation.data,
349+
'foreground': foreground.data}
350+
351+
layer_kwargs = layer_properties(name=out_name,
352+
scale=segmentation.scale,
353+
metadata=segmentation.metadata)
354+
layer_type = 'labels'
355+
step_kwargs = dict(threshold=threshold)
356+
357+
return start_threading_process(remove_false_positives_by_foreground_probability,
358+
runtime_kwargs=func_kwargs,
359+
statics_kwargs=step_kwargs,
360+
out_name=out_name,
361+
input_keys=inputs_names,
362+
layer_kwarg=layer_kwargs,
363+
layer_type=layer_type,
364+
step_name='Reduce False Positives',
334365
)

0 commit comments

Comments
 (0)