Skip to content

Commit c0b2e80

Browse files
hhaAndroidchhluoRangiLyuAronLinAndreaPi
authored andcommitted
Refactor YOLOX (open-mmlab#6443)
* Fix aug test error when the number of prediction bboxes is 0 (open-mmlab#6398) * Fix aug test error when the number of prediction bboxes is 0 * test * test * fix lint * Support custom pin_memory and persistent_workers * [Docs] Chinese version of robustness_benchmarking.md (open-mmlab#6375) * Chinese version of robustness_benchmarking.md * Update docs_zh-CN/robustness_benchmarking.md Co-authored-by: RangiLyu <[email protected]> * Update docs_zh-CN/robustness_benchmarking.md Co-authored-by: RangiLyu <[email protected]> * Update docs_zh-CN/robustness_benchmarking.md Co-authored-by: RangiLyu <[email protected]> * Update docs_zh-CN/robustness_benchmarking.md Co-authored-by: RangiLyu <[email protected]> * Update docs_zh-CN/robustness_benchmarking.md Co-authored-by: RangiLyu <[email protected]> * Update docs_zh-CN/robustness_benchmarking.md Co-authored-by: RangiLyu <[email protected]> * Update robustness_benchmarking.md * Update robustness_benchmarking.md * Update robustness_benchmarking.md * Update robustness_benchmarking.md * Update robustness_benchmarking.md * Update robustness_benchmarking.md Co-authored-by: RangiLyu <[email protected]> * update yolox_s * update yolox_s * support dynamic eval interval * fix some error * support ceph * fix none error * fix batch error * replace resize * fix comment * fix docstr * Update the link of checkpoints (open-mmlab#6460) * [Feature]: Support plot confusion matrix. (open-mmlab#6344) * remove pin_memory * update * fix unittest * update cfg * fix error * add unittest * [Fix] Fix SpatialReductionAttention in PVT. (open-mmlab#6488) * [Fix] Fix SpatialReductionAttention in PVT * Add warning * Save coco summarize print information to logger (open-mmlab#6505) * Fix type error in 2_new_data_mode (open-mmlab#6469) * Always map location to cpu when load checkpoint (open-mmlab#6405) * configs: update groie README (open-mmlab#6401) Signed-off-by: Leonardo Rossi <[email protected]> * [Fix] fix config path in docs (open-mmlab#6396) * [Enchance] Set a random seed when the user does not set a seed. (open-mmlab#6457) * fix random seed bug * add comment * enchance random seed * rename Co-authored-by: Haobo Yuan <[email protected]> * [BugFixed] fix wrong trunc_normal_init use (open-mmlab#6432) * fix wrong trunc_normal_init use * fix wrong trunc_normal_init use * fix open-mmlab#6446 Co-authored-by: Uno Wu <[email protected]> Co-authored-by: Leonardo Rossi <[email protected]> Co-authored-by: BigDong <[email protected]> Co-authored-by: Haian Huang(深度眸) <[email protected]> Co-authored-by: Haobo Yuan <[email protected]> Co-authored-by: Shusheng Yang <[email protected]> * bump version to v2.18.1 (open-mmlab#6510) * bump version to v2.18.1 * Update changelog.md * add some comment * fix some comment * update readme * fix lint * add reduce mean * update * update readme * update params Co-authored-by: Cedric Luo <[email protected]> Co-authored-by: RangiLyu <[email protected]> Co-authored-by: Guangchen Lin <[email protected]> Co-authored-by: Andrea Panizza <[email protected]> Co-authored-by: Uno Wu <[email protected]> Co-authored-by: Leonardo Rossi <[email protected]> Co-authored-by: BigDong <[email protected]> Co-authored-by: Haobo Yuan <[email protected]> Co-authored-by: Shusheng Yang <[email protected]>
1 parent fced16b commit c0b2e80

File tree

20 files changed

+503
-184
lines changed

20 files changed

+503
-184
lines changed

configs/yolox/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
| Backbone | size | Mem (GB) | box AP | Config | Download |
1919
|:---------:|:-------:|:-------:|:-------:|:--------:|:------:|
20-
| YOLOX-Tiny | 416 | 3.6 | 31.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_tiny_8x8_300e_coco.py) |[model](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250.log.json) |
20+
| YOLOX-s | 640 | 7.6 | 40.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_s_8x8_300e_coco.py) |[model](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711.log.json) |
21+
2122

2223
**Note**:
2324

2425
1. The test score threshold is 0.001.
25-
2. We find that the performance is unstable and may fluctuate by about 0.7 mAP. We will continue to investigate and improve it.

configs/yolox/metafile.yml

+6-5
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@ Collections:
1818
URL: https://github.com/open-mmlab/mmdetection/blob/v2.15.1/mmdet/models/detectors/yolox.py#L6
1919
Version: v2.15.1
2020

21+
2122
Models:
22-
- Name: yolox_tiny_8x8_300e_coco
23+
- Name: yolox_s_8x8_300e_coco
2324
In Collection: YOLOX
24-
Config: configs/yolox/yolox_tiny_8x8_300e_coco.py
25+
Config: configs/yolox/yolox_s_8x8_300e_coco.py
2526
Metadata:
26-
Training Memory (GB): 3.6
27+
Training Memory (GB): 7.6
2728
Epochs: 300
2829
Results:
2930
- Task: Object Detection
3031
Dataset: COCO
3132
Metrics:
32-
box AP: 31.6
33-
Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth
33+
box AP: 40.5
34+
Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth

configs/yolox/yolox_s_8x8_300e_coco.py

+49-31
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
_base_ = ['../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py']
22

3+
img_scale = (640, 640)
4+
35
# model settings
46
model = dict(
57
type='YOLOX',
8+
input_size=img_scale,
9+
random_size_range=(15, 25),
10+
random_size_interval=10,
611
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
712
neck=dict(
813
type='YOLOXPAFPN',
@@ -20,11 +25,6 @@
2025
data_root = 'data/coco/'
2126
dataset_type = 'CocoDataset'
2227

23-
img_norm_cfg = dict(
24-
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
25-
26-
img_scale = (640, 640)
27-
2828
train_pipeline = [
2929
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
3030
dict(
@@ -36,16 +36,19 @@
3636
img_scale=img_scale,
3737
ratio_range=(0.8, 1.6),
3838
pad_val=114.0),
39-
dict(
40-
type='PhotoMetricDistortion',
41-
brightness_delta=32,
42-
contrast_range=(0.5, 1.5),
43-
saturation_range=(0.5, 1.5),
44-
hue_delta=18),
39+
dict(type='YOLOXHSVRandomAug'),
4540
dict(type='RandomFlip', flip_ratio=0.5),
46-
dict(type='Resize', keep_ratio=True),
47-
dict(type='Pad', pad_to_square=True, pad_val=114.0),
48-
dict(type='Normalize', **img_norm_cfg),
41+
# According to the official implementation, multi-scale
42+
# training is not considered here but in the
43+
# 'mmdet/models/detectors/yolox.py'.
44+
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
45+
dict(
46+
type='Pad',
47+
pad_to_square=True,
48+
# If the image is three-channel, the pad value needs
49+
# to be set separately for each channel.
50+
pad_val=dict(img=(114.0, 114.0, 114.0))),
51+
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
4952
dict(type='DefaultFormatBundle'),
5053
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
5154
]
@@ -57,13 +60,12 @@
5760
ann_file=data_root + 'annotations/instances_train2017.json',
5861
img_prefix=data_root + 'train2017/',
5962
pipeline=[
60-
dict(type='LoadImageFromFile', to_float32=True),
63+
dict(type='LoadImageFromFile'),
6164
dict(type='LoadAnnotations', with_bbox=True)
6265
],
6366
filter_empty_gt=False,
6467
),
65-
pipeline=train_pipeline,
66-
dynamic_scale=img_scale)
68+
pipeline=train_pipeline)
6769

6870
test_pipeline = [
6971
dict(type='LoadImageFromFile'),
@@ -74,16 +76,19 @@
7476
transforms=[
7577
dict(type='Resize', keep_ratio=True),
7678
dict(type='RandomFlip'),
77-
dict(type='Pad', size=img_scale, pad_val=114.0),
78-
dict(type='Normalize', **img_norm_cfg),
79+
dict(
80+
type='Pad',
81+
pad_to_square=True,
82+
pad_val=dict(img=(114.0, 114.0, 114.0))),
7983
dict(type='DefaultFormatBundle'),
8084
dict(type='Collect', keys=['img'])
8185
])
8286
]
8387

8488
data = dict(
8589
samples_per_gpu=8,
86-
workers_per_gpu=2,
90+
workers_per_gpu=4,
91+
persistent_workers=True,
8792
train=train_dataset,
8893
val=dict(
8994
type=dataset_type,
@@ -107,6 +112,11 @@
107112
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
108113
optimizer_config = dict(grad_clip=None)
109114

115+
max_epochs = 300
116+
num_last_epochs = 15
117+
resume_from = None
118+
interval = 10
119+
110120
# learning policy
111121
lr_config = dict(
112122
_delete_=True,
@@ -116,27 +126,35 @@
116126
warmup_by_epoch=True,
117127
warmup_ratio=1,
118128
warmup_iters=5, # 5 epoch
119-
num_last_epochs=15,
129+
num_last_epochs=num_last_epochs,
120130
min_lr_ratio=0.05)
121-
runner = dict(type='EpochBasedRunner', max_epochs=300)
122131

123-
resume_from = None
124-
interval = 10
132+
runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
125133

126134
custom_hooks = [
127-
dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48),
128135
dict(
129-
type='SyncRandomSizeHook',
130-
ratio_range=(14, 26),
131-
img_scale=img_scale,
136+
type='YOLOXModeSwitchHook',
137+
num_last_epochs=num_last_epochs,
132138
priority=48),
133139
dict(
134140
type='SyncNormHook',
135-
num_last_epochs=15,
141+
num_last_epochs=num_last_epochs,
136142
interval=interval,
137143
priority=48),
138-
dict(type='ExpMomentumEMAHook', resume_from=resume_from, priority=49)
144+
dict(
145+
type='ExpMomentumEMAHook',
146+
resume_from=resume_from,
147+
momentum=0.0001,
148+
priority=49)
139149
]
140150
checkpoint_config = dict(interval=interval)
141-
evaluation = dict(interval=interval, metric='bbox')
151+
evaluation = dict(
152+
save_best='auto',
153+
# The evaluation interval is 'interval' when running epoch is
154+
# less than ‘max_epochs - num_last_epochs’.
155+
# The evaluation interval is 1 when running epoch is greater than
156+
# or equal to ‘max_epochs - num_last_epochs’.
157+
interval=interval,
158+
dynamic_intervals=[(max_epochs - num_last_epochs, 1)],
159+
metric='bbox')
142160
log_config = dict(interval=50)

configs/yolox/yolox_tiny_8x8_300e_coco.py

+12-37
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22

33
# model settings
44
model = dict(
5+
random_size_range=(10, 20),
56
backbone=dict(deepen_factor=0.33, widen_factor=0.375),
67
neck=dict(in_channels=[96, 192, 384], out_channels=96),
78
bbox_head=dict(in_channels=96, feat_channels=96))
89

9-
# dataset settings
10-
img_norm_cfg = dict(
11-
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
12-
1310
img_scale = (640, 640)
1411

1512
train_pipeline = [
@@ -18,16 +15,14 @@
1815
type='RandomAffine',
1916
scaling_ratio_range=(0.5, 1.5),
2017
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
21-
dict(
22-
type='PhotoMetricDistortion',
23-
brightness_delta=32,
24-
contrast_range=(0.5, 1.5),
25-
saturation_range=(0.5, 1.5),
26-
hue_delta=18),
18+
dict(type='YOLOXHSVRandomAug'),
2719
dict(type='RandomFlip', flip_ratio=0.5),
28-
dict(type='Resize', keep_ratio=True),
29-
dict(type='Pad', pad_to_square=True, pad_val=114.0),
30-
dict(type='Normalize', **img_norm_cfg),
20+
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
21+
dict(
22+
type='Pad',
23+
pad_to_square=True,
24+
pad_val=dict(img=(114.0, 114.0, 114.0))),
25+
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
3126
dict(type='DefaultFormatBundle'),
3227
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
3328
]
@@ -41,8 +36,10 @@
4136
transforms=[
4237
dict(type='Resize', keep_ratio=True),
4338
dict(type='RandomFlip'),
44-
dict(type='Pad', size=(416, 416), pad_val=114.0),
45-
dict(type='Normalize', **img_norm_cfg),
39+
dict(
40+
type='Pad',
41+
pad_to_square=True,
42+
pad_val=dict(img=(114.0, 114.0, 114.0))),
4643
dict(type='DefaultFormatBundle'),
4744
dict(type='Collect', keys=['img'])
4845
])
@@ -54,25 +51,3 @@
5451
train=train_dataset,
5552
val=dict(pipeline=test_pipeline),
5653
test=dict(pipeline=test_pipeline))
57-
58-
resume_from = None
59-
interval = 10
60-
61-
# Execute in the order of insertion when the priority is the same.
62-
# The smaller the value, the higher the priority
63-
custom_hooks = [
64-
dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48),
65-
dict(
66-
type='SyncRandomSizeHook',
67-
ratio_range=(10, 20),
68-
img_scale=img_scale,
69-
priority=48),
70-
dict(
71-
type='SyncNormHook',
72-
num_last_epochs=15,
73-
interval=interval,
74-
priority=48),
75-
dict(type='ExpMomentumEMAHook', resume_from=resume_from, priority=49)
76-
]
77-
checkpoint_config = dict(interval=interval)
78-
evaluation = dict(interval=interval, metric='bbox')

mmdet/core/bbox/assigners/sim_ota_assigner.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ def get_in_gt_and_in_center_info(self, priors, gt_bboxes):
228228
def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
229229
matching_matrix = torch.zeros_like(cost)
230230
# select candidate topk ious for dynamic-k calculation
231-
topk_ious, _ = torch.topk(pairwise_ious, self.candidate_topk, dim=0)
231+
candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
232+
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
232233
# calculate dynamic k for each gt
233234
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
234235
for gt_idx in range(num_gt):

mmdet/core/evaluation/eval_hooks.py

+65
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,52 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import bisect
23
import os.path as osp
34

5+
import mmcv
46
import torch.distributed as dist
57
from mmcv.runner import DistEvalHook as BaseDistEvalHook
68
from mmcv.runner import EvalHook as BaseEvalHook
79
from torch.nn.modules.batchnorm import _BatchNorm
810

911

12+
def _calc_dynamic_intervals(start_interval, dynamic_interval_list):
13+
assert mmcv.is_list_of(dynamic_interval_list, tuple)
14+
15+
dynamic_milestones = [0]
16+
dynamic_milestones.extend(
17+
[dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
18+
dynamic_intervals = [start_interval]
19+
dynamic_intervals.extend(
20+
[dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
21+
return dynamic_milestones, dynamic_intervals
22+
23+
1024
class EvalHook(BaseEvalHook):
1125

26+
def __init__(self, *args, dynamic_intervals=None, **kwargs):
27+
super(EvalHook, self).__init__(*args, **kwargs)
28+
29+
self.use_dynamic_intervals = dynamic_intervals is not None
30+
if self.use_dynamic_intervals:
31+
self.dynamic_milestones, self.dynamic_intervals = \
32+
_calc_dynamic_intervals(self.interval, dynamic_intervals)
33+
34+
def _decide_interval(self, runner):
35+
if self.use_dynamic_intervals:
36+
progress = runner.epoch if self.by_epoch else runner.iter
37+
step = bisect.bisect(self.dynamic_milestones, (progress + 1))
38+
# Dynamically modify the evaluation interval
39+
self.interval = self.dynamic_intervals[step - 1]
40+
41+
def before_train_epoch(self, runner):
42+
"""Evaluate the model only at the start of training by epoch."""
43+
self._decide_interval(runner)
44+
super().before_train_epoch(runner)
45+
46+
def before_train_iter(self, runner):
47+
self._decide_interval(runner)
48+
super().before_train_iter(runner)
49+
1250
def _do_evaluate(self, runner):
1351
"""perform evaluation and save ckpt."""
1452
if not self._should_evaluate(runner):
@@ -22,8 +60,35 @@ def _do_evaluate(self, runner):
2260
self._save_ckpt(runner, key_score)
2361

2462

63+
# Note: Considering that MMCV's EvalHook updated its interface in V1.3.16,
64+
# in order to avoid strong version dependency, we did not directly
65+
# inherit EvalHook but BaseDistEvalHook.
2566
class DistEvalHook(BaseDistEvalHook):
2667

68+
def __init__(self, *args, dynamic_intervals=None, **kwargs):
69+
super(DistEvalHook, self).__init__(*args, **kwargs)
70+
71+
self.use_dynamic_intervals = dynamic_intervals is not None
72+
if self.use_dynamic_intervals:
73+
self.dynamic_milestones, self.dynamic_intervals = \
74+
_calc_dynamic_intervals(self.interval, dynamic_intervals)
75+
76+
def _decide_interval(self, runner):
77+
if self.use_dynamic_intervals:
78+
progress = runner.epoch if self.by_epoch else runner.iter
79+
step = bisect.bisect(self.dynamic_milestones, (progress + 1))
80+
# Dynamically modify the evaluation interval
81+
self.interval = self.dynamic_intervals[step - 1]
82+
83+
def before_train_epoch(self, runner):
84+
"""Evaluate the model only at the start of training by epoch."""
85+
self._decide_interval(runner)
86+
super().before_train_epoch(runner)
87+
88+
def before_train_iter(self, runner):
89+
self._decide_interval(runner)
90+
super().before_train_iter(runner)
91+
2792
def _do_evaluate(self, runner):
2893
"""perform evaluation and save ckpt."""
2994
# Synchronization of BatchNorm's buffer (running_mean

mmdet/core/hook/yolox_mode_switch_hook.py

+14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(self,
2323
skip_type_keys=('Mosaic', 'RandomAffine', 'MixUp')):
2424
self.num_last_epochs = num_last_epochs
2525
self.skip_type_keys = skip_type_keys
26+
self._restart_dataloader = False
2627

2728
def before_train_epoch(self, runner):
2829
"""Close mosaic and mixup augmentation and switches to use L1 loss."""
@@ -33,6 +34,19 @@ def before_train_epoch(self, runner):
3334
model = model.module
3435
if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
3536
runner.logger.info('No mosaic and mixup aug now!')
37+
# The dataset pipeline cannot be updated when persistent_workers
38+
# is True, so we need to force the dataloader's multi-process
39+
# restart. This is a very hacky approach.
3640
train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
41+
if hasattr(train_loader, 'persistent_workers'
42+
) and train_loader.persistent_workers is True:
43+
train_loader._DataLoader__initialized = False
44+
train_loader._iterator = None
45+
self._restart_dataloader = True
3746
runner.logger.info('Add additional L1 loss now!')
3847
model.bbox_head.use_l1 = True
48+
else:
49+
# Once the restart is complete, we need to restore
50+
# the initialization flag.
51+
if self._restart_dataloader:
52+
train_loader._DataLoader__initialized = True

0 commit comments

Comments
 (0)