Bug description
When used with torchao's Float8LinearConfig(enable_fsdp_float8_all_gather=True), SimpleFSDP all-gathers weights in bf16 instead of fp8 — 2× the expected comm bandwidth, no warning. FSDP2 honors the flag correctly.
SimpleFSDP's ReplicateComputation uses DTensor.redistribute(Replicate()), which dispatches to _c10d_functional.all_gather_into_tensor. This op isn't in torchao's _ops_to_preserve_subclass, so WeightWithDynamicFloat8CastTensor is stripped before the collective. FSDP2 avoids this by going through fsdp_pre_all_gather / fsdp_post_all_gather hooks.
Repro: https://gist.github.com/aditvenk/9108e099a75150be64184aba40c984ea
torchrun --nproc_per_node=2 repro.py
SimpleFSDP's ReplicateComputation needs to account for quantized weights analogous to FSDP2.
Repro script is for Float8, but same issue should apply to other types like MXFP8 also
Versions
Reproducible on latest PyTorch nightly.
Bug description
When used with torchao's
Float8LinearConfig(enable_fsdp_float8_all_gather=True), SimpleFSDP all-gathers weights in bf16 instead of fp8 — 2× the expected comm bandwidth, no warning. FSDP2 honors the flag correctly.SimpleFSDP's
ReplicateComputationusesDTensor.redistribute(Replicate()), which dispatches to_c10d_functional.all_gather_into_tensor. This op isn't in torchao's_ops_to_preserve_subclass, soWeightWithDynamicFloat8CastTensoris stripped before the collective. FSDP2 avoids this by going throughfsdp_pre_all_gather/fsdp_post_all_gatherhooks.Repro: https://gist.github.com/aditvenk/9108e099a75150be64184aba40c984ea
torchrun --nproc_per_node=2 repro.py
SimpleFSDP's ReplicateComputation needs to account for quantized weights analogous to FSDP2.
Repro script is for Float8, but same issue should apply to other types like MXFP8 also
Versions
Reproducible on latest PyTorch nightly.