Skip to content

torch.compile + DTensor: backward crash on TP-parallelized sub-module with a param-shadowing pre-hook #3122

@mreso

Description

@mreso

Bug description

Hi,

I am working on the integration of HF transformers MoE model into the transformer_backend but running in a compile issue.

#2679

Compiling an MLP sub-module (e.g. an MoE router gate) that has a forward pre-hook overwriting module.dict["weight"] with a fresh to_local() view each call crashes inside AOT autograd on the second backward pass.

I am able to get around the issue by decorating the pre-hook with torch.compile.disable but I am not sure if that is the best solution.

Why the pre-hook exists

The HF integration replicates MoE router gate params on the TP mesh via distribute_module(gate, tp_mesh), which makes gate.weight a DTensor. This works for routers that are plain nn.Linear (Mixtral, Qwen3, DeepSeek V3, etc. — handled cleanly by NoParallel).

But some routers override forward() with custom autograd Functions that don't tolerate DTensor inputs — e.g. PhiMoE's sparsemixer, whose backward calls scatter_add_ and breaks on DTensor placement changes. To make those routers run under TP without modifying upstream HF code, we shadow the DTensor weight with a local view before each call:

gate.dict["weight"] = gate.weight.to_local()

Python's attribute lookup checks the instance dict before the parameter store, so self.weight inside the router's forward sees a plain torch.Tensor. The shadow has to refresh every forward because FSDP unsharding produces a new local tensor each iteration.

Failure

File ".../torch/_functorch/_aot_autograd/runtime_wrappers.py:2554", in load_tensors
  RuntimeError: Trying to backward through the graph a second time
    (or directly access saved tensors after they have already been freed).
  First call: forward + backward succeed. Second call: forward succeeds, backward crashes.

Minimum repro

  • TP=2 on the MoE block (PrepareModuleInputOutput, Shard(1)→Replicate).
  • A to_local(grad_placements=...) pre-hook on the MoE block.
  • A pre-hook on gate that mutates gate.dict["weight"] each call.
  • torch.compile(gate, fullgraph=True).
  • Two consecutive forward+backward passes (the first passes; the second's backward crashes).
"""
Run: torchrun --nproc_per_node=2 repro.py
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_module  # noqa: F401
from torch.distributed.tensor.parallel import PrepareModuleInputOutput, parallelize_module


class Gate(nn.Module):
    """Mimics Qwen3MoeTopKRouter: nn.Module with a Parameter and custom forward."""
    def __init__(self, dim, num_experts):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_experts, dim) * 0.02)

    def forward(self, x):
        return F.linear(x, self.weight)


class MoEBlock(nn.Module):
    def __init__(self, dim, num_experts):
        super().__init__()
        self.gate = Gate(dim, num_experts)
        # "experts": a trivial eager linear acting on the routed mix.
        self.experts = nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        # x is local at this point (pre-hook converted it).
        scores = self.gate(x).softmax(dim=-1)         # (B*S, E)
        # cheap dispatch: weighted sum of expert output along E.
        return self.experts(x) * scores.mean(dim=-1, keepdim=True)


class Block(nn.Module):
    def __init__(self, dim, num_experts):
        super().__init__()
        self.mlp = MoEBlock(dim, num_experts)

    def forward(self, x):
        return self.mlp(x)


def main():
    rank = int(os.environ["RANK"])
    world = int(os.environ["WORLD_SIZE"])
    torch.cuda.set_device(rank)
    torch.distributed.init_process_group("nccl")
    mesh = init_device_mesh("cuda", (world,), mesh_dim_names=("tp",))
    tp_mesh = mesh["tp"]

    dim, num_experts, B, S = 64, 4, 2, 32
    torch.manual_seed(0)
    model = Block(dim, num_experts).cuda()

    moe = model.mlp

    # 1. Replicate gate params on TP mesh (no input/output fns).
    distribute_module(moe.gate, tp_mesh)

    # 2. Pre-hook: shadow DTensor weight with local before each forward
    #    (mirrors _router_params_to_local in the real code).
    def _gate_to_local_hook(module, args):
        w = module.weight
        if isinstance(w, DTensor):
            module.__dict__["weight"] = w.to_local()
    moe.gate.register_forward_pre_hook(_gate_to_local_hook)

    # 3. TP boundary on the MoE block: input Shard(1) -> Replicate (all-gather),
    #    output Partial -> Shard(1) (reduce-scatter).
    parallelize_module(
        moe,
        tp_mesh,
        PrepareModuleInputOutput(
            input_layouts=(Shard(1),),
            desired_input_layouts=(Replicate(),),
            use_local_input=False,
            output_layouts=(Partial(),),
            desired_output_layouts=(Shard(1),),
        ),
    )

    # 4. Pre-hook to materialize the (Replicate) DTensor input as local with
    #    grad_placements=(Partial(),) — mirrors _make_moe_to_local_pre_hook.
    def _moe_to_local(module, args):
        x = args[0]
        if isinstance(x, DTensor):
            return (x.to_local(grad_placements=(Partial(),)).clone(),)
    moe.register_forward_pre_hook(_moe_to_local)

    # 5. Compile ONLY the gate (the trigger).
    moe.gate = torch.compile(moe.gate, fullgraph=True)

    # 2 steps: step 2 backward crashes with the load_tensors error.
    for step in range(2):
        x = torch.randn(B, S, dim, device="cuda", requires_grad=True)
        # Sequence-shard the input: matches SequenceParallel upstream.
        x_sharded = DTensor.from_local(x.chunk(world, dim=1)[rank].contiguous(),
                                       tp_mesh, (Shard(1),))
        out = model(x_sharded)
        loss = out.to_local().sum() if isinstance(out, DTensor) else out.sum()
        if rank == 0:
            print(f"step {step}: forward ok, loss={loss.item():.4f}")
        loss.backward()
        if rank == 0:
            print(f"step {step}: backward ok")

    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()

cc @tianyu-l @wwwjn @fegin @acisseJZhong @xmfan

Versions

torch==2.13.0.dev20260421+cu130

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions