@@ -586,28 +586,11 @@ def compress(self):
586
586
tmp_model_path = os .path .join (
587
587
self .tmp_dir , 'strategy_{}' .format (str (strategy_idx + 1 )))
588
588
final_model_path = os .path .join (self .final_dir )
589
- if not os .path .exists (final_model_path ):
590
- os .makedirs (final_model_path )
591
-
592
- tmp_model_file = "." .join ([tmp_model_path , "pdmodel" ])
593
- if not os .path .exists (tmp_model_file ):
594
- tmp_model_file = os .path .join (tmp_model_path , self .model_filename )
595
-
596
- tmp_params_file = "." .join ([tmp_model_path , "pdiparams" ])
597
- if not os .path .exists (tmp_params_file ):
598
- tmp_params_file = os .path .join (tmp_model_path , self .params_filename )
599
-
600
- if self .model_filename is None :
601
- self .model_filename = "infer.pdmodel"
602
- if self .params_filename is None :
603
- self .params_filename = "infer.pdiparams"
604
-
605
- final_model_file = os .path .join (final_model_path , self .model_filename )
606
- final_params_file = os .path .join (final_model_path , self .params_filename )
607
-
608
589
if paddle .distributed .get_rank () == 0 :
609
- shutil .move (tmp_model_file , final_model_file )
610
- shutil .move (tmp_params_file , final_params_file )
590
+ for _file in os .listdir (tmp_model_path ):
591
+ _file_path = os .path .join (tmp_model_path , _file )
592
+ if os .path .isfile (_file_path ):
593
+ shutil .copy (_file_path , final_model_path )
611
594
shutil .rmtree (self .tmp_dir )
612
595
_logger .info (
613
596
"==> The ACT compression has been completed and the final model is saved in `{}`" .
0 commit comments