-
Notifications
You must be signed in to change notification settings - Fork 352
/
Copy pathdistill.py
237 lines (208 loc) · 9.62 KB
/
distill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import math
import logging
import paddle
import argparse
import functools
import numpy as np
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import models
from utility import add_arguments, print_arguments, _download, _decompress
from paddleslim.dist import merge, l2, soft_label
from paddle.distributed import fleet
from paddle.distributed.fleet import DistributedStrategy
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('save_inference', bool, False, "Whether to save inference model.")
add_arg('total_images', int, 1281167, "Training image number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'")
add_arg('log_period', int, 20, "Log period in batches.")
add_arg('model', str, "MobileNet", "Set the network to use.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('teacher_model', str, "ResNet50_vd", "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str, "./ResNet50_vd_pretrained", "Whether to use pretrained model.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(
boundaries=bd, values=lr, verbose=False)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return learning_rate, optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return learning_rate, optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
fleet.init(is_collective=True)
if args.data == "cifar10":
train_dataset = paddle.vision.datasets.Cifar10(mode='train')
val_dataset = paddle.vision.datasets.Cifar10(mode='test')
class_dim = 10
image_shape = "3,32,32"
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train')
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
student_program = paddle.static.Program()
s_startup = paddle.static.Program()
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
place = places[0]
if args.use_gpu:
devices_num = paddle.framework.core.get_cuda_device_count()
else:
devices_num = int(os.environ.get('CPU_NUM', 1))
with paddle.static.program_guard(student_program, s_startup):
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
sampler = paddle.io.DistributedBatchSampler(
train_dataset,
shuffle=False,
drop_last=True,
batch_size=args.batch_size)
train_loader = paddle.io.DataLoader(
train_dataset,
places=places,
feed_list=[image, label],
batch_sampler=sampler,
return_list=False,
use_shared_memory=False,
num_workers=4)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
use_shared_memory=False,
batch_size=args.batch_size,
shuffle=False)
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = student_program.clone(for_test=True)
exe = paddle.static.Executor(place)
teacher_model = models.__dict__[args.teacher_model]()
# define teacher program
teacher_program = paddle.static.Program()
t_startup = paddle.static.Program()
with paddle.static.program_guard(teacher_program, t_startup):
with paddle.utils.unique_name.guard():
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
predict = teacher_model.net(image, class_dim=class_dim)
exe.run(t_startup)
if not os.path.exists(args.teacher_pretrained_model):
_download(
'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
'.')
_decompress('./ResNet50_vd_pretrained.tar')
assert args.teacher_pretrained_model and os.path.exists(
args.teacher_pretrained_model
), "teacher_pretrained_model should be set when teacher_model is not None."
def if_exist(var):
exist = os.path.exists(
os.path.join(args.teacher_pretrained_model, var.name))
if args.data == "cifar10" and (var.name == 'fc_0.w_0' or
var.name == 'fc_0.b_0'):
exist = False
return exist
paddle.static.load(teacher_program, args.teacher_pretrained_model, exe)
data_name_map = {'image': 'image'}
merge(teacher_program, student_program, data_name_map, place)
build_strategy = paddle.static.BuildStrategy()
dist_strategy = DistributedStrategy()
dist_strategy.build_strategy = build_strategy
with paddle.static.program_guard(student_program, s_startup):
distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0",
student_program)
loss = avg_cost + distill_loss
lr, opt = create_optimizer(args)
opt = fleet.distributed_optimizer(opt, strategy=dist_strategy)
opt.minimize(loss)
exe.run(s_startup)
parallel_main = student_program
for epoch_id in range(args.num_epochs):
for step_id, data in enumerate(train_loader):
loss_1, loss_2, loss_3 = exe.run(
parallel_main,
feed=data,
fetch_list=[loss.name, avg_cost.name, distill_loss.name])
if step_id % args.log_period == 0:
_logger.info(
"train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}".
format(epoch_id, step_id,
lr.get_lr(), loss_1, loss_2, loss_3))
lr.step()
val_acc1s = []
val_acc5s = []
for step_id, data in enumerate(valid_loader):
val_loss, val_acc1, val_acc5 = exe.run(
val_program,
data,
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
val_acc1s.append(val_acc1)
val_acc5s.append(val_acc5)
if step_id % args.log_period == 0:
_logger.info(
"valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
format(epoch_id, step_id, val_loss, val_acc1, val_acc5))
if args.save_inference:
paddle.static.save_inference_model(
os.path.join("./saved_models", str(epoch_id)), [image], [out],
exe,
program=student_program)
_logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
epoch_id, np.mean(val_acc1s), np.mean(val_acc5s)))
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
paddle.enable_static()
main()