Skip to content

Commit 18ed332

Browse files
authored
[cherry pick] fix smac output_dir (#1248)
1 parent 75f477f commit 18ed332

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

example/auto_compression/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import paddle
3333
from PIL import Image
3434
from paddle.vision.datasets import DatasetFolder
3535
from paddle.vision.transforms import transforms
36-
from paddleslim.auto_compression import AutoCompression, Quantization, HyperParameterOptimization
36+
from paddleslim.auto_compression import AutoCompression
3737
paddle.enable_static()
3838
# 定义DataSet
3939
class ImageNetDataset(DatasetFolder):
@@ -65,7 +65,7 @@ ac = AutoCompression(
6565
model_filename="inference.pdmodel",
6666
params_filename="inference.pdiparams",
6767
save_dir="output",
68-
config={'Quantization': Quantization(), "HyperParameterOptimization": HyperParameterOptimization(max_quant_count=5)},
68+
config={'Quantization': {}, "HyperParameterOptimization": {'max_quant_count': 5}},
6969
train_dataloader=train_loader,
7070
eval_dataloader=train_loader) # eval_function to verify accuracy
7171
ac.compress()

paddleslim/quant/post_quant_hpo.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import sys
1818
import math
1919
import time
20+
from time import gmtime, strftime
2021
import numpy as np
2122
import shutil
2223
import paddle
@@ -40,7 +41,7 @@
4041

4142
_logger = get_logger(__name__, level=logging.INFO)
4243

43-
SMAC_TMP_FILE_PATTERN = "smac3-output*"
44+
SMAC_TMP_FILE_PATTERN = "smac3-output_"
4445

4546

4647
def remove(path):
@@ -496,14 +497,19 @@ def quant_post_hpo(
496497

497498
cs.add_hyperparameters(hyper_params)
498499

500+
s_datetime = strftime("%Y-%m-%d-%H:%M:%S", gmtime())
501+
smac_output_dir = SMAC_TMP_FILE_PATTERN + s_datetime
502+
499503
scenario = Scenario({
500504
"run_obj": "quality", # we optimize quality (alternative runtime)
501505
"runcount-limit":
502506
runcount_limit, # max. number of function evaluations; for this example set to a low number
503507
"cs": cs, # configuration space
504508
"deterministic": "True",
505509
"limit_resources": "False",
506-
"memory_limit": 4096 # adapt this to reasonable value for your hardware
510+
"memory_limit":
511+
4096, # adapt this to reasonable value for your hardware
512+
"output_dir": smac_output_dir # output_dir
507513
})
508514
# To optimize, we pass the function to the SMAC-object
509515
smac = SMAC4HPO(
@@ -523,5 +529,5 @@ def quant_post_hpo(
523529
inc_value = smac.get_tae_runner().run(incumbent, 1)[1]
524530
_logger.info("Optimized Value: %.8f" % inc_value)
525531
shutil.rmtree(g_quant_model_cache_path)
526-
remove(SMAC_TMP_FILE_PATTERN)
532+
remove(smac_output_dir)
527533
_logger.info("Quantization completed.")

0 commit comments

Comments
 (0)