Skip to content

Bug: FP8 + TP≥2 causes NaN loss from step 3 (use_local_output=False incompatible with torchao amax) #3109

@zzhaobh

Description

@zzhaobh

Bug description

## Bug: FP8 + TP≥2 causes NaN loss from step 3 (use_local_output=False incompatible with torchao amax)

### Summary

After torchtitan commit eb518a1d8 changed all `ColwiseParallel` layers to
`use_local_output=False`, training with `dtype=fp8` and `TP≥2` produces NaN
loss starting at step 3, with an explosive grad_norm spike at step 2.

### Versions

### Environment

- PyTorch: 2.11.0
- torchao: (latest)
- GPUs: 8× H800, TP=2, EP=4
- Model: Qwen3-30B-A3B (MoE)
- FP8 recipe: `rowwise`, `fp8_recipe_name="rowwise"`

### Steps to Reproduce

Run any training with:
- `tp > 1`
- `dtype = "fp8"` + `fp8_recipe_name = "rowwise"`
- torchao's `Float8ColwiseParallel` / `Float8RowwiseParallel` TP plan

### Observed Behavior

step: 1 loss: 2.833 grad_norm: 105.0
step: 2 loss: 2.831 grad_norm: 572.0 ← spike
step: 3 loss: nan grad_norm: nan ← permanent NaN
step: 4 loss: nan grad_norm: nan
...


### Root Cause

`torchao`'s `Float8ColwiseParallel` computes the per-tensor amax by calling
`.to_local()` on its output activation before quantizing to FP8.

Before eb518a1d8, `ColwiseParallel` used `use_local_output=True`, so the
output was already a plain `torch.Tensor` (the local shard). `.to_local()` was
effectively a no-op and returned the full local tensor.

After eb518a1d8, `use_local_output=False` is set on all `ColwiseParallel`
layers so that DTensor metadata propagates through the TP region. The output is
now a `DTensor(Shard(-1))`. When `Float8ColwiseParallel` calls `.to_local()` on
a `DTensor`, it gets back only the **local shard** — half the activations for
TP=2. The amax is therefore underestimated by roughly 2×, the FP8 scale is too
small, and activations overflow to `inf`/`NaN` within a few steps.

`Float8RowwiseParallel` is **not affected**: its output is `Replicate()` after
the all-reduce, so `.to_local()` on a `DTensor(Replicate)` returns the full
tensor correctly.

### Minimal Fix (workaround applied on our side)

When `Float8ColwiseParallel` is used, set `use_local_output=True` so the output
is a plain Tensor and amax sees the full local activations:

```python
# FP8 path
colwise_local_output = True   # workaround: Float8ColwiseParallel.amax
                               # calls .to_local(), which returns only half
                               # the activations when output is DTensor(Shard(-1))
# non-FP8 path
colwise_local_output = False  # correct DTensor propagation

The proper fix is for torchao's Float8ColwiseParallel to use
tensor.full_tensor() (or an all-gather) instead of .to_local() when the
output is a DTensor, so the amax reflects the global activation range.

Expected Behavior

FP8 + TP≥2 should train without NaN. The amax should be computed over the
global activation tensor (or at minimum over the full local shard before
sharding), not over a DTensor local view.

Metadata

Metadata

Labels

No labels
No labels

Type

Projects

Status

Todo

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions