Skip to content

Commit 2629917

Browse files
thuyngchZwwWayne
authored andcommitted
[Feature] Support Label Assignment Distillation (LAD) (open-mmlab#6342)
* add LAD * inherit LAD from KnowledgeDistillationSingleStageDetector * add configs/lad/lad_r101_paa_r50_fpn_coco_1x.py * update LAD readme * update configs/lad/README.md * try not to use abbreviations for variable names * add unittest for lad_head * update test_lad_head * remove main in tests/test_models/test_dense_heads/test_lad_head.py
1 parent 65b0fa4 commit 2629917

File tree

8 files changed

+748
-2
lines changed

8 files changed

+748
-2
lines changed

configs/lad/README.md

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Improving Object Detection by Label Assignment Distillation
2+
3+
<!-- [ALGORITHM] -->
4+
5+
```latex
6+
@inproceedings{nguyen2021improving,
7+
title={Improving Object Detection by Label Assignment Distillation},
8+
author={Chuong H. Nguyen and Thuy C. Nguyen and Tuan N. Tang and Nam L. H. Phan},
9+
booktitle = {WACV},
10+
year={2022}
11+
}
12+
```
13+
14+
## Results and Models
15+
16+
We provide config files to reproduce the object detection results in the
17+
WACV 2022 paper for Improving Object Detection by Label Assignment
18+
Distillation.
19+
20+
### PAA with LAD
21+
22+
| Teacher | Student | Training schedule | AP (val) | Config |
23+
| :-------: | :-----: | :---------------: | :------: | :----------------------------------------------------: |
24+
| -- | R-50 | 1x | 40.4 | |
25+
| -- | R-101 | 1x | 42.6 | |
26+
| R-101 | R-50 | 1x | 41.6 | [config](configs/lad/lad_r50_paa_r101_fpn_coco_1x.py) |
27+
| R-50 | R-101 | 1x | 43.2 | [config](configs/lad/lad_r101_paa_r50_fpn_coco_1x.py) |
28+
29+
## Note
30+
31+
- Meaning of Config name: lad_r50(student model)_paa(based on paa)_r101(teacher model)_fpn(neck)_coco(dataset)_1x(12 epoch).py
32+
- Results may fluctuate by about 0.2 mAP.
+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
_base_ = [
2+
'../_base_/datasets/coco_detection.py',
3+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
4+
]
5+
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.pth' # noqa
6+
model = dict(
7+
type='LAD',
8+
# student
9+
backbone=dict(
10+
type='ResNet',
11+
depth=101,
12+
num_stages=4,
13+
out_indices=(0, 1, 2, 3),
14+
frozen_stages=1,
15+
norm_cfg=dict(type='BN', requires_grad=True),
16+
norm_eval=True,
17+
style='pytorch',
18+
init_cfg=dict(type='Pretrained',
19+
checkpoint='torchvision://resnet101')),
20+
neck=dict(
21+
type='FPN',
22+
in_channels=[256, 512, 1024, 2048],
23+
out_channels=256,
24+
start_level=1,
25+
add_extra_convs='on_output',
26+
num_outs=5),
27+
bbox_head=dict(
28+
type='LADHead',
29+
reg_decoded_bbox=True,
30+
score_voting=True,
31+
topk=9,
32+
num_classes=80,
33+
in_channels=256,
34+
stacked_convs=4,
35+
feat_channels=256,
36+
anchor_generator=dict(
37+
type='AnchorGenerator',
38+
ratios=[1.0],
39+
octave_base_scale=8,
40+
scales_per_octave=1,
41+
strides=[8, 16, 32, 64, 128]),
42+
bbox_coder=dict(
43+
type='DeltaXYWHBBoxCoder',
44+
target_means=[.0, .0, .0, .0],
45+
target_stds=[0.1, 0.1, 0.2, 0.2]),
46+
loss_cls=dict(
47+
type='FocalLoss',
48+
use_sigmoid=True,
49+
gamma=2.0,
50+
alpha=0.25,
51+
loss_weight=1.0),
52+
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
53+
loss_centerness=dict(
54+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
55+
# teacher
56+
teacher_ckpt=teacher_ckpt,
57+
teacher_backbone=dict(
58+
type='ResNet',
59+
depth=50,
60+
num_stages=4,
61+
out_indices=(0, 1, 2, 3),
62+
frozen_stages=1,
63+
norm_cfg=dict(type='BN', requires_grad=True),
64+
norm_eval=True,
65+
style='pytorch'),
66+
teacher_neck=dict(
67+
type='FPN',
68+
in_channels=[256, 512, 1024, 2048],
69+
out_channels=256,
70+
start_level=1,
71+
add_extra_convs='on_output',
72+
num_outs=5),
73+
teacher_bbox_head=dict(
74+
type='LADHead',
75+
reg_decoded_bbox=True,
76+
score_voting=True,
77+
topk=9,
78+
num_classes=80,
79+
in_channels=256,
80+
stacked_convs=4,
81+
feat_channels=256,
82+
anchor_generator=dict(
83+
type='AnchorGenerator',
84+
ratios=[1.0],
85+
octave_base_scale=8,
86+
scales_per_octave=1,
87+
strides=[8, 16, 32, 64, 128]),
88+
bbox_coder=dict(
89+
type='DeltaXYWHBBoxCoder',
90+
target_means=[.0, .0, .0, .0],
91+
target_stds=[0.1, 0.1, 0.2, 0.2]),
92+
loss_cls=dict(
93+
type='FocalLoss',
94+
use_sigmoid=True,
95+
gamma=2.0,
96+
alpha=0.25,
97+
loss_weight=1.0),
98+
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
99+
loss_centerness=dict(
100+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
101+
# training and testing settings
102+
train_cfg=dict(
103+
assigner=dict(
104+
type='MaxIoUAssigner',
105+
pos_iou_thr=0.1,
106+
neg_iou_thr=0.1,
107+
min_pos_iou=0,
108+
ignore_iof_thr=-1),
109+
allowed_border=-1,
110+
pos_weight=-1,
111+
debug=False),
112+
test_cfg=dict(
113+
nms_pre=1000,
114+
min_bbox_size=0,
115+
score_thr=0.05,
116+
score_voting=True,
117+
nms=dict(type='nms', iou_threshold=0.6),
118+
max_per_img=100))
119+
data = dict(samples_per_gpu=8, workers_per_gpu=4)
120+
optimizer = dict(lr=0.01)
121+
fp16 = dict(loss_scale=512.)
+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
_base_ = [
2+
'../_base_/datasets/coco_detection.py',
3+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
4+
]
5+
teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.pth' # noqa
6+
model = dict(
7+
type='LAD',
8+
# student
9+
backbone=dict(
10+
type='ResNet',
11+
depth=50,
12+
num_stages=4,
13+
out_indices=(0, 1, 2, 3),
14+
frozen_stages=1,
15+
norm_cfg=dict(type='BN', requires_grad=True),
16+
norm_eval=True,
17+
style='pytorch',
18+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
19+
neck=dict(
20+
type='FPN',
21+
in_channels=[256, 512, 1024, 2048],
22+
out_channels=256,
23+
start_level=1,
24+
add_extra_convs='on_output',
25+
num_outs=5),
26+
bbox_head=dict(
27+
type='LADHead',
28+
reg_decoded_bbox=True,
29+
score_voting=True,
30+
topk=9,
31+
num_classes=80,
32+
in_channels=256,
33+
stacked_convs=4,
34+
feat_channels=256,
35+
anchor_generator=dict(
36+
type='AnchorGenerator',
37+
ratios=[1.0],
38+
octave_base_scale=8,
39+
scales_per_octave=1,
40+
strides=[8, 16, 32, 64, 128]),
41+
bbox_coder=dict(
42+
type='DeltaXYWHBBoxCoder',
43+
target_means=[.0, .0, .0, .0],
44+
target_stds=[0.1, 0.1, 0.2, 0.2]),
45+
loss_cls=dict(
46+
type='FocalLoss',
47+
use_sigmoid=True,
48+
gamma=2.0,
49+
alpha=0.25,
50+
loss_weight=1.0),
51+
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
52+
loss_centerness=dict(
53+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
54+
# teacher
55+
teacher_ckpt=teacher_ckpt,
56+
teacher_backbone=dict(
57+
type='ResNet',
58+
depth=101,
59+
num_stages=4,
60+
out_indices=(0, 1, 2, 3),
61+
frozen_stages=1,
62+
norm_cfg=dict(type='BN', requires_grad=True),
63+
norm_eval=True,
64+
style='pytorch'),
65+
teacher_neck=dict(
66+
type='FPN',
67+
in_channels=[256, 512, 1024, 2048],
68+
out_channels=256,
69+
start_level=1,
70+
add_extra_convs='on_output',
71+
num_outs=5),
72+
teacher_bbox_head=dict(
73+
type='LADHead',
74+
reg_decoded_bbox=True,
75+
score_voting=True,
76+
topk=9,
77+
num_classes=80,
78+
in_channels=256,
79+
stacked_convs=4,
80+
feat_channels=256,
81+
anchor_generator=dict(
82+
type='AnchorGenerator',
83+
ratios=[1.0],
84+
octave_base_scale=8,
85+
scales_per_octave=1,
86+
strides=[8, 16, 32, 64, 128]),
87+
bbox_coder=dict(
88+
type='DeltaXYWHBBoxCoder',
89+
target_means=[.0, .0, .0, .0],
90+
target_stds=[0.1, 0.1, 0.2, 0.2]),
91+
loss_cls=dict(
92+
type='FocalLoss',
93+
use_sigmoid=True,
94+
gamma=2.0,
95+
alpha=0.25,
96+
loss_weight=1.0),
97+
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
98+
loss_centerness=dict(
99+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
100+
# training and testing settings
101+
train_cfg=dict(
102+
assigner=dict(
103+
type='MaxIoUAssigner',
104+
pos_iou_thr=0.1,
105+
neg_iou_thr=0.1,
106+
min_pos_iou=0,
107+
ignore_iof_thr=-1),
108+
allowed_border=-1,
109+
pos_weight=-1,
110+
debug=False),
111+
test_cfg=dict(
112+
nms_pre=1000,
113+
min_bbox_size=0,
114+
score_thr=0.05,
115+
score_voting=True,
116+
nms=dict(type='nms', iou_threshold=0.6),
117+
max_per_img=100))
118+
data = dict(samples_per_gpu=8, workers_per_gpu=4)
119+
optimizer = dict(lr=0.01)
120+
fp16 = dict(loss_scale=512.)

mmdet/models/dense_heads/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .ga_rpn_head import GARPNHead
1919
from .gfl_head import GFLHead
2020
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
21+
from .lad_head import LADHead
2122
from .ld_head import LDHead
2223
from .nasfcos_head import NASFCOSHead
2324
from .paa_head import PAAHead
@@ -47,5 +48,5 @@
4748
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
4849
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
4950
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
50-
'DecoupledSOLOLightHead'
51+
'DecoupledSOLOLightHead', 'LADHead'
5152
]

0 commit comments

Comments
 (0)