@@ -118,6 +118,211 @@ def stitch_tile(img, tile, tile_size, overlap_size, i, j):
118
118
img .paste (tile , (i * tile_size , j * tile_size ))
119
119
120
120
121
+ class InferenceTiler :
122
+ """
123
+ Iterable class to tile image(s) and stitch result tiles together.
124
+
125
+ To perform inference on a large image, that image will need to be
126
+ tiled into smaller tiles that can be run individually and then
127
+ stitched back together. This class wraps the functionality as an
128
+ iterable object that can accept a single image or list of images
129
+ if multiple images are taken as input for inference.
130
+
131
+ An overlap size can be specified so that neighboring tiles will
132
+ overlap at the edges, helping to reduce seams or other artifacts
133
+ near the edge of a tile. Padding of a solid color around the
134
+ perimeter of the tile is also possible, if needed. The specified
135
+ tile size includes this overlap and pad sizes, so a tile size of
136
+ 512 with an overlap size of 32 and pad size of 16 would have a
137
+ central area of 416 pixels that are stitched into the result image.
138
+
139
+ Example Usage
140
+ -------------
141
+ tiler = InferenceTiler(img, 512, 32)
142
+ for tile in tiler:
143
+ result_tiles = infer(tile)
144
+ tiler.stitch(result_tiles)
145
+ images = tiler.results()
146
+ """
147
+
148
+ def __init__ (self , orig , tile_size , overlap_size = 0 , pad_size = 0 , pad_color = (255 , 255 , 255 )):
149
+ """
150
+ Initialize for tiling an image or list of images.
151
+
152
+ Parameters
153
+ ----------
154
+ orig : Image | list(Image)
155
+ Original image or list of images to be tiled.
156
+ tile_size: int
157
+ Size (width and height) of the tiles to be generated.
158
+ overlap_size: int [default: 0]
159
+ Amount of overlap on each side of the tile.
160
+ pad_size: int [default: 0]
161
+ Amount of solid color padding around perimeter of tile.
162
+ pad_color: tuple(int, int, int) [default: (255,255,255)]
163
+ RGB color to use for padding.
164
+ """
165
+
166
+ if tile_size <= 0 :
167
+ raise ValueError ('InfereneTiler input tile_size must be positive and non-zero' )
168
+ if overlap_size < 0 :
169
+ raise ValueError ('InfereneTiler input overlap_size must be positive or zero' )
170
+ if pad_size < 0 :
171
+ raise ValueError ('InfereneTiler input pad_size must be positive or zero' )
172
+
173
+ self .single_orig = not type (orig ) is list
174
+ if self .single_orig :
175
+ orig = [orig ]
176
+
177
+ for i in range (1 , len (orig )):
178
+ if orig [i ].size != orig [0 ].size :
179
+ raise ValueError ('InferenceTiler input images do not have the same size.' )
180
+ self .orig_width = orig [0 ].width
181
+ self .orig_height = orig [0 ].height
182
+
183
+ # patch size to extract from input image, which is then padded to tile size
184
+ patch_size = tile_size - (2 * pad_size )
185
+
186
+ # make sure width and height are both at least patch_size
187
+ if orig [0 ].width < patch_size :
188
+ for i in range (len (orig )):
189
+ while orig [i ].width < patch_size :
190
+ mirrored = ImageOps .mirror (orig [i ])
191
+ orig [i ] = ImageOps .expand (orig [i ], (0 , 0 , orig [i ].width , 0 ))
192
+ orig [i ].paste (mirrored , (mirrored .width , 0 ))
193
+ orig [i ] = orig [i ].crop ((0 , 0 , patch_size , orig [i ].height ))
194
+ if orig [0 ].height < patch_size :
195
+ for i in range (len (orig )):
196
+ while orig [i ].height < patch_size :
197
+ flipped = ImageOps .flip (orig [i ])
198
+ orig [i ] = ImageOps .expand (orig [i ], (0 , 0 , 0 , orig [i ].height ))
199
+ orig [i ].paste (flipped , (0 , flipped .height ))
200
+ orig [i ] = orig [i ].crop ((0 , 0 , orig [i ].width , patch_size ))
201
+ self .image_width = orig [0 ].width
202
+ self .image_height = orig [0 ].height
203
+
204
+ overlap_width = 0 if patch_size >= self .image_width else overlap_size
205
+ overlap_height = 0 if patch_size >= self .image_height else overlap_size
206
+ center_width = patch_size - (2 * overlap_width )
207
+ center_height = patch_size - (2 * overlap_height )
208
+ if center_width <= 0 or center_height <= 0 :
209
+ raise ValueError ('InferenceTiler combined overlap_size and pad_size are too large' )
210
+
211
+ self .c0x = pad_size # crop offset for left of non-pad content in result tile
212
+ self .c0y = pad_size # crop offset for top of non-pad content in result tile
213
+ self .c1x = overlap_width + pad_size # crop offset for left of center region in result tile
214
+ self .c1y = overlap_height + pad_size # crop offset for top of center region in result tile
215
+ self .c2x = patch_size - overlap_width + pad_size # crop offset for right of center region in result tile
216
+ self .c2y = patch_size - overlap_height + pad_size # crop offset for bottom of center region in result tile
217
+ self .c3x = patch_size + pad_size # crop offset for right of non-pad content in result tile
218
+ self .c3y = patch_size + pad_size # crop offset for bottom of non-pad content in result tile
219
+ self .p1x = overlap_width # paste offset for left of center region w.r.t (x,y) coord
220
+ self .p1y = overlap_height # paste offset for top of center region w.r.t (x,y) coord
221
+ self .p2x = patch_size - overlap_width # paste offset for right of center region w.r.t (x,y) coord
222
+ self .p2y = patch_size - overlap_height # paste offset for bottom of center region w.r.t (x,y) coord
223
+
224
+ self .overlap_width = overlap_width
225
+ self .overlap_height = overlap_height
226
+ self .patch_size = patch_size
227
+ self .center_width = center_width
228
+ self .center_height = center_height
229
+
230
+ self .orig = orig
231
+ self .tile_size = tile_size
232
+ self .pad_size = pad_size
233
+ self .pad_color = pad_color
234
+ self .res = {}
235
+
236
+ def __iter__ (self ):
237
+ """
238
+ Generate the tiles as an iterable.
239
+
240
+ Tiles are created and iterated over from top left to bottom
241
+ right, going across the rows. The yielded tile(s) match the
242
+ type of the original input when initialized (either a single
243
+ image or a list of images in the same order as initialized).
244
+ The (x, y) coordinate of the current tile is maintained
245
+ internally for use in the stitch function.
246
+ """
247
+
248
+ for y in range (0 , self .image_height , self .center_height ):
249
+ for x in range (0 , self .image_width , self .center_width ):
250
+ if x + self .patch_size > self .image_width :
251
+ x = self .image_width - self .patch_size
252
+ if y + self .patch_size > self .image_height :
253
+ y = self .image_height - self .patch_size
254
+ self .x = x
255
+ self .y = y
256
+ tiles = [im .crop ((x , y , x + self .patch_size , y + self .patch_size )) for im in self .orig ]
257
+ if self .pad_size != 0 :
258
+ tiles = [ImageOps .expand (t , self .pad_size , self .pad_color ) for t in tiles ]
259
+ yield tiles [0 ] if self .single_orig else tiles
260
+
261
+ def stitch (self , result_tiles ):
262
+ """
263
+ Stitch result tiles into the result images.
264
+
265
+ The key names for the dictionary of result tiles are used to
266
+ stitch each tile into its corresponding final image in the
267
+ results attribute. If a result image does not exist for a
268
+ result tile key name, then it will be created. The result tiles
269
+ are stitched at the location from which the list iterated tile
270
+ was extracted.
271
+
272
+ Parameters
273
+ ----------
274
+ result_tiles : dict(str: Image)
275
+ Dictionary of result tiles from the inference.
276
+ """
277
+
278
+ for k , tile in result_tiles .items ():
279
+ if k not in self .res :
280
+ self .res [k ] = Image .new ('RGB' , (self .image_width , self .image_height ))
281
+ if tile .size != (self .tile_size , self .tile_size ):
282
+ tile = tile .resize ((self .tile_size , self .tile_size ))
283
+ self .res [k ].paste (tile .crop ((self .c1x , self .c1y , self .c2x , self .c2y )), (self .x + self .p1x , self .y + self .p1y ))
284
+
285
+ # top left corner
286
+ if self .x == 0 and self .y == 0 :
287
+ self .res [k ].paste (tile .crop ((self .c0x , self .c0y , self .c1x , self .c1y )), (self .x , self .y ))
288
+ # top row
289
+ if self .y == 0 :
290
+ self .res [k ].paste (tile .crop ((self .c1x , self .c0y , self .c2x , self .c1y )), (self .x + self .p1x , self .y ))
291
+ # top right corner
292
+ if self .x == self .image_width - self .patch_size and self .y == 0 :
293
+ self .res [k ].paste (tile .crop ((self .c2x , self .c0y , self .c3x , self .c1y )), (self .x + self .p2x , self .y ))
294
+ # left column
295
+ if self .x == 0 :
296
+ self .res [k ].paste (tile .crop ((self .c0x , self .c1y , self .c1x , self .c2y )), (self .x , self .y + self .p1y ))
297
+ # right column
298
+ if self .x == self .image_width - self .patch_size :
299
+ self .res [k ].paste (tile .crop ((self .c2x , self .c1y , self .c3x , self .c2y )), (self .x + self .p2x , self .y + self .p1y ))
300
+ # bottom left corner
301
+ if self .x == 0 and self .y == self .image_height - self .patch_size :
302
+ self .res [k ].paste (tile .crop ((self .c0x , self .c2y , self .c1x , self .c3y )), (self .x , self .y + self .p2y ))
303
+ # bottom row
304
+ if self .y == self .image_height - self .patch_size :
305
+ self .res [k ].paste (tile .crop ((self .c1x , self .c2y , self .c2x , self .c3y )), (self .x + self .p1x , self .y + self .p2y ))
306
+ # bottom right corner
307
+ if self .x == self .image_width - self .patch_size and self .y == self .image_height - self .patch_size :
308
+ self .res [k ].paste (tile .crop ((self .c2x , self .c2y , self .c3x , self .c3y )), (self .x + self .p2x , self .y + self .p2y ))
309
+
310
+ def results (self ):
311
+ """
312
+ Return a dictionary of result images.
313
+
314
+ The keys for the result images are the same as those used for
315
+ the result tiles in the stitch function. This function should
316
+ only be called once, since the stitched images will be cropped
317
+ if the original image size was less than the patch size.
318
+ """
319
+
320
+ if self .orig_width != self .image_width or self .orig_height != self .image_height :
321
+ return {k : im .crop ((0 , 0 , self .orig_width , self .orig_height )) for k , im in self .res .items ()}
322
+ else :
323
+ return {k : im for k , im in self .res .items ()}
324
+
325
+
121
326
def calculate_background_mean_value (img ):
122
327
img = cv2 .fastNlMeansDenoisingColored (np .array (img ), None , 10 , 10 , 7 , 21 )
123
328
img = np .array (img , dtype = float )
0 commit comments