4
4
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
5
5
The difference is that there is less copy-pasting from pycocotools
6
6
in the end of the file, as python3 can suppress prints with contextlib
7
+
8
+ # MiXaiLL76 replacing pycocotools with faster-coco-eval for better performance and support.
7
9
"""
10
+
8
11
import os
9
12
import contextlib
10
13
import copy
11
14
import numpy as np
12
15
import torch
13
16
14
- from pycocotools .cocoeval import COCOeval
15
- from pycocotools .coco import COCO
16
- import pycocotools .mask as mask_util
17
-
18
- from ...misc import dist_utils
17
+ from faster_coco_eval import COCO , COCOeval_faster
18
+ import faster_coco_eval .core .mask as mask_util
19
19
from ...core import register
20
-
20
+ from ... misc import dist_utils
21
21
__all__ = ['CocoEvaluator' ,]
22
22
23
23
@@ -26,47 +26,49 @@ class CocoEvaluator(object):
26
26
def __init__ (self , coco_gt , iou_types ):
27
27
assert isinstance (iou_types , (list , tuple ))
28
28
coco_gt = copy .deepcopy (coco_gt )
29
- self .coco_gt = coco_gt
29
+ self .coco_gt : COCO = coco_gt
30
30
self .iou_types = iou_types
31
-
31
+
32
32
self .coco_eval = {}
33
33
for iou_type in iou_types :
34
- self .coco_eval [iou_type ] = COCOeval (coco_gt , iouType = iou_type )
34
+ self .coco_eval [iou_type ] = COCOeval_faster (coco_gt , iouType = iou_type , print_function = print , separate_eval = True )
35
35
36
36
self .img_ids = []
37
37
self .eval_imgs = {k : [] for k in iou_types }
38
-
38
+
39
39
def cleanup (self ):
40
40
self .coco_eval = {}
41
41
for iou_type in self .iou_types :
42
- self .coco_eval [iou_type ] = COCOeval (self .coco_gt , iouType = iou_type )
42
+ self .coco_eval [iou_type ] = COCOeval_faster (self .coco_gt , iouType = iou_type , print_function = print , separate_eval = True )
43
43
self .img_ids = []
44
44
self .eval_imgs = {k : [] for k in self .iou_types }
45
-
46
-
45
+
46
+
47
47
def update (self , predictions ):
48
48
img_ids = list (np .unique (list (predictions .keys ())))
49
49
self .img_ids .extend (img_ids )
50
50
51
51
for iou_type in self .iou_types :
52
52
results = self .prepare (predictions , iou_type )
53
+ coco_eval = self .coco_eval [iou_type ]
53
54
54
- # suppress pycocotools prints
55
55
with open (os .devnull , 'w' ) as devnull :
56
56
with contextlib .redirect_stdout (devnull ):
57
- coco_dt = COCO .loadRes (self .coco_gt , results ) if results else COCO ()
58
- coco_eval = self .coco_eval [iou_type ]
59
-
60
- coco_eval .cocoDt = coco_dt
61
- coco_eval .params .imgIds = list (img_ids )
62
- img_ids , eval_imgs = evaluate (coco_eval )
57
+ coco_dt = self .coco_gt .loadRes (results ) if results else COCO ()
58
+ coco_eval .cocoDt = coco_dt
59
+ coco_eval .params .imgIds = list (img_ids )
60
+ coco_eval .evaluate ()
63
61
64
- self .eval_imgs [iou_type ].append (eval_imgs )
62
+ self .eval_imgs [iou_type ].append (np . array ( coco_eval . _evalImgs_cpp ). reshape ( len ( coco_eval . params . catIds ), len ( coco_eval . params . areaRng ), len ( coco_eval . params . imgIds )) )
65
63
66
64
def synchronize_between_processes (self ):
67
65
for iou_type in self .iou_types :
68
- self .eval_imgs [iou_type ] = np .concatenate (self .eval_imgs [iou_type ], 2 )
69
- create_common_coco_eval (self .coco_eval [iou_type ], self .img_ids , self .eval_imgs [iou_type ])
66
+ img_ids , eval_imgs = merge (self .img_ids , self .eval_imgs [iou_type ])
67
+
68
+ coco_eval = self .coco_eval [iou_type ]
69
+ coco_eval .params .imgIds = img_ids
70
+ coco_eval ._paramsEval = copy .deepcopy (coco_eval .params )
71
+ coco_eval ._evalImgs_cpp = eval_imgs
70
72
71
73
def accumulate (self ):
72
74
for coco_eval in self .coco_eval .values ():
@@ -177,7 +179,6 @@ def convert_to_xywh(boxes):
177
179
xmin , ymin , xmax , ymax = boxes .unbind (1 )
178
180
return torch .stack ((xmin , ymin , xmax - xmin , ymax - ymin ), dim = 1 )
179
181
180
-
181
182
def merge (img_ids , eval_imgs ):
182
183
all_img_ids = dist_utils .all_gather (img_ids )
183
184
all_eval_imgs = dist_utils .all_gather (eval_imgs )
@@ -188,90 +189,14 @@ def merge(img_ids, eval_imgs):
188
189
189
190
merged_eval_imgs = []
190
191
for p in all_eval_imgs :
191
- merged_eval_imgs .append (p )
192
+ merged_eval_imgs .extend (p )
193
+
192
194
193
195
merged_img_ids = np .array (merged_img_ids )
194
- merged_eval_imgs = np .concatenate (merged_eval_imgs , 2 )
196
+ merged_eval_imgs = np .concatenate (merged_eval_imgs , axis = 2 ).ravel ()
197
+ # merged_eval_imgs = np.array(merged_eval_imgs).T.ravel()
195
198
196
199
# keep only unique (and in sorted order) images
197
200
merged_img_ids , idx = np .unique (merged_img_ids , return_index = True )
198
- merged_eval_imgs = merged_eval_imgs [..., idx ]
199
-
200
- return merged_img_ids , merged_eval_imgs
201
-
202
-
203
- def create_common_coco_eval (coco_eval , img_ids , eval_imgs ):
204
- img_ids , eval_imgs = merge (img_ids , eval_imgs )
205
- img_ids = list (img_ids )
206
- eval_imgs = list (eval_imgs .flatten ())
207
-
208
- coco_eval .evalImgs = eval_imgs
209
- coco_eval .params .imgIds = img_ids
210
- coco_eval ._paramsEval = copy .deepcopy (coco_eval .params )
211
-
212
-
213
- #################################################################
214
- # From pycocotools, just removed the prints and fixed
215
- # a Python3 bug about unicode not defined
216
- #################################################################
217
-
218
-
219
- # import io
220
- # from contextlib import redirect_stdout
221
- # def evaluate(imgs):
222
- # with redirect_stdout(io.StringIO()):
223
- # imgs.evaluate()
224
- # return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
225
-
226
-
227
- def evaluate (self ):
228
- """
229
- Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
230
- :return: None
231
- """
232
- # tic = time.time()
233
- # print('Running per image evaluation...')
234
- p = self .params
235
- # add backward compatibility if useSegm is specified in params
236
- if p .useSegm is not None :
237
- p .iouType = 'segm' if p .useSegm == 1 else 'bbox'
238
- print ('useSegm (deprecated) is not None. Running {} evaluation' .format (p .iouType ))
239
- # print('Evaluate annotation type *{}*'.format(p.iouType))
240
- p .imgIds = list (np .unique (p .imgIds ))
241
- if p .useCats :
242
- p .catIds = list (np .unique (p .catIds ))
243
- p .maxDets = sorted (p .maxDets )
244
- self .params = p
245
-
246
- self ._prepare ()
247
- # loop through images, area range, max detection number
248
- catIds = p .catIds if p .useCats else [- 1 ]
249
-
250
- if p .iouType == 'segm' or p .iouType == 'bbox' :
251
- computeIoU = self .computeIoU
252
- elif p .iouType == 'keypoints' :
253
- computeIoU = self .computeOks
254
- self .ious = {
255
- (imgId , catId ): computeIoU (imgId , catId )
256
- for imgId in p .imgIds
257
- for catId in catIds }
258
-
259
- evaluateImg = self .evaluateImg
260
- maxDet = p .maxDets [- 1 ]
261
- evalImgs = [
262
- evaluateImg (imgId , catId , areaRng , maxDet )
263
- for catId in catIds
264
- for areaRng in p .areaRng
265
- for imgId in p .imgIds
266
- ]
267
- # this is NOT in the pycocotools code, but could be done outside
268
- evalImgs = np .asarray (evalImgs ).reshape (len (catIds ), len (p .areaRng ), len (p .imgIds ))
269
- self ._paramsEval = copy .deepcopy (self .params )
270
- # toc = time.time()
271
- # print('DONE (t={:0.2f}s).'.format(toc-tic))
272
- return p .imgIds , evalImgs
273
-
274
- #################################################################
275
- # end of straight copy from pycocotools, just removing the prints
276
- #################################################################
277
201
202
+ return merged_img_ids .tolist (), merged_eval_imgs .tolist ()
0 commit comments