Skip to content

Commit c590123

Browse files
authored
Copy Json to SaveDir in NLP Demo (#1252)
1 parent 072ff7a commit c590123

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

example/auto_compression/nlp/run.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
from functools import partial
66
import numpy as np
7+
import shutil
78
import paddle
89
import paddle.nn as nn
910
from paddle.io import Dataset, BatchSampler, DataLoader
@@ -305,13 +306,17 @@ def main():
305306
if 'HyperParameterOptimization' not in all_config else eval_dataloader,
306307
eval_dataloader=eval_dataloader)
307308

308-
ac.compress()
309+
if not os.path.exists(args.save_dir):
310+
os.makedirs(args.save_dir)
311+
309312
for file_name in os.listdir(global_config['model_dir']):
310313
if 'json' in file_name or 'txt' in file_name:
311314
shutil.copy(
312315
os.path.join(global_config['model_dir'], file_name),
313316
args.save_dir)
314317

318+
ac.compress()
319+
315320

316321
if __name__ == '__main__':
317322
paddle.enable_static()

example/auto_compression/pytorch_huggingface/run.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -363,14 +363,17 @@ def main():
363363
'HyperParameterOptimization' not in all_config else eval_dataloader,
364364
eval_dataloader=eval_dataloader)
365365

366-
ac.compress()
366+
if not os.path.exists(args.save_dir):
367+
os.makedirs(args.save_dir)
367368

368369
for file_name in os.listdir(global_config['model_dir']):
369370
if 'json' in file_name or 'txt' in file_name:
370371
shutil.copy(
371372
os.path.join(global_config['model_dir'], file_name),
372373
args.save_dir)
373374

375+
ac.compress()
376+
374377

375378
if __name__ == '__main__':
376379
paddle.enable_static()

0 commit comments

Comments
 (0)