Skip to content

Commit abdadd9

Browse files
committed
[ft] dont do HSDP for semi_sync
1 parent 83d5c16 commit abdadd9

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

torchtitan/components/optimizer.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(
180180
optimizer_cls: type[T],
181181
optimizer_kwargs: dict[str, Any],
182182
ft_manager: "ft.Manager",
183+
use_ft_optimizer: bool = True,
183184
) -> None:
184185
super().__init__(model_parts, optimizer_cls, optimizer_kwargs)
185186

@@ -192,7 +193,9 @@ def __init__(
192193
}
193194
self.cache_state_dict: dict[str, Any] = {}
194195
self._ft_optimizer = ft.Optimizer(ft_manager, self)
195-
self._call_from_ft: bool = False
196+
# Whether to determine quorum using FT.optimizer,
197+
# in semi-sync training we use the synchronization step to start quorum
198+
self._use_ft_optimizer: bool = use_ft_optimizer
196199

197200
def init_cache_state_dict(self) -> None:
198201
self.cache_state_dict = super().state_dict()
@@ -211,28 +214,28 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
211214
def step(self, *args, **kwargs) -> None:
212215
"""Calling the correct step() depending on the caller.
213216
214-
TorchFT's OptimizerWrapper.step() is designed to be callled only once
217+
TorchFT's OptimizerWrapper.step() is designed to be called only once
215218
per train step per ft.Manager regardless how many optimizers are used.
216219
Hence we will need to appropriately dispatch the call.
217220
"""
218-
if self._call_from_ft:
219-
super().step(*args, **kwargs)
220-
else:
221-
self._call_from_ft = True
221+
if self._use_ft_optimizer:
222+
self._use_ft_optimizer = False
222223
self._ft_optimizer.step(*args, **kwargs)
223-
self._call_from_ft = False
224+
self._use_ft_optimizer = True
225+
else:
226+
super().step(*args, **kwargs)
224227

225228
def zero_grad(self, *args, **kwargs) -> None:
226229
"""Calling the correct zero_grad() depending on the caller.
227230
228231
Check the comment in ``step()``.
229232
"""
230-
if self._call_from_ft:
231-
super().zero_grad(*args, **kwargs)
232-
else:
233-
self._call_from_ft = True
233+
if self._use_ft_optimizer:
234+
self._use_ft_optimizer = False
234235
self._ft_optimizer.zero_grad(*args, **kwargs)
235-
self._call_from_ft = False
236+
self._use_ft_optimizer = True
237+
else:
238+
super().zero_grad(*args, **kwargs)
236239

237240

238241
def build_optimizers(
@@ -297,7 +300,11 @@ def build_optimizers(
297300
)
298301
elif ft_manager.enabled:
299302
return FTOptimizersContainer(
300-
model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager
303+
model_parts,
304+
optimizer_cls,
305+
optimizer_kwargs,
306+
ft_manager.manager,
307+
use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None,
301308
)
302309
else:
303310
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

torchtitan/train.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ def __init__(self, job_config: JobConfig):
230230

231231
self.model_parts = [model]
232232

233-
if self.ft_manager.enabled:
233+
if (
234+
self.ft_manager.enabled
235+
and job_config.fault_tolerance.semi_sync_method is None
236+
):
234237
self.ft_manager.set_all_reduce_hook(self.model_parts)
235238

236239
# initialize device memory monitor and get peak flops for MFU calculation
@@ -388,7 +391,13 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
388391
or self.ft_manager.enabled
389392
):
390393
loss = loss.detach()
391-
ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None
394+
# Skip ft manager communication when using semi sync training
395+
use_ft_pg = (
396+
self.ft_manager.enabled
397+
and self.job_config.fault_tolerance.semi_sync_method is None
398+
)
399+
ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None
400+
global_avg_loss = global_max_loss = loss.item()
392401
global_avg_loss, global_max_loss = (
393402
dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg),
394403
dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),

0 commit comments

Comments
 (0)