Skip to content

Commit 93d00c1

Browse files
authored
replace pycocotools with faster-coco-eval (#548)
1 parent df74223 commit 93d00c1

File tree

9 files changed

+37
-868
lines changed

9 files changed

+37
-868
lines changed

rtdetrv2_pytorch/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torch>=2.0.1
22
torchvision>=0.15.2
3-
pycocotools
3+
faster-coco-eval>=1.6.5
44
PyYAML
55
tensorboard

rtdetrv2_pytorch/src/data/dataset/coco_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
torchvision.disable_beta_transforms_warning()
1313

1414
from PIL import Image
15-
from pycocotools import mask as coco_mask
15+
from faster_coco_eval.core import mask as coco_mask
1616

1717
from ._dataset import DetDataset
1818
from .._misc import convert_to_tv_tensor

rtdetrv2_pytorch/src/data/dataset/coco_eval.py

+30-105
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@
44
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
55
The difference is that there is less copy-pasting from pycocotools
66
in the end of the file, as python3 can suppress prints with contextlib
7+
8+
# MiXaiLL76 replacing pycocotools with faster-coco-eval for better performance and support.
79
"""
10+
811
import os
912
import contextlib
1013
import copy
1114
import numpy as np
1215
import torch
1316

14-
from pycocotools.cocoeval import COCOeval
15-
from pycocotools.coco import COCO
16-
import pycocotools.mask as mask_util
17-
18-
from ...misc import dist_utils
17+
from faster_coco_eval import COCO, COCOeval_faster
18+
import faster_coco_eval.core.mask as mask_util
1919
from ...core import register
20-
20+
from ...misc import dist_utils
2121
__all__ = ['CocoEvaluator',]
2222

2323

@@ -26,47 +26,49 @@ class CocoEvaluator(object):
2626
def __init__(self, coco_gt, iou_types):
2727
assert isinstance(iou_types, (list, tuple))
2828
coco_gt = copy.deepcopy(coco_gt)
29-
self.coco_gt = coco_gt
29+
self.coco_gt : COCO = coco_gt
3030
self.iou_types = iou_types
31-
31+
3232
self.coco_eval = {}
3333
for iou_type in iou_types:
34-
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
34+
self.coco_eval[iou_type] = COCOeval_faster(coco_gt, iouType=iou_type, print_function=print, separate_eval=True)
3535

3636
self.img_ids = []
3737
self.eval_imgs = {k: [] for k in iou_types}
38-
38+
3939
def cleanup(self):
4040
self.coco_eval = {}
4141
for iou_type in self.iou_types:
42-
self.coco_eval[iou_type] = COCOeval(self.coco_gt, iouType=iou_type)
42+
self.coco_eval[iou_type] = COCOeval_faster(self.coco_gt, iouType=iou_type, print_function=print, separate_eval=True)
4343
self.img_ids = []
4444
self.eval_imgs = {k: [] for k in self.iou_types}
45-
46-
45+
46+
4747
def update(self, predictions):
4848
img_ids = list(np.unique(list(predictions.keys())))
4949
self.img_ids.extend(img_ids)
5050

5151
for iou_type in self.iou_types:
5252
results = self.prepare(predictions, iou_type)
53+
coco_eval = self.coco_eval[iou_type]
5354

54-
# suppress pycocotools prints
5555
with open(os.devnull, 'w') as devnull:
5656
with contextlib.redirect_stdout(devnull):
57-
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
58-
coco_eval = self.coco_eval[iou_type]
59-
60-
coco_eval.cocoDt = coco_dt
61-
coco_eval.params.imgIds = list(img_ids)
62-
img_ids, eval_imgs = evaluate(coco_eval)
57+
coco_dt = self.coco_gt.loadRes(results) if results else COCO()
58+
coco_eval.cocoDt = coco_dt
59+
coco_eval.params.imgIds = list(img_ids)
60+
coco_eval.evaluate()
6361

64-
self.eval_imgs[iou_type].append(eval_imgs)
62+
self.eval_imgs[iou_type].append(np.array(coco_eval._evalImgs_cpp).reshape(len(coco_eval.params.catIds), len(coco_eval.params.areaRng), len(coco_eval.params.imgIds)))
6563

6664
def synchronize_between_processes(self):
6765
for iou_type in self.iou_types:
68-
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
69-
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
66+
img_ids, eval_imgs = merge(self.img_ids, self.eval_imgs[iou_type])
67+
68+
coco_eval = self.coco_eval[iou_type]
69+
coco_eval.params.imgIds = img_ids
70+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
71+
coco_eval._evalImgs_cpp = eval_imgs
7072

7173
def accumulate(self):
7274
for coco_eval in self.coco_eval.values():
@@ -177,7 +179,6 @@ def convert_to_xywh(boxes):
177179
xmin, ymin, xmax, ymax = boxes.unbind(1)
178180
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
179181

180-
181182
def merge(img_ids, eval_imgs):
182183
all_img_ids = dist_utils.all_gather(img_ids)
183184
all_eval_imgs = dist_utils.all_gather(eval_imgs)
@@ -188,90 +189,14 @@ def merge(img_ids, eval_imgs):
188189

189190
merged_eval_imgs = []
190191
for p in all_eval_imgs:
191-
merged_eval_imgs.append(p)
192+
merged_eval_imgs.extend(p)
193+
192194

193195
merged_img_ids = np.array(merged_img_ids)
194-
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
196+
merged_eval_imgs = np.concatenate(merged_eval_imgs, axis=2).ravel()
197+
# merged_eval_imgs = np.array(merged_eval_imgs).T.ravel()
195198

196199
# keep only unique (and in sorted order) images
197200
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
198-
merged_eval_imgs = merged_eval_imgs[..., idx]
199-
200-
return merged_img_ids, merged_eval_imgs
201-
202-
203-
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
204-
img_ids, eval_imgs = merge(img_ids, eval_imgs)
205-
img_ids = list(img_ids)
206-
eval_imgs = list(eval_imgs.flatten())
207-
208-
coco_eval.evalImgs = eval_imgs
209-
coco_eval.params.imgIds = img_ids
210-
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
211-
212-
213-
#################################################################
214-
# From pycocotools, just removed the prints and fixed
215-
# a Python3 bug about unicode not defined
216-
#################################################################
217-
218-
219-
# import io
220-
# from contextlib import redirect_stdout
221-
# def evaluate(imgs):
222-
# with redirect_stdout(io.StringIO()):
223-
# imgs.evaluate()
224-
# return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
225-
226-
227-
def evaluate(self):
228-
"""
229-
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
230-
:return: None
231-
"""
232-
# tic = time.time()
233-
# print('Running per image evaluation...')
234-
p = self.params
235-
# add backward compatibility if useSegm is specified in params
236-
if p.useSegm is not None:
237-
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
238-
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
239-
# print('Evaluate annotation type *{}*'.format(p.iouType))
240-
p.imgIds = list(np.unique(p.imgIds))
241-
if p.useCats:
242-
p.catIds = list(np.unique(p.catIds))
243-
p.maxDets = sorted(p.maxDets)
244-
self.params = p
245-
246-
self._prepare()
247-
# loop through images, area range, max detection number
248-
catIds = p.catIds if p.useCats else [-1]
249-
250-
if p.iouType == 'segm' or p.iouType == 'bbox':
251-
computeIoU = self.computeIoU
252-
elif p.iouType == 'keypoints':
253-
computeIoU = self.computeOks
254-
self.ious = {
255-
(imgId, catId): computeIoU(imgId, catId)
256-
for imgId in p.imgIds
257-
for catId in catIds}
258-
259-
evaluateImg = self.evaluateImg
260-
maxDet = p.maxDets[-1]
261-
evalImgs = [
262-
evaluateImg(imgId, catId, areaRng, maxDet)
263-
for catId in catIds
264-
for areaRng in p.areaRng
265-
for imgId in p.imgIds
266-
]
267-
# this is NOT in the pycocotools code, but could be done outside
268-
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
269-
self._paramsEval = copy.deepcopy(self.params)
270-
# toc = time.time()
271-
# print('DONE (t={:0.2f}s).'.format(toc-tic))
272-
return p.imgIds, evalImgs
273-
274-
#################################################################
275-
# end of straight copy from pycocotools, just removing the prints
276-
#################################################################
277201

202+
return merged_img_ids.tolist(), merged_eval_imgs.tolist()

rtdetrv2_pytorch/src/data/dataset/coco_fasteval.py

-139
This file was deleted.

rtdetrv2_pytorch/src/data/dataset/coco_utils.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
import torch.utils.data
1010
import torchvision
1111
import torchvision.transforms.functional as TVF
12-
from pycocotools import mask as coco_mask
13-
from pycocotools.coco import COCO
14-
12+
from faster_coco_eval import COCO
13+
import faster_coco_eval.core.mask as mask_util
1514

1615
def convert_coco_poly_to_mask(segmentations, height, width):
1716
masks = []
1817
for polygons in segmentations:
19-
rles = coco_mask.frPyObjects(polygons, height, width)
20-
mask = coco_mask.decode(rles)
18+
rles = mask_util.frPyObjects(polygons, height, width)
19+
mask = mask_util.decode(rles)
2120
if len(mask.shape) < 3:
2221
mask = mask[..., None]
2322
mask = torch.as_tensor(mask, dtype=torch.uint8)
@@ -169,7 +168,7 @@ def convert_to_coco_api(ds):
169168
ann["iscrowd"] = iscrowd[i]
170169
ann["id"] = ann_id
171170
if "masks" in targets:
172-
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
171+
ann["segmentation"] = mask_util.encode(masks[i].numpy())
173172
if "keypoints" in targets:
174173
ann["keypoints"] = keypoints[i]
175174
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])

0 commit comments

Comments
 (0)