Bug description
Bug
_get_grad_fn_or_grad_acc in torch/distributed/pipelining/_backward.py assumes all stage inputs are tensors:
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None: # crashes on non-tensor
When a model's forward() receives a non-tensor kwarg (e.g., skip_lm_head: bool = False) through the PP schedule, the backward pass iterates over all stored inputs including the bool and fails:
File "torch/distributed/pipelining/_backward.py", line 25, in _get_grad_fn_or_grad_acc
if t.requires_grad and t.grad_fn is None:
^^^^^^^^^^^^^^^
AttributeError: 'bool' object has no attribute 'requires_grad'
Repro
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, skip_head: bool = False):
h = self.layers(x)
if skip_head:
return h
return self.head(h)
# Works without PP
model(x, skip_head=True)
# Crashes during PP backward
pp_schedule.step(x, skip_head=True, target=labels, losses=losses)
Tested with DualPipeV schedule, but likely affects all PP schedules that call stage_backward_input.
Expected behavior
_get_grad_fn_or_grad_acc should skip non-tensor inputs (e.g., if not isinstance(t, torch.Tensor): continue).
Workaround
Use a model attribute instead of a forward kwarg:
model._skip_head = True # set once before training
# In forward:
if self._skip_head:
return h
Versions
PyTorch nightly 20260426
cc @sanketpurandare I temporarily put this issue in torchtitan
Bug description
Bug
_get_grad_fn_or_grad_acc in torch/distributed/pipelining/_backward.py assumes all stage inputs are tensors:
When a model's forward() receives a non-tensor kwarg (e.g., skip_lm_head: bool = False) through the PP schedule, the backward pass iterates over all stored inputs including the bool and fails:
Repro
Tested with DualPipeV schedule, but likely affects all PP schedules that call stage_backward_input.
Expected behavior
_get_grad_fn_or_grad_accshould skip non-tensor inputs (e.g., if not isinstance(t, torch.Tensor): continue).Workaround
Use a model attribute instead of a forward kwarg:
Versions
PyTorch nightly 20260426
cc @sanketpurandare I temporarily put this issue in torchtitan