Skip to content

Commit 92b40cf

Browse files
author
liukai
committed
add group fisher algorithm implementation and configs
1 parent 1209937 commit 92b40cf

30 files changed

+1288
-0
lines changed

projects/group_fisher/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .modules.group_fisher_algorthm import GroupFisherAlgorithm
2+
from .modules.group_fisher_channel_mutator import GroupFisherChannelMutator
3+
from .modules.group_fisher_channel_unit import GroupFisherChannelUnit
4+
from .modules.group_fisher_ops import GroupFisherMixin
5+
6+
__all__ = [
7+
'GroupFisherChannelMutator',
8+
'GroupFisherAlgorithm',
9+
'GroupFisherMixin',
10+
'GroupFisherChannelUnit',
11+
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Group_fisher pruning
2+
3+
> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf)
4+
5+
## Abstract
6+
7+
Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy.
8+
9+
![pipeline](https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png)
10+
11+
## Results and models
12+
13+
### Detection
14+
15+
| Dataset | Detector | Backbone | Baseline(mAP) | Pruned&Finetuned(mAP) | Model | log |
16+
| :-----: | :-------: | :------: | :-----------: | :-------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -------------------------- |
17+
| COCO | RetinaNet | R-50-FPN | 36.5 | 36.5 (50% flops) | [Baseline](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth)/[Pruned](<>)/[Finetuned](<>) | [Prune](<>)/[Finetune](<>) |
18+
19+
## Citation
20+
21+
@InProceedings{liu2021group,
22+
title = {Group Fisher Pruning for Practical Network Compression},
23+
author = {Liu, Liyang and Zhang, Shilong and Kuang, Zhanghui and Zhou, Aojun and Xue, Jing-Hao and Wang, Xinjiang and Chen, Yimin and Yang, Wenming and Liao, Qingmin and Zhang, Wayne},
24+
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
25+
year = {2021},
26+
series = {Proceedings of Machine Learning Research},
27+
month = {18--24 Jul},
28+
publisher ={PMLR},
29+
}
30+
31+
## Get Started
32+
33+
### Pruning
34+
35+
```bash
36+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \
37+
configs/pruning/mmdet/group_fisher/group-fisher-pruning_retinanet_resnet50_8xb2_coco.py 8 \
38+
--work-dir $WORK_DIR
39+
```
40+
41+
### Finetune
42+
43+
Update the `pruned_path` to your local file path that saves the pruned checkpoint.
44+
45+
```bash
46+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \
47+
configs/pruning/mmdet/group_fisher/group-fisher-finetune_retinanet_resnet50_8xb2_coco.py 8 \
48+
--work-dir $WORK_DIR
49+
```
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
| name | flop | param | finetune |
2+
| ----------- | ----- | ----- | -------- |
3+
| baseline | 0.319 | 3.5 | 71.86 |
4+
| fisher_act | 0.20 | 3.14 | 70.79 |
5+
| fisher_flop | 0.20 | 2.78 | 70.87 |
6+
7+
fisher_act
8+
{
9+
"backbone.conv1.conv\_(0, 32)_32": 21,
10+
"backbone.layer1.0.conv.1.conv_(0, 16)_16": 10,
11+
"backbone.layer2.0.conv.0.conv_(0, 96)_96": 45,
12+
"backbone.layer2.0.conv.2.conv_(0, 24)_24": 24,
13+
"backbone.layer2.1.conv.0.conv_(0, 144)_144": 73,
14+
"backbone.layer3.0.conv.0.conv_(0, 144)_144": 85,
15+
"backbone.layer3.0.conv.2.conv_(0, 32)_32": 32,
16+
"backbone.layer3.1.conv.0.conv_(0, 192)_192": 95,
17+
"backbone.layer3.2.conv.0.conv_(0, 192)_192": 76,
18+
"backbone.layer4.0.conv.0.conv_(0, 192)_192": 160,
19+
"backbone.layer4.0.conv.2.conv_(0, 64)_64": 64,
20+
"backbone.layer4.1.conv.0.conv_(0, 384)_384": 204,
21+
"backbone.layer4.2.conv.0.conv_(0, 384)_384": 200,
22+
"backbone.layer4.3.conv.0.conv_(0, 384)_384": 217,
23+
"backbone.layer5.0.conv.0.conv_(0, 384)_384": 344,
24+
"backbone.layer5.0.conv.2.conv_(0, 96)_96": 96,
25+
"backbone.layer5.1.conv.0.conv_(0, 576)_576": 348,
26+
"backbone.layer5.2.conv.0.conv_(0, 576)_576": 338,
27+
"backbone.layer6.0.conv.0.conv_(0, 576)_576": 543,
28+
"backbone.layer6.0.conv.2.conv_(0, 160)_160": 160,
29+
"backbone.layer6.1.conv.0.conv_(0, 960)_960": 810,
30+
"backbone.layer6.2.conv.0.conv_(0, 960)_960": 803,
31+
"backbone.layer7.0.conv.0.conv_(0, 960)_960": 944,
32+
"backbone.layer7.0.conv.2.conv_(0, 320)\_320": 320
33+
}
34+
fisher_flop
35+
{
36+
"backbone.conv1.conv\_(0, 32)_32": 27,
37+
"backbone.layer1.0.conv.1.conv_(0, 16)_16": 16,
38+
"backbone.layer2.0.conv.0.conv_(0, 96)_96": 77,
39+
"backbone.layer2.0.conv.2.conv_(0, 24)_24": 24,
40+
"backbone.layer2.1.conv.0.conv_(0, 144)_144": 85,
41+
"backbone.layer3.0.conv.0.conv_(0, 144)_144": 115,
42+
"backbone.layer3.0.conv.2.conv_(0, 32)_32": 32,
43+
"backbone.layer3.1.conv.0.conv_(0, 192)_192": 102,
44+
"backbone.layer3.2.conv.0.conv_(0, 192)_192": 95,
45+
"backbone.layer4.0.conv.0.conv_(0, 192)_192": 181,
46+
"backbone.layer4.0.conv.2.conv_(0, 64)_64": 64,
47+
"backbone.layer4.1.conv.0.conv_(0, 384)_384": 169,
48+
"backbone.layer4.2.conv.0.conv_(0, 384)_384": 176,
49+
"backbone.layer4.3.conv.0.conv_(0, 384)_384": 180,
50+
"backbone.layer5.0.conv.0.conv_(0, 384)_384": 308,
51+
"backbone.layer5.0.conv.2.conv_(0, 96)_96": 96,
52+
"backbone.layer5.1.conv.0.conv_(0, 576)_576": 223,
53+
"backbone.layer5.2.conv.0.conv_(0, 576)_576": 241,
54+
"backbone.layer6.0.conv.0.conv_(0, 576)_576": 511,
55+
"backbone.layer6.0.conv.2.conv_(0, 160)_160": 160,
56+
"backbone.layer6.1.conv.0.conv_(0, 960)_960": 467,
57+
"backbone.layer6.2.conv.0.conv_(0, 960)_960": 510,
58+
"backbone.layer7.0.conv.0.conv_(0, 960)_960": 771,
59+
"backbone.layer7.0.conv.2.conv_(0, 320)\_320": 320
60+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
_base_ = './mobilenet_v2_group_fisher_prune_flop.py'
2+
custom_imports = dict(imports=['projects'])
3+
4+
algorithm = _base_.model
5+
pruned_path = './work_dirs/mobilenet_v2_group_fisher_prune_flop/flops_0.65.pth'
6+
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
7+
8+
model = dict(
9+
_delete_=True,
10+
_scope_='mmrazor',
11+
type='PruneDeployWrapper',
12+
algorithm=algorithm,
13+
)
14+
15+
# restore optimizer
16+
17+
optim_wrapper = dict(
18+
_delete_=True,
19+
optimizer=dict(
20+
type='SGD',
21+
lr=0.045,
22+
momentum=0.9,
23+
weight_decay=4e-05,
24+
_scope_='mmcls'))
25+
param_scheduler = dict(
26+
_delete_=True,
27+
type='StepLR',
28+
by_epoch=True,
29+
step_size=1,
30+
gamma=0.98,
31+
_scope_='mmcls')
32+
33+
# remove pruning related hooks
34+
custom_hooks = _base_.custom_hooks[:-2]
35+
36+
# delete ddp
37+
model_wrapper_cfg = None
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = '../mobilenet_v2_group_fisher_prune.py'
2+
model = dict(
3+
mutator=dict(
4+
channel_unit_cfg=dict(default_args=dict(detla_type='flop', ), ), ), )
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
bash ./tools/dist_train.sh projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_prune_flop.py 8
2+
bash ./tools/dist_train.sh projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_finetune_flop.py 8
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
_base_ = './mobilenet_v2_group_fisher_prune.py'
2+
custom_imports = dict(imports=['projects'])
3+
4+
algorithm = _base_.model
5+
pruned_path = './work_dirs/mobilenet_v2_group_fisher_prune/flops_0.65.pth'
6+
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
7+
8+
model = dict(
9+
_delete_=True,
10+
_scope_='mmrazor',
11+
type='PruneDeployWrapper',
12+
algorithm=algorithm,
13+
)
14+
15+
# restore optimizer
16+
17+
optim_wrapper = dict(
18+
_delete_=True,
19+
optimizer=dict(
20+
type='SGD',
21+
lr=0.045,
22+
momentum=0.9,
23+
weight_decay=4e-05,
24+
_scope_='mmcls'))
25+
param_scheduler = dict(
26+
_delete_=True,
27+
type='StepLR',
28+
by_epoch=True,
29+
step_size=1,
30+
gamma=0.98,
31+
_scope_='mmcls')
32+
33+
# remove pruning related hooks
34+
custom_hooks = _base_.custom_hooks[:-2]
35+
36+
# delete ddp
37+
model_wrapper_cfg = None
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
_base_ = 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py'
2+
custom_imports = dict(imports=['projects'])
3+
architecture = _base_.model
4+
pretrained_path = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa
5+
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
6+
architecture.update({
7+
'data_preprocessor': _base_.data_preprocessor,
8+
})
9+
data_preprocessor = None
10+
11+
model = dict(
12+
_delete_=True,
13+
_scope_='mmrazor',
14+
type='GroupFisherAlgorithm',
15+
architecture=architecture,
16+
interval=25,
17+
mutator=dict(
18+
type='GroupFisherChannelMutator',
19+
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
20+
channel_unit_cfg=dict(
21+
type='GroupFisherChannelUnit',
22+
default_args=dict(detla_type='act', ),
23+
),
24+
),
25+
)
26+
model_wrapper_cfg = dict(
27+
type='mmrazor.GroupFisherDDP',
28+
broadcast_buffers=False,
29+
)
30+
# update optimizer
31+
32+
optim_wrapper = dict(optimizer=dict(lr=0.004, ))
33+
param_scheduler = None
34+
35+
custom_hooks = [
36+
dict(type='mmrazor.PruningStructureHook'),
37+
dict(
38+
type='mmrazor.ResourceInfoHook',
39+
interval=25,
40+
demo_input=dict(
41+
type='mmrazor.DefaultDemoInput',
42+
input_shape=[1, 3, 224, 224],
43+
),
44+
save_ckpt_delta_thr=[0.65, 0.33],
45+
),
46+
]
47+
48+
# original
49+
"""
50+
optim_wrapper = dict(
51+
optimizer=dict(
52+
type='SGD',
53+
lr=0.045,
54+
momentum=0.9,
55+
weight_decay=4e-05,
56+
_scope_='mmcls'))
57+
param_scheduler = dict(
58+
type='StepLR', by_epoch=True, step_size=1, gamma=0.98, _scope_='mmcls')
59+
"""
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_prune.py 8
2+
bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_finetune.py 8
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
| name | flop | param | finetune |
2+
| ----------- | ---- | ----- | -------- |
3+
| fisher_act | 2.05 | 16.22 | 75.2 |
4+
| fisher_flop | 2.05 | 16.22 | 75.6 |
5+
6+
act:
7+
"backbone.conv1\_(0, 64)_64": 61,
8+
"backbone.layer1.0.conv1_(0, 64)_64": 27,
9+
"backbone.layer1.0.conv2_(0, 64)_64": 35,
10+
"backbone.layer1.0.conv3_(0, 256)_256": 241,
11+
"backbone.layer1.1.conv1_(0, 64)_64": 32,
12+
"backbone.layer1.1.conv2_(0, 64)_64": 29,
13+
"backbone.layer1.2.conv1_(0, 64)_64": 27,
14+
"backbone.layer1.2.conv2_(0, 64)_64": 42,
15+
"backbone.layer2.0.conv1_(0, 128)_128": 87,
16+
"backbone.layer2.0.conv2_(0, 128)_128": 107,
17+
"backbone.layer2.0.conv3_(0, 512)_512": 512,
18+
"backbone.layer2.1.conv1_(0, 128)_128": 44,
19+
"backbone.layer2.1.conv2_(0, 128)_128": 50,
20+
"backbone.layer2.2.conv1_(0, 128)_128": 52,
21+
"backbone.layer2.2.conv2_(0, 128)_128": 81,
22+
"backbone.layer2.3.conv1_(0, 128)_128": 47,
23+
"backbone.layer2.3.conv2_(0, 128)_128": 50,
24+
"backbone.layer3.0.conv1_(0, 256)_256": 210,
25+
"backbone.layer3.0.conv2_(0, 256)_256": 206,
26+
"backbone.layer3.0.conv3_(0, 1024)_1024": 1024,
27+
"backbone.layer3.1.conv1_(0, 256)_256": 107,
28+
"backbone.layer3.1.conv2_(0, 256)_256": 108,
29+
"backbone.layer3.2.conv1_(0, 256)_256": 86,
30+
"backbone.layer3.2.conv2_(0, 256)_256": 126,
31+
"backbone.layer3.3.conv1_(0, 256)_256": 91,
32+
"backbone.layer3.3.conv2_(0, 256)_256": 112,
33+
"backbone.layer3.4.conv1_(0, 256)_256": 98,
34+
"backbone.layer3.4.conv2_(0, 256)_256": 110,
35+
"backbone.layer3.5.conv1_(0, 256)_256": 112,
36+
"backbone.layer3.5.conv2_(0, 256)_256": 115,
37+
"backbone.layer4.0.conv1_(0, 512)_512": 397,
38+
"backbone.layer4.0.conv2_(0, 512)_512": 427,
39+
"backbone.layer4.1.conv1_(0, 512)_512": 373,
40+
"backbone.layer4.1.conv2_(0, 512)_512": 348,
41+
"backbone.layer4.2.conv1_(0, 512)_512": 433,
42+
"backbone.layer4.2.conv2_(0, 512)\_512": 384
43+
44+
flop:
45+
"backbone.conv1\_(0, 64)_64": 61,
46+
"backbone.layer1.0.conv1_(0, 64)_64": 28,
47+
"backbone.layer1.0.conv2_(0, 64)_64": 35,
48+
"backbone.layer1.0.conv3_(0, 256)_256": 242,
49+
"backbone.layer1.1.conv1_(0, 64)_64": 31,
50+
"backbone.layer1.1.conv2_(0, 64)_64": 28,
51+
"backbone.layer1.2.conv1_(0, 64)_64": 26,
52+
"backbone.layer1.2.conv2_(0, 64)_64": 41,
53+
"backbone.layer2.0.conv1_(0, 128)_128": 90,
54+
"backbone.layer2.0.conv2_(0, 128)_128": 107,
55+
"backbone.layer2.0.conv3_(0, 512)_512": 509,
56+
"backbone.layer2.1.conv1_(0, 128)_128": 42,
57+
"backbone.layer2.1.conv2_(0, 128)_128": 50,
58+
"backbone.layer2.2.conv1_(0, 128)_128": 51,
59+
"backbone.layer2.2.conv2_(0, 128)_128": 84,
60+
"backbone.layer2.3.conv1_(0, 128)_128": 49,
61+
"backbone.layer2.3.conv2_(0, 128)_128": 51,
62+
"backbone.layer3.0.conv1_(0, 256)_256": 210,
63+
"backbone.layer3.0.conv2_(0, 256)_256": 207,
64+
"backbone.layer3.0.conv3_(0, 1024)_1024": 1024,
65+
"backbone.layer3.1.conv1_(0, 256)_256": 103,
66+
"backbone.layer3.1.conv2_(0, 256)_256": 108,
67+
"backbone.layer3.2.conv1_(0, 256)_256": 90,
68+
"backbone.layer3.2.conv2_(0, 256)_256": 124,
69+
"backbone.layer3.3.conv1_(0, 256)_256": 94,
70+
"backbone.layer3.3.conv2_(0, 256)_256": 114,
71+
"backbone.layer3.4.conv1_(0, 256)_256": 99,
72+
"backbone.layer3.4.conv2_(0, 256)_256": 111,
73+
"backbone.layer3.5.conv1_(0, 256)_256": 108,
74+
"backbone.layer3.5.conv2_(0, 256)_256": 111,
75+
"backbone.layer4.0.conv1_(0, 512)_512": 400,
76+
"backbone.layer4.0.conv2_(0, 512)_512": 421,
77+
"backbone.layer4.1.conv1_(0, 512)_512": 377,
78+
"backbone.layer4.1.conv2_(0, 512)_512": 347,
79+
"backbone.layer4.2.conv1_(0, 512)_512": 443,
80+
"backbone.layer4.2.conv2_(0, 512)\_512": 376
81+
}

0 commit comments

Comments
 (0)