88"""
99
1010import argparse
11+ from datetime import datetime
1112import functools
1213import logging
1314import math
1819import socket
1920import re
2021from typing import Dict , Optional , Any , List , Tuple , Callable
22+ from urllib .parse import urlparse
2123import warnings
2224
2325import numpy as np
@@ -512,37 +514,51 @@ def _checkpoint_add_directory(basename):
512514 return m [1 ], f"checkpoint{ m [3 ]} "
513515
514516
515- def post_checkpoint_callback (cfg , num_updates , training_finished , filename ):
517+ def _get_basename (path ):
518+ res = urlparse (path )
519+ if res .scheme :
520+ return os .path .basename (res .path )
521+ else :
522+ return os .path .basename (path )
523+
524+
525+ def _get_destination_path (path , destination ):
526+ """Calculates the destination path with handling for remote paths."""
527+ basename = _get_basename (path )
528+ res = urlparse (destination )
529+ if res .scheme :
530+ new_path = os .path .join (res .path , basename )
531+ res = res ._replace (path = new_path )
532+ return res .geturl ()
533+ else :
534+ return os .path .join (destination , basename )
535+
536+
537+ def post_checkpoint_callback (
538+ cfg , num_updates , training_finished , filename , files_to_symlink_to
539+ ):
516540 if cfg .checkpoint .cloud_upload_path is not None :
517541 if "blob.core.windows.net" in cfg .checkpoint .cloud_upload_path :
518- azcopy_logs = filename + "_azcopy_logs"
519- os .environ ["AZCOPY_CONCURRENCY_VALUE" ] = "10"
520- os .environ ["AZCOPY_LOG_LOCATION" ] = azcopy_logs
521- os .makedirs (azcopy_logs , exist_ok = True )
522- logger .info (
523- f"preparing to azcopy { filename } to { cfg .checkpoint .cloud_upload_path } ; logs in { azcopy_logs } "
542+ azcopy_log_dir = os .path .dirname (filename )
543+ final_path = _get_destination_path (
544+ filename , cfg .checkpoint .cloud_upload_path
524545 )
525- cmd = [
526- "azcopy" , # TODO(susanz): require azcopy to be installed.
527- "copy" ,
528- "--cap-mbps" ,
529- "96.0" ,
530- filename ,
531- cfg .checkpoint .cloud_upload_path ,
532- ]
533- res = _run_azcopy (cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
534- if res .returncode != 0 :
535- print ("Error: {}, azcopy failed" .format (res .returncode ))
536- print ("Azcopy stdout = {}" .format (res .stdout ))
537- sys .exit (1 )
546+ _copy_to_azure (filename , final_path , azcopy_log_dir )
547+
538548 # Delete original checkpoint on local storage
539549 # TODO make this configurable
540- logger .info (
541- f"Successfully copied { filename } to { cfg .checkpoint .cloud_upload_path } "
542- )
543550 os .remove (filename )
551+
552+ # Azure Blob doesn't support symlinks so make full copies
553+ if files_to_symlink_to :
554+ for other_checkpoint in files_to_symlink_to :
555+ dest = _get_destination_path (
556+ other_checkpoint , cfg .checkpoint .cloud_upload_path
557+ )
558+ _copy_to_azure (final_path , dest , azcopy_log_dir )
559+
544560 elif cfg .checkpoint .cloud_upload_path .startswith ("nfs:" ):
545- path , basename = os .path .split (filename )
561+ basename = os .path .basename (filename )
546562 checkpoint_dir , checkpoint_file = _checkpoint_add_directory (basename )
547563 destination_checkpoints_dir = cfg .checkpoint .cloud_upload_path [4 :]
548564 temporary_checkpoint_file = f"_{ checkpoint_file } "
@@ -566,6 +582,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
566582 )
567583
568584 logger .info (f"Renaming { temporary_checkpoint_file } -> { checkpoint_file } " )
585+ final_path = os .path .join (
586+ destination_checkpoints_dir , checkpoint_dir , checkpoint_file
587+ )
569588 # atomic rename _checkpointfile -> checkpointfile
570589 # this way we know that if present the checkpoint file is complete
571590 os .rename (
@@ -574,12 +593,20 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
574593 checkpoint_dir ,
575594 temporary_checkpoint_file ,
576595 ),
577- os .path .join (
578- destination_checkpoints_dir , checkpoint_dir , checkpoint_file
579- ),
596+ final_path ,
580597 )
581598 os .remove (filename )
582599
600+ if files_to_symlink_to :
601+ dest_dir = os .path .dirname (final_path )
602+ for other_checkpoint in files_to_symlink_to :
603+ dest = _get_destination_path (other_checkpoint , dest_dir )
604+ if PathManager .islink (dest ):
605+ PathManager .rm (dest )
606+ assert PathManager .symlink (
607+ final_path , dest
608+ ), f"Failed to symlink { final_path } to { dest } "
609+
583610 # Start running evals on uploaded checkpoint
584611 nfs_evaluation (
585612 cfg ,
@@ -593,13 +620,18 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
593620 try :
594621 # PathManager only supports writing to S3, but this function call
595622 # can be replaced with other APIs for copying checkpoints.
596- PathManager .copy_from_local (
597- filename ,
598- os .path .join (
599- cfg .checkpoint .cloud_upload_path , os .path .basename (filename )
600- ),
601- overwrite = True ,
623+ final_path = _get_destination_path (
624+ filename , cfg .checkpoint .cloud_upload_path
602625 )
626+ PathManager .copy_from_local (filename , final_path , overwrite = True )
627+
628+ # Some non-native PathHandlers don't support symlinks so default to full copies
629+ if files_to_symlink_to :
630+ for other_checkpoint in files_to_symlink_to :
631+ dest = _get_destination_path (
632+ other_checkpoint , cfg .checkpoint .cloud_upload_path
633+ )
634+ PathManager .copy (final_path , dest , overwrite = True )
603635 except (FileNotFoundError , AssertionError ) as e :
604636 logger .info (f"could not upload { filename } : { e } " )
605637
@@ -665,6 +697,31 @@ def nfs_evaluation(
665697 )
666698
667699
700+ def _copy_to_azure (source , destination , log_dir ):
701+ # /dir/checkpoint_last.pt -> /dir/checkpoint_last.pt_azcopy_logs_2000-01-01T00_00_00
702+ basename = _get_basename (destination )
703+ timestamp = datetime .utcnow ().isoformat ().replace (":" , "_" )[:- 7 ]
704+ azcopy_logs = os .path .join (log_dir , f"{ basename } _azcopy_logs_{ timestamp } " )
705+ os .environ ["AZCOPY_CONCURRENCY_VALUE" ] = "10"
706+ os .environ ["AZCOPY_LOG_LOCATION" ] = azcopy_logs
707+ os .makedirs (azcopy_logs , exist_ok = True )
708+ logger .info (f"preparing to azcopy { source } to { destination } ; logs in { azcopy_logs } " )
709+ cmd = [
710+ "azcopy" , # TODO(susanz): require azcopy to be installed.
711+ "copy" ,
712+ "--cap-mbps" ,
713+ "96.0" ,
714+ source ,
715+ destination ,
716+ ]
717+ res = _run_azcopy (cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
718+ if res .returncode != 0 :
719+ print ("Error: {}, azcopy failed" .format (res .returncode ))
720+ print ("Azcopy stdout = {}" .format (res .stdout ))
721+ sys .exit (1 )
722+ logger .info (f"Successfully copied { source } to { destination } " )
723+
724+
668725def _run_azcopy (cmd , stdout , stderr ):
669726 return subprocess .run (cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
670727
0 commit comments