@@ -141,7 +141,7 @@ def __init__(
141
141
phase = self ._phase ,
142
142
)
143
143
144
- self ._steps = self ._create_steps ()
144
+ self ._steps , self . _first_grad_stage = self ._create_steps ()
145
145
146
146
self ._create_index ()
147
147
@@ -214,8 +214,8 @@ def _create_index(self) -> None:
214
214
# Consistency checks
215
215
step_map = self ._step_map .copy ()
216
216
for data_index in range (self ._batch_config .num_inputs ):
217
- for type_ in (StepType .forward , StepType .backward ) if self . _is_training else ( StepType . forward ,) :
218
- for stage in range (self ._num_stages ):
217
+ for type_ in (StepType .forward , StepType .backward ):
218
+ for stage in range (0 if type_ == StepType . forward else self . _first_grad_stage , self ._num_stages ):
219
219
assert (
220
220
step_map .pop ((type_ , stage , data_index ), None ) is not None
221
221
), f"Missing { type_ .value } step with stage={ stage } , data_index={ data_index } "
@@ -225,7 +225,8 @@ def _create_index(self) -> None:
225
225
for i , step in enumerate (self ._steps ):
226
226
if self ._is_training :
227
227
if step .type_ == StepType .forward :
228
- step .backward_step = self .get_step (StepType .backward , * step .map_index [1 :])
228
+ if step .stage >= self ._first_grad_stage :
229
+ step .backward_step = self .get_step (StepType .backward , * step .map_index [1 :])
229
230
else :
230
231
step .forward_step = self .get_step (StepType .forward , * step .map_index [1 :])
231
232
if step .type_ == StepType .forward and step .stage == 0 :
@@ -236,7 +237,8 @@ def _create_index(self) -> None:
236
237
step .prev_step = self .get_step (
237
238
step .type_ , step .stage + (1 if step .type_ == StepType .backward else - 1 ), * step .map_index [2 :]
238
239
)
239
- if step .type_ == StepType .backward and step .stage == 0 :
240
+
241
+ if step .type_ == StepType .backward and step .stage == self ._first_grad_stage :
240
242
step .next_step = None
241
243
elif step .type_ == StepType .forward and step .stage == self ._num_stages - 1 :
242
244
step .next_step = self .get_step (StepType .backward , * step .map_index [1 :]) if self ._is_training else None
@@ -249,11 +251,15 @@ def _create_index(self) -> None:
249
251
for step in self ._steps :
250
252
if self ._is_training :
251
253
if step .type_ == StepType .forward :
252
- Assert .gt (step .backward_step .global_index , step .global_index )
253
- Assert .is_ (step .backward_step .forward_step , step )
254
+ if step .stage >= self ._first_grad_stage :
255
+ Assert .gt (step .backward_step .global_index , step .global_index )
256
+ Assert .is_ (step .backward_step .forward_step , step )
257
+ else :
258
+ assert step .backward_step is None
254
259
else :
255
260
Assert .lt (step .forward_step .global_index , step .global_index )
256
- Assert .is_ (step .forward_step .backward_step , step )
261
+ if step .stage >= self ._first_grad_stage :
262
+ Assert .is_ (step .forward_step .backward_step , step )
257
263
if step .next_step is not None :
258
264
Assert .gt (step .next_step .global_index , step .global_index )
259
265
Assert .is_ (step .next_step .prev_step , step )
@@ -303,7 +309,10 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None:
303
309
reduce_step .reduce_accumulate = reduction_count [reduce_step .stage ] > 0
304
310
reduction_count [reduce_step .stage ] += 1
305
311
for stage , count in enumerate (reduction_count ):
306
- assert (count > 0 ) == (stage % self ._distributed .pipeline_parallel == self ._distributed .pipeline_rank )
312
+ assert (count > 0 ) == (
313
+ stage >= self ._first_grad_stage
314
+ and (stage % self ._distributed .pipeline_parallel == self ._distributed .pipeline_rank )
315
+ )
307
316
308
317
def _setup_timeline (self ) -> None :
309
318
# TODO: Include network time
@@ -468,8 +477,16 @@ def get_data_index_split(
468
477
micro_sequence ,
469
478
)
470
479
471
- def _create_steps (self ) -> list [Step ]:
480
+ def _create_steps (self ) -> tuple [ list [Step ], int ]:
472
481
steps = []
482
+ if self ._is_training :
483
+ # The first stage(s) may not have any trainable parameters,
484
+ # in which case we shouldn't run the backward pass.
485
+ first_grad_stage = 0
486
+ while first_grad_stage < self ._num_stages and not self ._multi_stage .stages [first_grad_stage ].requires_grad :
487
+ first_grad_stage += 1
488
+ else :
489
+ first_grad_stage = self ._num_stages
473
490
for depth_first_micro_batch in range (self ._batch_config .depth_first_micro_batches ):
474
491
for stage in range (self ._num_stages ):
475
492
for breadth_first_micro_batch in range (self ._batch_config .breadth_first_micro_batches ):
@@ -485,7 +502,7 @@ def _create_steps(self) -> list[Step]:
485
502
)
486
503
)
487
504
if self ._is_training :
488
- for stage in reversed (range (self ._num_stages )):
505
+ for stage in reversed (range (first_grad_stage , self ._num_stages )):
489
506
for breadth_first_micro_batch in range (self ._batch_config .breadth_first_micro_batches ):
490
507
for micro_sequence in reversed (range (self ._batch_config .num_micro_sequences )):
491
508
steps .append (
@@ -498,4 +515,4 @@ def _create_steps(self) -> list[Step]:
498
515
type_ = StepType .backward ,
499
516
)
500
517
)
501
- return steps
518
+ return steps , first_grad_stage
0 commit comments