Skip to content

Commit 663b508

Browse files
authored
[ACT] fix mkdirs on distributed training. (#1666)
* fix bugs * fix dead links * fix bug
1 parent e33dc48 commit 663b508

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

example/auto_compression/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ ac.compress()
255255
## 进阶使用
256256

257257
- ACT可以自动处理常见的预测模型,如果有更特殊的改造需求,可以参考[ACT超参配置教程](./hyperparameter_tutorial.md)来进行单独配置压缩策略。
258-
- ACT接口各个参数详细含义可以参考 [ACT API文档](../docs/zh_cn/api_cn/static/auto-compression/auto_compression_api.rst)。
258+
- ACT接口各个参数详细含义可以参考 [ACT API文档](../../docs/zh_cn/api_cn/static/auto-compression/auto_compression_api.rst)。
259259

260260
## 社区交流
261261

paddleslim/auto_compression/compressor.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(self,
127127

128128
self.final_dir = save_dir
129129
if not os.path.exists(self.final_dir):
130-
os.makedirs(self.final_dir)
130+
os.makedirs(self.final_dir, exist_ok=True)
131131

132132
# load config
133133
if isinstance(config, str):
@@ -263,7 +263,7 @@ def _infer_shape(self, model_dir, model_filename, params_filename,
263263
op.desc.infer_shape(block.desc)
264264

265265
save_path = os.path.join(save_path, "infered_shape")
266-
os.makedirs(save_path)
266+
os.makedirs(save_path, exist_ok=True)
267267
paddle.static.save_inference_model(
268268
save_path,
269269
feed_vars,
@@ -763,8 +763,13 @@ def single_strategy_compress(self, strategy, config, strategy_idx,
763763
inference_program, feed_target_names, fetch_targets, patterns,
764764
strategy, config, train_config)
765765
if 'unstructure' in strategy:
766-
test_program_info.program._program = remove_unused_var_nodes(
767-
test_program_info.program._program)
766+
if isinstance(test_program_info.program,
767+
paddle.static.CompiledProgram):
768+
test_program_info.program._program = remove_unused_var_nodes(
769+
test_program_info.program._program)
770+
else:
771+
test_program_info.program = remove_unused_var_nodes(
772+
test_program_info.program)
768773
test_program_info = self._start_train(
769774
train_program_info, test_program_info, strategy, train_config)
770775
if paddle.distributed.get_rank() == 0:

0 commit comments

Comments
 (0)