Skip to content

Commit 7eecf62

Browse files
committed
New tiling method for inference on large images
1 parent a8835a1 commit 7eecf62

File tree

2 files changed

+266
-3
lines changed

2 files changed

+266
-3
lines changed

deepliif/models/__init__.py

+61-3
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ def get_net_tiles(n):
369369
return images
370370

371371

372-
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
373-
color_dapi=False, color_marker=False, opt=None):
372+
def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
373+
color_dapi=False, color_marker=False, opt=None):
374374
if not opt:
375375
opt = get_opt(model_path)
376376
#print_options(opt)
@@ -489,6 +489,63 @@ def get_net_tiles(n):
489489
raise Exception(f'inference() not implemented for model {opt.model}')
490490

491491

492+
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
493+
eager_mode=False, color_dapi=False, color_marker=False, opt=None):
494+
if not opt:
495+
opt = get_opt(model_path)
496+
#print_options(opt)
497+
498+
run_fn = run_torchserve if use_torchserve else run_dask
499+
500+
if opt.model == 'SDG':
501+
# SDG could have multiple input images/modalities, hence the input could be a rectangle.
502+
# We split the input to get each modality image then create tiles for each set of input images.
503+
w, h = int(img.width / opt.input_no), img.height
504+
orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
505+
else:
506+
# Otherwise expect a single input image, which is used directly.
507+
orig = img
508+
509+
tiler = InferenceTiler(orig, tile_size, overlap_size)
510+
for tile in tiler:
511+
tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt))
512+
results = tiler.results()
513+
514+
if opt.model == 'DeepLIIF':
515+
images = {
516+
'Hema': results['G1'],
517+
'DAPI': results['G2'],
518+
'Lap2': results['G3'],
519+
'Marker': results['G4'],
520+
'Seg': results['G5'],
521+
}
522+
if color_dapi:
523+
matrix = ( 0, 0, 0, 0,
524+
299/1000, 587/1000, 114/1000, 0,
525+
299/1000, 587/1000, 114/1000, 0)
526+
images['DAPI'] = images['DAPI'].convert('RGB', matrix)
527+
if color_marker:
528+
matrix = (299/1000, 587/1000, 114/1000, 0,
529+
299/1000, 587/1000, 114/1000, 0,
530+
0, 0, 0, 0)
531+
images['Marker'] = images['Marker'].convert('RGB', matrix)
532+
return images
533+
534+
elif opt.model == 'DeepLIIFExt':
535+
images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
536+
if opt.seg_gen:
537+
images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
538+
return images
539+
540+
elif opt.model == 'SDG':
541+
images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
542+
return images
543+
544+
else:
545+
#raise Exception(f'inference() not implemented for model {opt.model}')
546+
return results # return result images with default key names (i.e., net names)
547+
548+
492549
def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
493550
if model == 'DeepLIIF':
494551
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
@@ -546,7 +603,8 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
546603
images = inference(
547604
img,
548605
tile_size=tile_size,
549-
overlap_size=compute_overlap(img_size, tile_size),
606+
#overlap_size=compute_overlap(img_size, tile_size),
607+
overlap_size=tile_size//16,
550608
model_path=model_dir,
551609
eager_mode=eager_mode,
552610
color_dapi=color_dapi,

deepliif/util/__init__.py

+205
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,211 @@ def stitch_tile(img, tile, tile_size, overlap_size, i, j):
118118
img.paste(tile, (i * tile_size, j * tile_size))
119119

120120

121+
class InferenceTiler:
122+
"""
123+
Iterable class to tile image(s) and stitch result tiles together.
124+
125+
To perform inference on a large image, that image will need to be
126+
tiled into smaller tiles that can be run individually and then
127+
stitched back together. This class wraps the functionality as an
128+
iterable object that can accept a single image or list of images
129+
if multiple images are taken as input for inference.
130+
131+
An overlap size can be specified so that neighboring tiles will
132+
overlap at the edges, helping to reduce seams or other artifacts
133+
near the edge of a tile. Padding of a solid color around the
134+
perimeter of the tile is also possible, if needed. The specified
135+
tile size includes this overlap and pad sizes, so a tile size of
136+
512 with an overlap size of 32 and pad size of 16 would have a
137+
central area of 416 pixels that are stitched into the result image.
138+
139+
Example Usage
140+
-------------
141+
tiler = InferenceTiler(img, 512, 32)
142+
for tile in tiler:
143+
result_tiles = infer(tile)
144+
tiler.stitch(result_tiles)
145+
images = tiler.results()
146+
"""
147+
148+
def __init__(self, orig, tile_size, overlap_size=0, pad_size=0, pad_color=(255, 255, 255)):
149+
"""
150+
Initialize for tiling an image or list of images.
151+
152+
Parameters
153+
----------
154+
orig : Image | list(Image)
155+
Original image or list of images to be tiled.
156+
tile_size: int
157+
Size (width and height) of the tiles to be generated.
158+
overlap_size: int [default: 0]
159+
Amount of overlap on each side of the tile.
160+
pad_size: int [default: 0]
161+
Amount of solid color padding around perimeter of tile.
162+
pad_color: tuple(int, int, int) [default: (255,255,255)]
163+
RGB color to use for padding.
164+
"""
165+
166+
if tile_size <= 0:
167+
raise ValueError('InfereneTiler input tile_size must be positive and non-zero')
168+
if overlap_size < 0:
169+
raise ValueError('InfereneTiler input overlap_size must be positive or zero')
170+
if pad_size < 0:
171+
raise ValueError('InfereneTiler input pad_size must be positive or zero')
172+
173+
self.single_orig = not type(orig) is list
174+
if self.single_orig:
175+
orig = [orig]
176+
177+
for i in range(1, len(orig)):
178+
if orig[i].size != orig[0].size:
179+
raise ValueError('InferenceTiler input images do not have the same size.')
180+
self.orig_width = orig[0].width
181+
self.orig_height = orig[0].height
182+
183+
# patch size to extract from input image, which is then padded to tile size
184+
patch_size = tile_size - (2 * pad_size)
185+
186+
# make sure width and height are both at least patch_size
187+
if orig[0].width < patch_size:
188+
for i in range(len(orig)):
189+
while orig[i].width < patch_size:
190+
mirrored = ImageOps.mirror(orig[i])
191+
orig[i] = ImageOps.expand(orig[i], (0, 0, orig[i].width, 0))
192+
orig[i].paste(mirrored, (mirrored.width, 0))
193+
orig[i] = orig[i].crop((0, 0, patch_size, orig[i].height))
194+
if orig[0].height < patch_size:
195+
for i in range(len(orig)):
196+
while orig[i].height < patch_size:
197+
flipped = ImageOps.flip(orig[i])
198+
orig[i] = ImageOps.expand(orig[i], (0, 0, 0, orig[i].height))
199+
orig[i].paste(flipped, (0, flipped.height))
200+
orig[i] = orig[i].crop((0, 0, orig[i].width, patch_size))
201+
self.image_width = orig[0].width
202+
self.image_height = orig[0].height
203+
204+
overlap_width = 0 if patch_size >= self.image_width else overlap_size
205+
overlap_height = 0 if patch_size >= self.image_height else overlap_size
206+
center_width = patch_size - (2 * overlap_width)
207+
center_height = patch_size - (2 * overlap_height)
208+
if center_width <= 0 or center_height <= 0:
209+
raise ValueError('InferenceTiler combined overlap_size and pad_size are too large')
210+
211+
self.c0x = pad_size # crop offset for left of non-pad content in result tile
212+
self.c0y = pad_size # crop offset for top of non-pad content in result tile
213+
self.c1x = overlap_width + pad_size # crop offset for left of center region in result tile
214+
self.c1y = overlap_height + pad_size # crop offset for top of center region in result tile
215+
self.c2x = patch_size - overlap_width + pad_size # crop offset for right of center region in result tile
216+
self.c2y = patch_size - overlap_height + pad_size # crop offset for bottom of center region in result tile
217+
self.c3x = patch_size + pad_size # crop offset for right of non-pad content in result tile
218+
self.c3y = patch_size + pad_size # crop offset for bottom of non-pad content in result tile
219+
self.p1x = overlap_width # paste offset for left of center region w.r.t (x,y) coord
220+
self.p1y = overlap_height # paste offset for top of center region w.r.t (x,y) coord
221+
self.p2x = patch_size - overlap_width # paste offset for right of center region w.r.t (x,y) coord
222+
self.p2y = patch_size - overlap_height # paste offset for bottom of center region w.r.t (x,y) coord
223+
224+
self.overlap_width = overlap_width
225+
self.overlap_height = overlap_height
226+
self.patch_size = patch_size
227+
self.center_width = center_width
228+
self.center_height = center_height
229+
230+
self.orig = orig
231+
self.tile_size = tile_size
232+
self.pad_size = pad_size
233+
self.pad_color = pad_color
234+
self.res = {}
235+
236+
def __iter__(self):
237+
"""
238+
Generate the tiles as an iterable.
239+
240+
Tiles are created and iterated over from top left to bottom
241+
right, going across the rows. The yielded tile(s) match the
242+
type of the original input when initialized (either a single
243+
image or a list of images in the same order as initialized).
244+
The (x, y) coordinate of the current tile is maintained
245+
internally for use in the stitch function.
246+
"""
247+
248+
for y in range(0, self.image_height, self.center_height):
249+
for x in range(0, self.image_width, self.center_width):
250+
if x + self.patch_size > self.image_width:
251+
x = self.image_width - self.patch_size
252+
if y + self.patch_size > self.image_height:
253+
y = self.image_height - self.patch_size
254+
self.x = x
255+
self.y = y
256+
tiles = [im.crop((x, y, x + self.patch_size, y + self.patch_size)) for im in self.orig]
257+
if self.pad_size != 0:
258+
tiles = [ImageOps.expand(t, self.pad_size, self.pad_color) for t in tiles]
259+
yield tiles[0] if self.single_orig else tiles
260+
261+
def stitch(self, result_tiles):
262+
"""
263+
Stitch result tiles into the result images.
264+
265+
The key names for the dictionary of result tiles are used to
266+
stitch each tile into its corresponding final image in the
267+
results attribute. If a result image does not exist for a
268+
result tile key name, then it will be created. The result tiles
269+
are stitched at the location from which the list iterated tile
270+
was extracted.
271+
272+
Parameters
273+
----------
274+
result_tiles : dict(str: Image)
275+
Dictionary of result tiles from the inference.
276+
"""
277+
278+
for k, tile in result_tiles.items():
279+
if k not in self.res:
280+
self.res[k] = Image.new('RGB', (self.image_width, self.image_height))
281+
if tile.size != (self.tile_size, self.tile_size):
282+
tile = tile.resize((self.tile_size, self.tile_size))
283+
self.res[k].paste(tile.crop((self.c1x, self.c1y, self.c2x, self.c2y)), (self.x + self.p1x, self.y + self.p1y))
284+
285+
# top left corner
286+
if self.x == 0 and self.y == 0:
287+
self.res[k].paste(tile.crop((self.c0x, self.c0y, self.c1x, self.c1y)), (self.x, self.y))
288+
# top row
289+
if self.y == 0:
290+
self.res[k].paste(tile.crop((self.c1x, self.c0y, self.c2x, self.c1y)), (self.x + self.p1x, self.y))
291+
# top right corner
292+
if self.x == self.image_width - self.patch_size and self.y == 0:
293+
self.res[k].paste(tile.crop((self.c2x, self.c0y, self.c3x, self.c1y)), (self.x + self.p2x, self.y))
294+
# left column
295+
if self.x == 0:
296+
self.res[k].paste(tile.crop((self.c0x, self.c1y, self.c1x, self.c2y)), (self.x, self.y + self.p1y))
297+
# right column
298+
if self.x == self.image_width - self.patch_size:
299+
self.res[k].paste(tile.crop((self.c2x, self.c1y, self.c3x, self.c2y)), (self.x + self.p2x, self.y + self.p1y))
300+
# bottom left corner
301+
if self.x == 0 and self.y == self.image_height - self.patch_size:
302+
self.res[k].paste(tile.crop((self.c0x, self.c2y, self.c1x, self.c3y)), (self.x, self.y + self.p2y))
303+
# bottom row
304+
if self.y == self.image_height - self.patch_size:
305+
self.res[k].paste(tile.crop((self.c1x, self.c2y, self.c2x, self.c3y)), (self.x + self.p1x, self.y + self.p2y))
306+
# bottom right corner
307+
if self.x == self.image_width - self.patch_size and self.y == self.image_height - self.patch_size:
308+
self.res[k].paste(tile.crop((self.c2x, self.c2y, self.c3x, self.c3y)), (self.x + self.p2x, self.y + self.p2y))
309+
310+
def results(self):
311+
"""
312+
Return a dictionary of result images.
313+
314+
The keys for the result images are the same as those used for
315+
the result tiles in the stitch function. This function should
316+
only be called once, since the stitched images will be cropped
317+
if the original image size was less than the patch size.
318+
"""
319+
320+
if self.orig_width != self.image_width or self.orig_height != self.image_height:
321+
return {k: im.crop((0, 0, self.orig_width, self.orig_height)) for k, im in self.res.items()}
322+
else:
323+
return {k: im for k, im in self.res.items()}
324+
325+
121326
def calculate_background_mean_value(img):
122327
img = cv2.fastNlMeansDenoisingColored(np.array(img), None, 10, 10, 7, 21)
123328
img = np.array(img, dtype=float)

0 commit comments

Comments
 (0)