Bug description
## Bug: FP8 + TP≥2 causes NaN loss from step 3 (use_local_output=False incompatible with torchao amax)
### Summary
After torchtitan commit eb518a1d8 changed all `ColwiseParallel` layers to
`use_local_output=False`, training with `dtype=fp8` and `TP≥2` produces NaN
loss starting at step 3, with an explosive grad_norm spike at step 2.
### Versions
### Environment
- PyTorch: 2.11.0
- torchao: (latest)
- GPUs: 8× H800, TP=2, EP=4
- Model: Qwen3-30B-A3B (MoE)
- FP8 recipe: `rowwise`, `fp8_recipe_name="rowwise"`
### Steps to Reproduce
Run any training with:
- `tp > 1`
- `dtype = "fp8"` + `fp8_recipe_name = "rowwise"`
- torchao's `Float8ColwiseParallel` / `Float8RowwiseParallel` TP plan
### Observed Behavior
step: 1 loss: 2.833 grad_norm: 105.0
step: 2 loss: 2.831 grad_norm: 572.0 ← spike
step: 3 loss: nan grad_norm: nan ← permanent NaN
step: 4 loss: nan grad_norm: nan
...
### Root Cause
`torchao`'s `Float8ColwiseParallel` computes the per-tensor amax by calling
`.to_local()` on its output activation before quantizing to FP8.
Before eb518a1d8, `ColwiseParallel` used `use_local_output=True`, so the
output was already a plain `torch.Tensor` (the local shard). `.to_local()` was
effectively a no-op and returned the full local tensor.
After eb518a1d8, `use_local_output=False` is set on all `ColwiseParallel`
layers so that DTensor metadata propagates through the TP region. The output is
now a `DTensor(Shard(-1))`. When `Float8ColwiseParallel` calls `.to_local()` on
a `DTensor`, it gets back only the **local shard** — half the activations for
TP=2. The amax is therefore underestimated by roughly 2×, the FP8 scale is too
small, and activations overflow to `inf`/`NaN` within a few steps.
`Float8RowwiseParallel` is **not affected**: its output is `Replicate()` after
the all-reduce, so `.to_local()` on a `DTensor(Replicate)` returns the full
tensor correctly.
### Minimal Fix (workaround applied on our side)
When `Float8ColwiseParallel` is used, set `use_local_output=True` so the output
is a plain Tensor and amax sees the full local activations:
```python
# FP8 path
colwise_local_output = True # workaround: Float8ColwiseParallel.amax
# calls .to_local(), which returns only half
# the activations when output is DTensor(Shard(-1))
# non-FP8 path
colwise_local_output = False # correct DTensor propagation
The proper fix is for torchao's Float8ColwiseParallel to use
tensor.full_tensor() (or an all-gather) instead of .to_local() when the
output is a DTensor, so the amax reflects the global activation range.
Expected Behavior
FP8 + TP≥2 should train without NaN. The amax should be computed over the
global activation tensor (or at minimum over the full local shard before
sharding), not over a DTensor local view.
Bug description
step: 1 loss: 2.833 grad_norm: 105.0
step: 2 loss: 2.831 grad_norm: 572.0 ← spike
step: 3 loss: nan grad_norm: nan ← permanent NaN
step: 4 loss: nan grad_norm: nan
...
The proper fix is for
torchao'sFloat8ColwiseParallelto usetensor.full_tensor()(or an all-gather) instead of.to_local()when theoutput is a
DTensor, so the amax reflects the global activation range.Expected Behavior
FP8 + TP≥2 should train without NaN. The amax should be computed over the
global activation tensor (or at minimum over the full local shard before
sharding), not over a DTensor local view.