Skip to content

Commit 671be70

Browse files
authored
fix ptq demo (#1066)
1 parent 6b8b683 commit 671be70

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

demo/auto-compression/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ python demo_imagenet.py \
8888
--model_dir='infermodel_mobilenetv2' \
8989
--model_filename='inference.pdmodel' \
9090
--params_filename='./inference.pdiparams' \
91-
--save_dir='./save_qat_mbv2/' \
91+
--save_dir='./save_ptq_mbv2/' \
9292
--devices='gpu' \
9393
--batch_size=64 \
9494
--data_dir='../data/ILSVRC2012/' \
@@ -118,7 +118,7 @@ python demo_imagenet.py \
118118
--model_dir='infermodel_mobilenetv2' \
119119
--model_filename='inference.pdmodel' \
120120
--params_filename='./inference.pdiparams' \
121-
--save_dir='./save_qat_mbv2/' \
121+
--save_dir='./save_asp_mbv2/' \
122122
--devices='gpu' \
123123
--batch_size=64 \
124124
--data_dir='../data/ILSVRC2012/' \

demo/auto-compression/demo_glue.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def apply_decay_param_fun(name):
189189
strategy_config=compress_config,
190190
train_config=train_config,
191191
train_dataloader=train_dataloader,
192-
eval_callback=eval_function,
192+
eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else eval_dataloader,
193193
devices=args.devices)
194194

195195
ac.compress()

demo/auto-compression/demo_imagenet.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ def gen():
3737

3838
return gen
3939

40+
def eval_reader(data_dir, batch_size):
41+
val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=batch_size)
42+
return val_reader
4043

4144
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
42-
43-
val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=1)
45+
val_reader = eval_reader(data_dir, batch_size=1)
4446
image = paddle.static.data(
4547
name='x', shape=[None, 3, 224, 224], dtype='float32')
4648
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
@@ -102,7 +104,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
102104
strategy_config=compress_config,
103105
train_config=train_config,
104106
train_dataloader=train_dataloader,
105-
eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else None,
107+
eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else reader_wrapper(eval_reader(data_dir, 64)),
106108
devices=args.devices)
107109

108110
ac.compress()

0 commit comments

Comments
 (0)