@@ -180,6 +180,7 @@ def __init__(
180
180
optimizer_cls : type [T ],
181
181
optimizer_kwargs : dict [str , Any ],
182
182
ft_manager : "ft.Manager" ,
183
+ use_ft_optimizer : bool = True ,
183
184
) -> None :
184
185
super ().__init__ (model_parts , optimizer_cls , optimizer_kwargs )
185
186
@@ -192,7 +193,9 @@ def __init__(
192
193
}
193
194
self .cache_state_dict : dict [str , Any ] = {}
194
195
self ._ft_optimizer = ft .Optimizer (ft_manager , self )
195
- self ._call_from_ft : bool = False
196
+ # Whether to determine quorum using FT.optimizer,
197
+ # in semi-sync training we use the synchronization step to start quorum
198
+ self ._use_ft_optimizer : bool = use_ft_optimizer
196
199
197
200
def init_cache_state_dict (self ) -> None :
198
201
self .cache_state_dict = super ().state_dict ()
@@ -211,28 +214,28 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
211
214
def step (self , * args , ** kwargs ) -> None :
212
215
"""Calling the correct step() depending on the caller.
213
216
214
- TorchFT's OptimizerWrapper.step() is designed to be callled only once
217
+ TorchFT's OptimizerWrapper.step() is designed to be called only once
215
218
per train step per ft.Manager regardless how many optimizers are used.
216
219
Hence we will need to appropriately dispatch the call.
217
220
"""
218
- if self ._call_from_ft :
219
- super ().step (* args , ** kwargs )
220
- else :
221
- self ._call_from_ft = True
221
+ if self ._use_ft_optimizer :
222
+ self ._use_ft_optimizer = False
222
223
self ._ft_optimizer .step (* args , ** kwargs )
223
- self ._call_from_ft = False
224
+ self ._use_ft_optimizer = True
225
+ else :
226
+ super ().step (* args , ** kwargs )
224
227
225
228
def zero_grad (self , * args , ** kwargs ) -> None :
226
229
"""Calling the correct zero_grad() depending on the caller.
227
230
228
231
Check the comment in ``step()``.
229
232
"""
230
- if self ._call_from_ft :
231
- super ().zero_grad (* args , ** kwargs )
232
- else :
233
- self ._call_from_ft = True
233
+ if self ._use_ft_optimizer :
234
+ self ._use_ft_optimizer = False
234
235
self ._ft_optimizer .zero_grad (* args , ** kwargs )
235
- self ._call_from_ft = False
236
+ self ._use_ft_optimizer = True
237
+ else :
238
+ super ().zero_grad (* args , ** kwargs )
236
239
237
240
238
241
def build_optimizers (
@@ -297,7 +300,11 @@ def build_optimizers(
297
300
)
298
301
elif ft_manager .enabled :
299
302
return FTOptimizersContainer (
300
- model_parts , optimizer_cls , optimizer_kwargs , ft_manager .manager
303
+ model_parts ,
304
+ optimizer_cls ,
305
+ optimizer_kwargs ,
306
+ ft_manager .manager ,
307
+ use_ft_optimizer = job_config .fault_tolerance .semi_sync_method is None ,
301
308
)
302
309
else :
303
310
return OptimizersContainer (model_parts , optimizer_cls , optimizer_kwargs )
0 commit comments