Bug description
Hi,
I am working on the integration of HF transformers MoE model into the transformer_backend but running in a compile issue.
#2679
Compiling an MLP sub-module (e.g. an MoE router gate) that has a forward pre-hook overwriting module.dict["weight"] with a fresh to_local() view each call crashes inside AOT autograd on the second backward pass.
I am able to get around the issue by decorating the pre-hook with torch.compile.disable but I am not sure if that is the best solution.
Why the pre-hook exists
The HF integration replicates MoE router gate params on the TP mesh via distribute_module(gate, tp_mesh), which makes gate.weight a DTensor. This works for routers that are plain nn.Linear (Mixtral, Qwen3, DeepSeek V3, etc. — handled cleanly by NoParallel).
But some routers override forward() with custom autograd Functions that don't tolerate DTensor inputs — e.g. PhiMoE's sparsemixer, whose backward calls scatter_add_ and breaks on DTensor placement changes. To make those routers run under TP without modifying upstream HF code, we shadow the DTensor weight with a local view before each call:
gate.dict["weight"] = gate.weight.to_local()
Python's attribute lookup checks the instance dict before the parameter store, so self.weight inside the router's forward sees a plain torch.Tensor. The shadow has to refresh every forward because FSDP unsharding produces a new local tensor each iteration.
Failure
File ".../torch/_functorch/_aot_autograd/runtime_wrappers.py:2554", in load_tensors
RuntimeError: Trying to backward through the graph a second time
(or directly access saved tensors after they have already been freed).
First call: forward + backward succeed. Second call: forward succeeds, backward crashes.
Minimum repro
- TP=2 on the MoE block (PrepareModuleInputOutput, Shard(1)→Replicate).
- A to_local(grad_placements=...) pre-hook on the MoE block.
- A pre-hook on gate that mutates gate.dict["weight"] each call.
- torch.compile(gate, fullgraph=True).
- Two consecutive forward+backward passes (the first passes; the second's backward crashes).
"""
Run: torchrun --nproc_per_node=2 repro.py
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_module # noqa: F401
from torch.distributed.tensor.parallel import PrepareModuleInputOutput, parallelize_module
class Gate(nn.Module):
"""Mimics Qwen3MoeTopKRouter: nn.Module with a Parameter and custom forward."""
def __init__(self, dim, num_experts):
super().__init__()
self.weight = nn.Parameter(torch.randn(num_experts, dim) * 0.02)
def forward(self, x):
return F.linear(x, self.weight)
class MoEBlock(nn.Module):
def __init__(self, dim, num_experts):
super().__init__()
self.gate = Gate(dim, num_experts)
# "experts": a trivial eager linear acting on the routed mix.
self.experts = nn.Linear(dim, dim, bias=False)
def forward(self, x):
# x is local at this point (pre-hook converted it).
scores = self.gate(x).softmax(dim=-1) # (B*S, E)
# cheap dispatch: weighted sum of expert output along E.
return self.experts(x) * scores.mean(dim=-1, keepdim=True)
class Block(nn.Module):
def __init__(self, dim, num_experts):
super().__init__()
self.mlp = MoEBlock(dim, num_experts)
def forward(self, x):
return self.mlp(x)
def main():
rank = int(os.environ["RANK"])
world = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(rank)
torch.distributed.init_process_group("nccl")
mesh = init_device_mesh("cuda", (world,), mesh_dim_names=("tp",))
tp_mesh = mesh["tp"]
dim, num_experts, B, S = 64, 4, 2, 32
torch.manual_seed(0)
model = Block(dim, num_experts).cuda()
moe = model.mlp
# 1. Replicate gate params on TP mesh (no input/output fns).
distribute_module(moe.gate, tp_mesh)
# 2. Pre-hook: shadow DTensor weight with local before each forward
# (mirrors _router_params_to_local in the real code).
def _gate_to_local_hook(module, args):
w = module.weight
if isinstance(w, DTensor):
module.__dict__["weight"] = w.to_local()
moe.gate.register_forward_pre_hook(_gate_to_local_hook)
# 3. TP boundary on the MoE block: input Shard(1) -> Replicate (all-gather),
# output Partial -> Shard(1) (reduce-scatter).
parallelize_module(
moe,
tp_mesh,
PrepareModuleInputOutput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
use_local_input=False,
output_layouts=(Partial(),),
desired_output_layouts=(Shard(1),),
),
)
# 4. Pre-hook to materialize the (Replicate) DTensor input as local with
# grad_placements=(Partial(),) — mirrors _make_moe_to_local_pre_hook.
def _moe_to_local(module, args):
x = args[0]
if isinstance(x, DTensor):
return (x.to_local(grad_placements=(Partial(),)).clone(),)
moe.register_forward_pre_hook(_moe_to_local)
# 5. Compile ONLY the gate (the trigger).
moe.gate = torch.compile(moe.gate, fullgraph=True)
# 2 steps: step 2 backward crashes with the load_tensors error.
for step in range(2):
x = torch.randn(B, S, dim, device="cuda", requires_grad=True)
# Sequence-shard the input: matches SequenceParallel upstream.
x_sharded = DTensor.from_local(x.chunk(world, dim=1)[rank].contiguous(),
tp_mesh, (Shard(1),))
out = model(x_sharded)
loss = out.to_local().sum() if isinstance(out, DTensor) else out.sum()
if rank == 0:
print(f"step {step}: forward ok, loss={loss.item():.4f}")
loss.backward()
if rank == 0:
print(f"step {step}: backward ok")
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()
cc @tianyu-l @wwwjn @fegin @acisseJZhong @xmfan
Versions
torch==2.13.0.dev20260421+cu130
Bug description
Hi,
I am working on the integration of HF transformers MoE model into the transformer_backend but running in a compile issue.
#2679
Compiling an MLP sub-module (e.g. an MoE router gate) that has a forward pre-hook overwriting module.dict["weight"] with a fresh to_local() view each call crashes inside AOT autograd on the second backward pass.
I am able to get around the issue by decorating the pre-hook with torch.compile.disable but I am not sure if that is the best solution.
Why the pre-hook exists
The HF integration replicates MoE router gate params on the TP mesh via distribute_module(gate, tp_mesh), which makes gate.weight a DTensor. This works for routers that are plain nn.Linear (Mixtral, Qwen3, DeepSeek V3, etc. — handled cleanly by NoParallel).
But some routers override forward() with custom autograd Functions that don't tolerate DTensor inputs — e.g. PhiMoE's sparsemixer, whose backward calls scatter_add_ and breaks on DTensor placement changes. To make those routers run under TP without modifying upstream HF code, we shadow the DTensor weight with a local view before each call:
gate.dict["weight"] = gate.weight.to_local()
Python's attribute lookup checks the instance dict before the parameter store, so self.weight inside the router's forward sees a plain torch.Tensor. The shadow has to refresh every forward because FSDP unsharding produces a new local tensor each iteration.
Failure
Minimum repro
cc @tianyu-l @wwwjn @fegin @acisseJZhong @xmfan
Versions
torch==2.13.0.dev20260421+cu130