|
5 | 5 | import numpy as np
|
6 | 6 | import tqdm
|
7 | 7 | from skimage.filters import gaussian
|
8 |
| -from skimage.segmentation import watershed |
9 |
| - |
| 8 | +from skimage.segmentation import watershed, relabel_sequential |
| 9 | +from skimage.measure import regionprops |
10 | 10 |
|
11 | 11 | def get_bbox(mask: np.array, pixel_toll: int = 0) -> tuple[tuple, int, int, int]:
|
12 | 12 | """
|
@@ -235,3 +235,52 @@ def fix_over_under_segmentation_from_nuclei(cell_seg: np.array,
|
235 | 235 | cell_assignments,
|
236 | 236 | cell_idx=None)
|
237 | 237 | return _cell_seg
|
| 238 | + |
| 239 | + |
| 240 | +def remove_false_positives_by_foreground_probability(segmentation: np.array, |
| 241 | + foreground: np.array, |
| 242 | + threshold: float) -> np.array: |
| 243 | + """ |
| 244 | + Remove false positive regions in a segmentation based on a foreground probability map in a smart way. |
| 245 | + If the mean(an instance * its own probability region) < threshold, it is removed. |
| 246 | +
|
| 247 | + Args: |
| 248 | + segmentation (np.ndarray): The segmentation array, where each unique non-zero value indicates a distinct region. |
| 249 | + foreground (np.ndarray): The foreground probability map, same shape as `segmentation`. |
| 250 | + threshold (float): Probability threshold below which regions are considered false positives. |
| 251 | +
|
| 252 | + Returns: |
| 253 | + np.ndarray: The modified segmentation array with false positives removed. |
| 254 | + """ |
| 255 | + # TODO: make a channel for removed regions for easier inspection |
| 256 | + # TODO: use `relabel_sequential` to recover the original labels |
| 257 | + |
| 258 | + if not segmentation.shape == foreground.shape: |
| 259 | + raise ValueError("Shape of segmentation and probability map must match.") |
| 260 | + if foreground.max() > 1: |
| 261 | + raise ValueError("Foreground must be a probability map probability map.") |
| 262 | + |
| 263 | + instances, _, _ = relabel_sequential(segmentation) |
| 264 | + |
| 265 | + regions = regionprops(instances) |
| 266 | + to_keep = np.ones(len(regions) + 1) |
| 267 | + pixel_count = np.zeros(len(regions) + 1) |
| 268 | + pixel_value = np.zeros(len(regions) + 1) |
| 269 | + |
| 270 | + for region in tqdm.tqdm(regions): |
| 271 | + bbox = region.bbox |
| 272 | + cube = instances[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]] == region.label # other instances may exist, don't use `> 0` |
| 273 | + prob = foreground[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]] |
| 274 | + pixel_count[region.label] = region.area |
| 275 | + pixel_value[region.label] = (cube * prob).sum() |
| 276 | + |
| 277 | + likelihood = pixel_value / pixel_count |
| 278 | + to_keep[likelihood < threshold] = 0 |
| 279 | + ids_to_delete = np.argwhere(to_keep == 0) |
| 280 | + assert ids_to_delete.shape[1] == 1 |
| 281 | + ids_to_delete = ids_to_delete.flatten() |
| 282 | + # print(f" Removing instance {region.label}: pixel count: {pixel_count}, pixel value: {pixel_value}, likelihood: {likelihood}") |
| 283 | + |
| 284 | + instances[np.isin(instances, ids_to_delete)] = 0 |
| 285 | + instances, _, _ = relabel_sequential(instances) |
| 286 | + return instances |
0 commit comments