Skip to content

Commit c4ee247

Browse files
authored
Fix bug of loading saved model and tests. (#1744)
* fix bug of loading saved model. * fix distill demo. * fix flops.py. * fix tests.
1 parent f60c6a0 commit c4ee247

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

demo/distillation/distill.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def compress(args):
9797
raise ValueError("{} is not supported.".format(args.data))
9898
image_shape = [int(m) for m in image_shape.split(",")]
9999

100-
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
101-
model_list)
100+
assert args.model in model_list, "{} is not in lists: {}".format(
101+
args.model, model_list)
102102
student_program = paddle.static.Program()
103103
s_startup = paddle.static.Program()
104104
places = paddle.static.cuda_places(
@@ -202,7 +202,7 @@ def if_exist(var):
202202
_logger.info(
203203
"train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}".
204204
format(epoch_id, step_id,
205-
lr.get_lr(), loss_1[0], loss_2[0], loss_3[0]))
205+
lr.get_lr(), loss_1, loss_2, loss_3))
206206
lr.step()
207207
val_acc1s = []
208208
val_acc5s = []
@@ -216,8 +216,7 @@ def if_exist(var):
216216
if step_id % args.log_period == 0:
217217
_logger.info(
218218
"valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
219-
format(epoch_id, step_id, val_loss[0], val_acc1[0],
220-
val_acc5[0]))
219+
format(epoch_id, step_id, val_loss, val_acc1, val_acc5))
221220
if args.save_inference:
222221
paddle.static.save_inference_model(
223222
os.path.join("./saved_models", str(epoch_id)), [image], [out],

paddleslim/analysis/flops.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def _graph_flops(graph, only_conv=True, detail=False):
8484
output_shape = op.outputs("Out")[0].shape()
8585
_, c_out, h_out, w_out = output_shape
8686
k_size = op.attr("ksize")
87-
flops += h_out * w_out * c_out * (k_size[0]**2)
87+
if op.attr('pooling_type') == 'avg':
88+
flops += (h_out * w_out * c_out * (k_size[0]**2) * 2)
8889

8990
elif op.type() in ['mul', 'matmul', 'matmul_v2']:
9091
x_shape = list(op.inputs("X")[0].shape())
@@ -101,7 +102,11 @@ def _graph_flops(graph, only_conv=True, detail=False):
101102
input_shape = list(op.inputs("X")[0].shape())
102103
if input_shape[0] == -1:
103104
input_shape[0] = 1
104-
flops += np.product(input_shape)
105+
if op.type() == 'batch_norm':
106+
op_flops = np.product(input_shape) * 2
107+
else:
108+
op_flops = np.product(input_shape)
109+
flops += op_flops
105110

106111
if detail:
107112
return flops, params2flops

paddleslim/auto_compression/compressor.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def create_tmp_dir(self, base_dir, prefix="tmp"):
570570
tmp_base_name = "_".join([prefix, str(os.getppid()), s_datetime])
571571
tmp_dir = os.path.join(base_dir, tmp_base_name)
572572
if not os.path.exists(tmp_dir):
573-
os.makedirs(tmp_dir)
573+
os.makedirs(tmp_dir, exist_ok=True)
574574
return tmp_dir
575575

576576
def compress(self):
@@ -609,10 +609,17 @@ def compress(self):
609609
shutil.rmtree(self.tmp_dir)
610610

611611
if self.eval_function is not None and self.final_metric < 0.0:
612+
model_filename = None
613+
if self.model_filename is None:
614+
model_filename = "model.pdmodel"
615+
elif self.model_filename.endswith(".pdmodel"):
616+
model_filename = self.model_filename
617+
else:
618+
model_filename = self.model_filename + '.pdmodel'
619+
612620
[inference_program, feed_target_names, fetch_targets]= load_inference_model( \
613621
final_model_path, \
614-
model_filename=self.model_filename, params_filename=self.params_filename,
615-
executor=self._exe)
622+
model_filename=model_filename, executor=self._exe)
616623
self.final_metric = self.eval_function(
617624
self._exe, inference_program, feed_target_names,
618625
fetch_targets)

tests/dygraph/test_flops.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ def runTest(self):
7373

7474
def add_cases(suite):
7575
suite.addTest(
76-
TestFlops(
77-
net=paddle.vision.models.mobilenet_v1, gt=11792896.0))
78-
suite.addTest(TestFlops(net=paddle.vision.models.resnet50, gt=83872768.0))
76+
TestFlops(net=paddle.vision.models.mobilenet_v1, gt=12920832.0))
77+
suite.addTest(TestFlops(net=paddle.vision.models.resnet50, gt=86112768.0))
7978
suite.addTest(TestFLOPsCase1())
8079
suite.addTest(TestFLOPsCase2())
8180

0 commit comments

Comments
 (0)