Skip to content

Commit d578628

Browse files
authored
Add lite hr net (PaddlePaddle#3793)
* add LiteHRNet backbone and config .YML * test lite18-network param acc is same with ori-model 1. fix default darkpose=ON, 2. += is not inplace add new keypoint model Lite-HRNet * add new keypoint model Lite-HRNet * 1. Add description of network type; 2. use channel_shuffle in ops.py * use normal to init conv2d * add network type description
1 parent 55fcc1f commit d578628

File tree

8 files changed

+1191
-24
lines changed

8 files changed

+1191
-24
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
use_gpu: true
2+
log_iter: 5
3+
save_dir: output
4+
snapshot_epoch: 10
5+
weights: output/lite_hrnet_18_256x192_coco/model_final
6+
epoch: 210
7+
num_joints: &num_joints 17
8+
pixel_std: &pixel_std 200
9+
metric: KeyPointTopDownCOCOEval
10+
num_classes: 1
11+
train_height: &train_height 256
12+
train_width: &train_width 192
13+
trainsize: &trainsize [*train_width, *train_height]
14+
hmsize: &hmsize [48, 64]
15+
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
16+
17+
18+
#####model
19+
architecture: TopDownHRNet
20+
21+
TopDownHRNet:
22+
backbone: LiteHRNet
23+
post_process: HRNetPostProcess
24+
flip_perm: *flip_perm
25+
num_joints: *num_joints
26+
width: &width 40
27+
loss: KeyPointMSELoss
28+
use_dark: false
29+
30+
LiteHRNet:
31+
network_type: lite_18
32+
freeze_at: -1
33+
freeze_norm: false
34+
return_idx: [0]
35+
36+
KeyPointMSELoss:
37+
use_target_weight: true
38+
loss_scale: 1.0
39+
40+
#####optimizer
41+
LearningRate:
42+
base_lr: 0.002
43+
schedulers:
44+
- !PiecewiseDecay
45+
milestones: [170, 200]
46+
gamma: 0.1
47+
- !LinearWarmup
48+
start_factor: 0.001
49+
steps: 500
50+
51+
OptimizerBuilder:
52+
optimizer:
53+
type: Adam
54+
regularizer:
55+
factor: 0.0
56+
type: L2
57+
58+
59+
#####data
60+
TrainDataset:
61+
!KeypointTopDownCocoDataset
62+
image_dir: train2017
63+
anno_path: annotations/person_keypoints_train2017.json
64+
dataset_dir: dataset/coco
65+
num_joints: *num_joints
66+
trainsize: *trainsize
67+
pixel_std: *pixel_std
68+
use_gt_bbox: True
69+
70+
71+
EvalDataset:
72+
!KeypointTopDownCocoDataset
73+
image_dir: val2017
74+
anno_path: annotations/person_keypoints_val2017.json
75+
dataset_dir: dataset/coco
76+
num_joints: *num_joints
77+
trainsize: *trainsize
78+
pixel_std: *pixel_std
79+
use_gt_bbox: True
80+
image_thre: 0.0
81+
82+
83+
TestDataset:
84+
!ImageFolder
85+
anno_path: dataset/coco/keypoint_imagelist.txt
86+
87+
worker_num: 2
88+
global_mean: &global_mean [0.485, 0.456, 0.406]
89+
global_std: &global_std [0.229, 0.224, 0.225]
90+
TrainReader:
91+
sample_transforms:
92+
- RandomFlipHalfBodyTransform:
93+
scale: 0.25
94+
rot: 30
95+
num_joints_half_body: 8
96+
prob_half_body: 0.3
97+
pixel_std: *pixel_std
98+
trainsize: *trainsize
99+
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
100+
flip_pairs: *flip_perm
101+
- TopDownAffine:
102+
trainsize: *trainsize
103+
- ToHeatmapsTopDown:
104+
hmsize: *hmsize
105+
sigma: 2
106+
batch_transforms:
107+
- NormalizeImage:
108+
mean: *global_mean
109+
std: *global_std
110+
is_scale: true
111+
- Permute: {}
112+
batch_size: 64
113+
shuffle: true
114+
drop_last: false
115+
116+
EvalReader:
117+
sample_transforms:
118+
- TopDownAffine:
119+
trainsize: *trainsize
120+
batch_transforms:
121+
- NormalizeImage:
122+
mean: *global_mean
123+
std: *global_std
124+
is_scale: true
125+
- Permute: {}
126+
batch_size: 16
127+
128+
TestReader:
129+
inputs_def:
130+
image_shape: [3, *train_height, *train_width]
131+
sample_transforms:
132+
- Decode: {}
133+
- TopDownEvalAffine:
134+
trainsize: *trainsize
135+
- NormalizeImage:
136+
mean: *global_mean
137+
std: *global_std
138+
is_scale: true
139+
- Permute: {}
140+
batch_size: 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
use_gpu: true
2+
log_iter: 5
3+
save_dir: output
4+
snapshot_epoch: 10
5+
weights: output/lite_hrnet_30_256x192_coco/model_final
6+
epoch: 210
7+
num_joints: &num_joints 17
8+
pixel_std: &pixel_std 200
9+
metric: KeyPointTopDownCOCOEval
10+
num_classes: 1
11+
train_height: &train_height 256
12+
train_width: &train_width 192
13+
trainsize: &trainsize [*train_width, *train_height]
14+
hmsize: &hmsize [48, 64]
15+
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
16+
17+
18+
#####model
19+
architecture: TopDownHRNet
20+
21+
TopDownHRNet:
22+
backbone: LiteHRNet
23+
post_process: HRNetPostProcess
24+
flip_perm: *flip_perm
25+
num_joints: *num_joints
26+
width: &width 40
27+
loss: KeyPointMSELoss
28+
use_dark: false
29+
30+
LiteHRNet:
31+
network_type: lite_30
32+
freeze_at: -1
33+
freeze_norm: false
34+
return_idx: [0]
35+
36+
KeyPointMSELoss:
37+
use_target_weight: true
38+
loss_scale: 1.0
39+
40+
#####optimizer
41+
LearningRate:
42+
base_lr: 0.002
43+
schedulers:
44+
- !PiecewiseDecay
45+
milestones: [170, 200]
46+
gamma: 0.1
47+
- !LinearWarmup
48+
start_factor: 0.001
49+
steps: 500
50+
51+
OptimizerBuilder:
52+
optimizer:
53+
type: Adam
54+
regularizer:
55+
factor: 0.0
56+
type: L2
57+
58+
59+
#####data
60+
TrainDataset:
61+
!KeypointTopDownCocoDataset
62+
image_dir: train2017
63+
anno_path: annotations/person_keypoints_train2017.json
64+
dataset_dir: dataset/coco
65+
num_joints: *num_joints
66+
trainsize: *trainsize
67+
pixel_std: *pixel_std
68+
use_gt_bbox: True
69+
70+
71+
EvalDataset:
72+
!KeypointTopDownCocoDataset
73+
image_dir: val2017
74+
anno_path: annotations/person_keypoints_val2017.json
75+
dataset_dir: dataset/coco
76+
num_joints: *num_joints
77+
trainsize: *trainsize
78+
pixel_std: *pixel_std
79+
use_gt_bbox: True
80+
image_thre: 0.0
81+
82+
83+
TestDataset:
84+
!ImageFolder
85+
anno_path: dataset/coco/keypoint_imagelist.txt
86+
87+
worker_num: 4
88+
global_mean: &global_mean [0.485, 0.456, 0.406]
89+
global_std: &global_std [0.229, 0.224, 0.225]
90+
TrainReader:
91+
sample_transforms:
92+
- RandomFlipHalfBodyTransform:
93+
scale: 0.25
94+
rot: 30
95+
num_joints_half_body: 8
96+
prob_half_body: 0.3
97+
pixel_std: *pixel_std
98+
trainsize: *trainsize
99+
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
100+
flip_pairs: *flip_perm
101+
- TopDownAffine:
102+
trainsize: *trainsize
103+
- ToHeatmapsTopDown:
104+
hmsize: *hmsize
105+
sigma: 2
106+
batch_transforms:
107+
- NormalizeImage:
108+
mean: *global_mean
109+
std: *global_std
110+
is_scale: true
111+
- Permute: {}
112+
batch_size: 64
113+
shuffle: true
114+
drop_last: false
115+
116+
EvalReader:
117+
sample_transforms:
118+
- TopDownAffine:
119+
trainsize: *trainsize
120+
batch_transforms:
121+
- NormalizeImage:
122+
mean: *global_mean
123+
std: *global_std
124+
is_scale: true
125+
- Permute: {}
126+
batch_size: 16
127+
128+
TestReader:
129+
inputs_def:
130+
image_shape: [3, *train_height, *train_width]
131+
sample_transforms:
132+
- Decode: {}
133+
- TopDownEvalAffine:
134+
trainsize: *trainsize
135+
- NormalizeImage:
136+
mean: *global_mean
137+
std: *global_std
138+
is_scale: true
139+
- Permute: {}
140+
batch_size: 1

ppdet/modeling/architectures/keypoint_hrnet.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,20 @@ def __init__(self,
4141
post_process='HRNetPostProcess',
4242
flip_perm=None,
4343
flip=True,
44-
shift_heatmap=True):
44+
shift_heatmap=True,
45+
use_dark=True):
4546
"""
46-
HRNnet network, see https://arxiv.org/abs/1902.09212
47+
HRNet network, see https://arxiv.org/abs/1902.09212
4748
4849
Args:
4950
backbone (nn.Layer): backbone instance
5051
post_process (object): `HRNetPostProcess` instance
5152
flip_perm (list): The left-right joints exchange order list
53+
use_dark(bool): Whether to use DARK in post processing
5254
"""
5355
super(TopDownHRNet, self).__init__()
5456
self.backbone = backbone
55-
self.post_process = HRNetPostProcess()
57+
self.post_process = HRNetPostProcess(use_dark)
5658
self.loss = loss
5759
self.flip_perm = flip_perm
5860
self.flip = flip
@@ -218,7 +220,6 @@ def get_final_preds(self, heatmaps, center, scale, kernelsize=3):
218220
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
219221
maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
220222
"""
221-
222223
coords, maxvals = self.get_max_preds(heatmaps)
223224

224225
heatmap_height = heatmaps.shape[2]

ppdet/modeling/backbones/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from . import mobilenet_v1
1919
from . import mobilenet_v3
2020
from . import hrnet
21+
from . import lite_hrnet
2122
from . import blazenet
2223
from . import ghostnet
2324
from . import senet
@@ -31,6 +32,7 @@
3132
from .mobilenet_v1 import *
3233
from .mobilenet_v3 import *
3334
from .hrnet import *
35+
from .lite_hrnet import *
3436
from .blazenet import *
3537
from .ghostnet import *
3638
from .senet import *

0 commit comments

Comments
 (0)