23
23
import argparse
24
24
import functools
25
25
26
+ import paddle
26
27
import paddle .fluid as fluid
27
- from paddle .fluid .dygraph .base import to_variable
28
28
from paddleslim .common import AvgrageMeter , get_logger
29
29
from paddleslim .nas .darts import count_parameters_in_MB
30
30
@@ -72,8 +72,8 @@ def train(model, train_reader, optimizer, epoch, drop_path_prob, args):
72
72
73
73
for step_id , data in enumerate (train_reader ()):
74
74
image_np , label_np = data
75
- image = to_variable (image_np )
76
- label = to_variable (label_np )
75
+ image = paddle . to_tensor (image_np )
76
+ label = paddle . to_tensor (label_np )
77
77
label .stop_gradient = True
78
78
logits , logits_aux = model (image , drop_path_prob , True )
79
79
@@ -117,8 +117,8 @@ def valid(model, valid_reader, epoch, args):
117
117
118
118
for step_id , data in enumerate (valid_reader ()):
119
119
image_np , label_np = data
120
- image = to_variable (image_np )
121
- label = to_variable (label_np )
120
+ image = paddle . to_tensor (image_np )
121
+ label = paddle . to_tensor (label_np )
122
122
logits , _ = model (image , 0 , False )
123
123
prec1 = paddle .static .accuracy (input = logits , label = label , k = 1 )
124
124
prec5 = paddle .static .accuracy (input = logits , label = label , k = 5 )
@@ -140,83 +140,75 @@ def main(args):
140
140
place = paddle .CUDAPlace (paddle .distributed .parallel .ParallelEnv ().dev_id ) \
141
141
if args .use_data_parallel else paddle .CUDAPlace (0 )
142
142
143
- with fluid .dygraph .guard (place ):
144
- genotype = eval ("genotypes.%s" % args .arch )
145
- model = Network (
146
- C = args .init_channels ,
147
- num_classes = args .class_num ,
148
- layers = args .layers ,
149
- auxiliary = args .auxiliary ,
150
- genotype = genotype )
151
-
152
- logger .info ("param size = {:.6f}MB" .format (
153
- count_parameters_in_MB (model .parameters ())))
154
-
155
- device_num = paddle .distributed .parallel .ParallelEnv ().nranks
156
- step_per_epoch = int (args .trainset_num / (args .batch_size * device_num ))
157
- learning_rate = fluid .dygraph .CosineDecay (args .learning_rate ,
158
- step_per_epoch , args .epochs )
159
- clip = fluid .clip .GradientClipByGlobalNorm (clip_norm = args .grad_clip )
160
- optimizer = paddle .optimizer .Momentum (
161
- learning_rate ,
162
- momentum = args .momentum ,
163
- regularization = fluid .regularizer .L2Decay (args .weight_decay ),
164
- parameter_list = model .parameters (),
165
- grad_clip = clip )
166
-
167
- if args .use_data_parallel :
168
- strategy = fluid .dygraph .parallel .prepare_context ()
169
- model = fluid .dygraph .parallel .DataParallel (model , strategy )
170
-
171
- train_loader = fluid .io .DataLoader .from_generator (
172
- capacity = 64 ,
173
- use_double_buffer = True ,
174
- iterable = True ,
175
- return_list = True ,
176
- use_multiprocess = args .use_multiprocess )
177
- valid_loader = fluid .io .DataLoader .from_generator (
178
- capacity = 64 ,
179
- use_double_buffer = True ,
180
- iterable = True ,
181
- return_list = True ,
182
- use_multiprocess = args .use_multiprocess )
183
-
184
- train_reader = reader .train_valid (
185
- batch_size = args .batch_size ,
186
- is_train = True ,
187
- is_shuffle = True ,
188
- args = args )
189
- valid_reader = reader .train_valid (
190
- batch_size = args .batch_size ,
191
- is_train = False ,
192
- is_shuffle = False ,
193
- args = args )
194
- if args .use_data_parallel :
195
- train_reader = fluid .contrib .reader .distributed_batch_reader (
196
- train_reader )
197
-
198
- train_loader .set_batch_generator (train_reader , places = place )
199
- valid_loader .set_batch_generator (valid_reader , places = place )
200
-
201
- save_parameters = (not args .use_data_parallel ) or (
202
- args .use_data_parallel and
203
- paddle .distributed .parallel .ParallelEnv ().local_rank == 0 )
204
- best_acc = 0
205
- for epoch in range (args .epochs ):
206
- drop_path_prob = args .drop_path_prob * epoch / args .epochs
207
- logger .info ('Epoch {}, lr {:.6f}' .format (
208
- epoch , optimizer .current_step_lr ()))
209
- train_top1 = train (model , train_loader , optimizer , epoch ,
210
- drop_path_prob , args )
211
- logger .info ("Epoch {}, train_acc {:.6f}" .format (epoch , train_top1 ))
212
- valid_top1 = valid (model , valid_loader , epoch , args )
213
- if valid_top1 > best_acc :
214
- best_acc = valid_top1
215
- if save_parameters :
216
- paddle .save (model .state_dict (),
217
- args .model_save_dir + "/best_model" )
218
- logger .info ("Epoch {}, valid_acc {:.6f}, best_valid_acc {:.6f}" .
219
- format (epoch , valid_top1 , best_acc ))
143
+ genotype = eval ("genotypes.%s" % args .arch )
144
+ model = Network (
145
+ C = args .init_channels ,
146
+ num_classes = args .class_num ,
147
+ layers = args .layers ,
148
+ auxiliary = args .auxiliary ,
149
+ genotype = genotype )
150
+
151
+ logger .info ("param size = {:.6f}MB" .format (
152
+ count_parameters_in_MB (model .parameters ())))
153
+
154
+ device_num = paddle .distributed .parallel .ParallelEnv ().nranks
155
+ learning_rate = paddle .optimizer .lr .CosineAnnealingDecay (args .learning_rate ,
156
+ args .epochs / 2 )
157
+ clip = paddle .nn .ClipGradByGlobalNorm (args .grad_clip )
158
+ optimizer = paddle .optimizer .Momentum (
159
+ learning_rate ,
160
+ momentum = args .momentum ,
161
+ regularization = paddle .regularizer .L2Decay (args .weight_decay ),
162
+ parameter_list = model .parameters (),
163
+ grad_clip = clip )
164
+
165
+ if args .use_data_parallel :
166
+ strategy = paddle .distributed .init_parallel_env ()
167
+ model = paddle .DataParallel (model , strategy )
168
+
169
+ train_loader = paddle .io .DataLoader .from_generator (
170
+ capacity = 64 ,
171
+ use_double_buffer = True ,
172
+ iterable = True ,
173
+ return_list = True ,
174
+ use_multiprocess = args .use_multiprocess )
175
+ valid_loader = paddle .io .DataLoader .from_generator (
176
+ capacity = 64 ,
177
+ use_double_buffer = True ,
178
+ iterable = True ,
179
+ return_list = True ,
180
+ use_multiprocess = args .use_multiprocess )
181
+
182
+ train_reader = reader .train_valid (
183
+ batch_size = args .batch_size , is_train = True , is_shuffle = True , args = args )
184
+ valid_reader = reader .train_valid (
185
+ batch_size = args .batch_size , is_train = False , is_shuffle = False , args = args )
186
+ if args .use_data_parallel :
187
+ train_reader = fluid .contrib .reader .distributed_batch_reader (
188
+ train_reader )
189
+
190
+ train_loader .set_batch_generator (train_reader , places = place )
191
+ valid_loader .set_batch_generator (valid_reader , places = place )
192
+
193
+ save_parameters = (not args .use_data_parallel ) or (
194
+ args .use_data_parallel and
195
+ paddle .distributed .parallel .ParallelEnv ().local_rank == 0 )
196
+ best_acc = 0
197
+ for epoch in range (args .epochs ):
198
+ drop_path_prob = args .drop_path_prob * epoch / args .epochs
199
+ logger .info ('Epoch {}, lr {:.6f}' .format (epoch ,
200
+ optimizer .current_step_lr ()))
201
+ train_top1 = train (model , train_loader , optimizer , epoch ,
202
+ drop_path_prob , args )
203
+ logger .info ("Epoch {}, train_acc {:.6f}" .format (epoch , train_top1 ))
204
+ valid_top1 = valid (model , valid_loader , epoch , args )
205
+ if valid_top1 > best_acc :
206
+ best_acc = valid_top1
207
+ if save_parameters :
208
+ paddle .save (model .state_dict (),
209
+ args .model_save_dir + "/best_model" )
210
+ logger .info ("Epoch {}, valid_acc {:.6f}, best_valid_acc {:.6f}" .format (
211
+ epoch , valid_top1 , best_acc ))
220
212
221
213
222
214
if __name__ == '__main__' :
0 commit comments