Skip to content

[SimpleFSDP] SimpleFSDP silently ignores enable_fsdp_float8_all_gather=True #3123

@aditvenk

Description

@aditvenk

Bug description

When used with torchao's Float8LinearConfig(enable_fsdp_float8_all_gather=True), SimpleFSDP all-gathers weights in bf16 instead of fp8 — 2× the expected comm bandwidth, no warning. FSDP2 honors the flag correctly.

SimpleFSDP's ReplicateComputation uses DTensor.redistribute(Replicate()), which dispatches to _c10d_functional.all_gather_into_tensor. This op isn't in torchao's _ops_to_preserve_subclass, so WeightWithDynamicFloat8CastTensor is stripped before the collective. FSDP2 avoids this by going through fsdp_pre_all_gather / fsdp_post_all_gather hooks.

Repro: https://gist.github.com/aditvenk/9108e099a75150be64184aba40c984ea

torchrun --nproc_per_node=2 repro.py

SimpleFSDP's ReplicateComputation needs to account for quantized weights analogous to FSDP2.

Repro script is for Float8, but same issue should apply to other types like MXFP8 also

Versions

Reproducible on latest PyTorch nightly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions