@@ -141,7 +141,7 @@ def __init__(
141141 phase = self ._phase ,
142142 )
143143
144- self ._steps = self ._create_steps ()
144+ self ._steps , self . _first_grad_stage = self ._create_steps ()
145145
146146 self ._create_index ()
147147
@@ -214,8 +214,8 @@ def _create_index(self) -> None:
214214 # Consistency checks
215215 step_map = self ._step_map .copy ()
216216 for data_index in range (self ._batch_config .num_inputs ):
217- for type_ in (StepType .forward , StepType .backward ) if self . _is_training else ( StepType . forward ,) :
218- for stage in range (self ._num_stages ):
217+ for type_ in (StepType .forward , StepType .backward ):
218+ for stage in range (0 if type_ == StepType . forward else self . _first_grad_stage , self ._num_stages ):
219219 assert (
220220 step_map .pop ((type_ , stage , data_index ), None ) is not None
221221 ), f"Missing { type_ .value } step with stage={ stage } , data_index={ data_index } "
@@ -225,7 +225,8 @@ def _create_index(self) -> None:
225225 for i , step in enumerate (self ._steps ):
226226 if self ._is_training :
227227 if step .type_ == StepType .forward :
228- step .backward_step = self .get_step (StepType .backward , * step .map_index [1 :])
228+ if step .stage >= self ._first_grad_stage :
229+ step .backward_step = self .get_step (StepType .backward , * step .map_index [1 :])
229230 else :
230231 step .forward_step = self .get_step (StepType .forward , * step .map_index [1 :])
231232 if step .type_ == StepType .forward and step .stage == 0 :
@@ -236,7 +237,8 @@ def _create_index(self) -> None:
236237 step .prev_step = self .get_step (
237238 step .type_ , step .stage + (1 if step .type_ == StepType .backward else - 1 ), * step .map_index [2 :]
238239 )
239- if step .type_ == StepType .backward and step .stage == 0 :
240+
241+ if step .type_ == StepType .backward and step .stage == self ._first_grad_stage :
240242 step .next_step = None
241243 elif step .type_ == StepType .forward and step .stage == self ._num_stages - 1 :
242244 step .next_step = self .get_step (StepType .backward , * step .map_index [1 :]) if self ._is_training else None
@@ -249,11 +251,15 @@ def _create_index(self) -> None:
249251 for step in self ._steps :
250252 if self ._is_training :
251253 if step .type_ == StepType .forward :
252- Assert .gt (step .backward_step .global_index , step .global_index )
253- Assert .is_ (step .backward_step .forward_step , step )
254+ if step .stage >= self ._first_grad_stage :
255+ Assert .gt (step .backward_step .global_index , step .global_index )
256+ Assert .is_ (step .backward_step .forward_step , step )
257+ else :
258+ assert step .backward_step is None
254259 else :
255260 Assert .lt (step .forward_step .global_index , step .global_index )
256- Assert .is_ (step .forward_step .backward_step , step )
261+ if step .stage >= self ._first_grad_stage :
262+ Assert .is_ (step .forward_step .backward_step , step )
257263 if step .next_step is not None :
258264 Assert .gt (step .next_step .global_index , step .global_index )
259265 Assert .is_ (step .next_step .prev_step , step )
@@ -303,7 +309,10 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None:
303309 reduce_step .reduce_accumulate = reduction_count [reduce_step .stage ] > 0
304310 reduction_count [reduce_step .stage ] += 1
305311 for stage , count in enumerate (reduction_count ):
306- assert (count > 0 ) == (stage % self ._distributed .pipeline_parallel == self ._distributed .pipeline_rank )
312+ assert (count > 0 ) == (
313+ stage >= self ._first_grad_stage
314+ and (stage % self ._distributed .pipeline_parallel == self ._distributed .pipeline_rank )
315+ )
307316
308317 def _setup_timeline (self ) -> None :
309318 # TODO: Include network time
@@ -468,8 +477,16 @@ def get_data_index_split(
468477 micro_sequence ,
469478 )
470479
471- def _create_steps (self ) -> list [Step ]:
480+ def _create_steps (self ) -> tuple [ list [Step ], int ]:
472481 steps = []
482+ if self ._is_training :
483+ # The first stage(s) may not have any trainable parameters,
484+ # in which case we shouldn't run the backward pass.
485+ first_grad_stage = 0
486+ while first_grad_stage < self ._num_stages and not self ._multi_stage .stages [first_grad_stage ].requires_grad :
487+ first_grad_stage += 1
488+ else :
489+ first_grad_stage = self ._num_stages
473490 for depth_first_micro_batch in range (self ._batch_config .depth_first_micro_batches ):
474491 for stage in range (self ._num_stages ):
475492 for breadth_first_micro_batch in range (self ._batch_config .breadth_first_micro_batches ):
@@ -485,7 +502,7 @@ def _create_steps(self) -> list[Step]:
485502 )
486503 )
487504 if self ._is_training :
488- for stage in reversed (range (self ._num_stages )):
505+ for stage in reversed (range (first_grad_stage , self ._num_stages )):
489506 for breadth_first_micro_batch in range (self ._batch_config .breadth_first_micro_batches ):
490507 for micro_sequence in reversed (range (self ._batch_config .num_micro_sequences )):
491508 steps .append (
@@ -498,4 +515,4 @@ def _create_steps(self) -> list[Step]:
498515 type_ = StepType .backward ,
499516 )
500517 )
501- return steps
518+ return steps , first_grad_stage
0 commit comments