15
15
import models
16
16
from utility import add_arguments , print_arguments
17
17
import paddle .vision .transforms as T
18
+ from paddle .distributed import fleet
18
19
19
20
_logger = get_logger (__name__ , level = logging .INFO )
20
21
40
41
add_arg ('criterion' , str , "l1_norm" , "The prune criterion to be used, support l1_norm and batch_norm_scale." )
41
42
add_arg ('save_inference' , bool , False , "Whether to save inference model." )
42
43
add_arg ('ce_test' , bool , False , "Whether to CE test." )
44
+ parser .add_argument ('fleet' , action = 'store_true' , help = "Whether to turn on distributed training." )
43
45
# yapf: enable
44
46
45
47
model_list = models .__all__
@@ -96,6 +98,8 @@ def create_optimizer(args, step_per_epoch):
96
98
97
99
98
100
def compress (args ):
101
+ if args .fleet :
102
+ fleet .init (is_collective = True )
99
103
100
104
num_workers = 4
101
105
shuffle = True
@@ -130,8 +134,8 @@ def compress(args):
130
134
else :
131
135
raise ValueError ("{} is not supported." .format (args .data ))
132
136
image_shape = [int (m ) for m in image_shape .split ("," )]
133
- assert args .model in model_list , "{} is not in lists: {}" .format (args . model ,
134
- model_list )
137
+ assert args .model in model_list , "{} is not in lists: {}" .format (
138
+ args . model , model_list )
135
139
places = paddle .static .cuda_places (
136
140
) if args .use_gpu else paddle .static .cpu_places ()
137
141
place = places [0 ]
@@ -140,13 +144,16 @@ def compress(args):
140
144
name = 'image' , shape = [None ] + image_shape , dtype = 'float32' )
141
145
label = paddle .static .data (name = 'label' , shape = [None , 1 ], dtype = 'int64' )
142
146
batch_size_per_card = int (args .batch_size / len (places ))
147
+ sampler = paddle .io .DistributedBatchSampler (
148
+ train_dataset ,
149
+ shuffle = shuffle ,
150
+ drop_last = True ,
151
+ batch_size = batch_size_per_card )
143
152
train_loader = paddle .io .DataLoader (
144
153
train_dataset ,
145
154
places = places ,
146
155
feed_list = [image , label ],
147
- drop_last = True ,
148
- batch_size = batch_size_per_card ,
149
- shuffle = shuffle ,
156
+ batch_sampler = sampler ,
150
157
return_list = False ,
151
158
use_shared_memory = True ,
152
159
num_workers = num_workers )
@@ -171,6 +178,8 @@ def compress(args):
171
178
acc_top5 = paddle .metric .accuracy (input = out , label = label , k = 5 )
172
179
val_program = paddle .static .default_main_program ().clone (for_test = True )
173
180
opt , learning_rate = create_optimizer (args , step_per_epoch )
181
+ if args .fleet :
182
+ opt = fleet .distributed_optimizer (opt )
174
183
opt .minimize (avg_cost )
175
184
176
185
exe .run (paddle .static .default_startup_program ())
@@ -180,8 +189,8 @@ def compress(args):
180
189
def if_exist (var ):
181
190
return os .path .exists (os .path .join (args .pretrained_model , var .name ))
182
191
183
- _logger .info ("Load pretrained model from {}" . format (
184
- args .pretrained_model ))
192
+ _logger .info (
193
+ "Load pretrained model from {}" . format ( args .pretrained_model ))
185
194
paddle .static .load (paddle .static .default_main_program (),
186
195
args .pretrained_model , exe )
187
196
@@ -247,13 +256,10 @@ def train(epoch, program):
247
256
place = place )
248
257
_logger .info ("FLOPs after pruning: {}" .format (flops (pruned_program )))
249
258
250
- build_strategy = paddle .static .BuildStrategy ()
251
- exec_strategy = paddle .static .ExecutionStrategy ()
252
- train_program = paddle .static .CompiledProgram (
253
- pruned_program ).with_data_parallel (
254
- loss_name = avg_cost .name ,
255
- build_strategy = build_strategy ,
256
- exec_strategy = exec_strategy )
259
+ if args .fleet :
260
+ train_program = paddle .static .CompiledProgram (pruned_program )
261
+ else :
262
+ train_program = pruned_program
257
263
258
264
for i in range (args .num_epochs ):
259
265
train (i , train_program )
@@ -268,8 +274,8 @@ def train(epoch, program):
268
274
infer_model_path , [image ], [out ],
269
275
exe ,
270
276
program = pruned_val_program )
271
- _logger .info ("Saved inference model into [{}]" . format (
272
- infer_model_path ))
277
+ _logger .info (
278
+ "Saved inference model into [{}]" . format ( infer_model_path ))
273
279
274
280
275
281
def main ():
0 commit comments