diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index f45d4db6ce..6a84cddcde 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -180,6 +180,7 @@ def __init__( optimizer_cls: type[T], optimizer_kwargs: dict[str, Any], ft_manager: "ft.Manager", + use_ft_optimizer: bool = True, ) -> None: super().__init__(model_parts, optimizer_cls, optimizer_kwargs) @@ -192,7 +193,9 @@ def __init__( } self.cache_state_dict: dict[str, Any] = {} self._ft_optimizer = ft.Optimizer(ft_manager, self) - self._call_from_ft: bool = False + # Whether to determine quorum using FT.optimizer, + # in semi-sync training we use the synchronization step to start quorum + self._use_ft_optimizer: bool = use_ft_optimizer def init_cache_state_dict(self) -> None: self.cache_state_dict = super().state_dict() @@ -211,28 +214,28 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: def step(self, *args, **kwargs) -> None: """Calling the correct step() depending on the caller. - TorchFT's OptimizerWrapper.step() is designed to be callled only once + TorchFT's OptimizerWrapper.step() is designed to be called only once per train step per ft.Manager regardless how many optimizers are used. Hence we will need to appropriately dispatch the call. """ - if self._call_from_ft: - super().step(*args, **kwargs) - else: - self._call_from_ft = True + if self._use_ft_optimizer: + self._use_ft_optimizer = False self._ft_optimizer.step(*args, **kwargs) - self._call_from_ft = False + self._use_ft_optimizer = True + else: + super().step(*args, **kwargs) def zero_grad(self, *args, **kwargs) -> None: """Calling the correct zero_grad() depending on the caller. Check the comment in ``step()``. """ - if self._call_from_ft: - super().zero_grad(*args, **kwargs) - else: - self._call_from_ft = True + if self._use_ft_optimizer: + self._use_ft_optimizer = False self._ft_optimizer.zero_grad(*args, **kwargs) - self._call_from_ft = False + self._use_ft_optimizer = True + else: + super().zero_grad(*args, **kwargs) def build_optimizers( @@ -297,7 +300,11 @@ def build_optimizers( ) elif ft_manager.enabled: return FTOptimizersContainer( - model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager + model_parts, + optimizer_cls, + optimizer_kwargs, + ft_manager.manager, + use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None, ) else: return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) diff --git a/torchtitan/train.py b/torchtitan/train.py index d0228ae782..88340fba69 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -230,7 +230,10 @@ def __init__(self, job_config: JobConfig): self.model_parts = [model] - if self.ft_manager.enabled: + if ( + self.ft_manager.enabled + and job_config.fault_tolerance.semi_sync_method is None + ): self.ft_manager.set_all_reduce_hook(self.model_parts) # initialize device memory monitor and get peak flops for MFU calculation @@ -388,7 +391,12 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): or self.ft_manager.enabled ): loss = loss.detach() - ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None + # Skip ft manager communication when using semi sync training + use_ft_pg = ( + self.ft_manager.enabled + and self.job_config.fault_tolerance.semi_sync_method is None + ) + ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None global_avg_loss, global_max_loss = ( dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),