Skip to content

PP backward crashes on non-tensor forward kwargs ('bool' object has no attribute 'requires_grad') #3112

@wwwjn

Description

@wwwjn

Bug description

Bug

_get_grad_fn_or_grad_acc in torch/distributed/pipelining/_backward.py assumes all stage inputs are tensors:

  def _get_grad_fn_or_grad_acc(t):                                
      if t.requires_grad and t.grad_fn is None:  # crashes on non-tensor

When a model's forward() receives a non-tensor kwarg (e.g., skip_lm_head: bool = False) through the PP schedule, the backward pass iterates over all stored inputs including the bool and fails:

File "torch/distributed/pipelining/_backward.py", line 25, in _get_grad_fn_or_grad_acc
      if t.requires_grad and t.grad_fn is None:                                              
         ^^^^^^^^^^^^^^^
  AttributeError: 'bool' object has no attribute 'requires_grad'

Repro

  class MyModel(nn.Module):                                       
      def forward(self, x: torch.Tensor, skip_head: bool = False):
          h = self.layers(x)                                                                 
          if skip_head:
              return h                                                                       
          return self.head(h)            

  # Works without PP                                                                         
  model(x, skip_head=True)
                                                                                             
  # Crashes during PP backward                                    
  pp_schedule.step(x, skip_head=True, target=labels, losses=losses)

Tested with DualPipeV schedule, but likely affects all PP schedules that call stage_backward_input.

Expected behavior

_get_grad_fn_or_grad_acc should skip non-tensor inputs (e.g., if not isinstance(t, torch.Tensor): continue).

Workaround

Use a model attribute instead of a forward kwarg:

  model._skip_head = True  # set once before training                                        
  # In forward:                                                                              
  if self._skip_head:
      return h   

Versions

PyTorch nightly 20260426

cc @sanketpurandare I temporarily put this issue in torchtitan

Metadata

Metadata

Labels

No labels
No labels

Projects

Status

Todo

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions