Skip to content

compile: turn off fullgraph=True to support llama4 #1182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/bdhirsh/3/base
Choose a base branch
from

Conversation

bdhirsh
Copy link

@bdhirsh bdhirsh commented May 12, 2025

This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile

CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile

Stack from ghstack (oldest at bottom):

bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: 98bb0ed
Pull Request resolved: #1182
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 12, 2025
@@ -304,7 +304,7 @@ def apply_compile(model: nn.Module):
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment/TODO to remind us to turn it back on when issues are resolved?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh will do. there are two related things here:

(1) grouped_mm support in compile, which torchtitan uses in llama4. I added basic support in core in this PR: pytorch/pytorch#153384

(2) E2E llama4 + compile in torchtitan. The current reason this completely blows up today is that torchtitan's llama4 + FSDP2 integration requires wrapping the MoE layer, which requires installing backward hooks around the MoE layer. Compile does not support compiling backward hooks (we graph break), and so we need to do one of these options:

(a) allow the graph break (turn off fullgraph=True)

(b) tweak torchtitan so that instead of compiling each transformer layer, we compile MoE layers separately, and compile the rest of the transformer block layer separately as well.

I also mentioned this to @tianyu-l but calling it out here: (a) is easier to do, so I'm doing it here, but it does have the risk that if any changes are made to core that increases the number of graph breaks in torchtitan, we won't error as loudly (we may see a perf drop instead). (b) is probably better to do at some point, I'm just doing the simpler thing here.

Are folks working on torchtitan interested in running benchmarks for titan + llama4 (with compile on/off?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I actually tweaked the PR so that we still fullgraph=True compile the "regular" transformer blocks, and only fullgraph=False the blocks with MoE layers. I think this should reduce the risk we hit regressions, so this may be a reasonable long term solution (when using FSDP2 in torchtitan), as long as we see reasonable perf numbers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh Thanks for the thorough explanation. The 8GPU integration test is timeout. Since llama4 is not in the integration test, the integrate test issue should not be caused by this PR. I still relaunch the test but feel free to land it.

This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: f22f920
Pull Request resolved: #1182
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: cd16b65
Pull Request resolved: #1182
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: 1539e21
Pull Request resolved: #1182
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current reason this completely blows up today is that torchtitan's llama4 + FSDP2 integration requires wrapping the MoE layer, which requires installing backward hooks around the MoE layer.

hmm let's be more careful here. This is only true when EP is used, specifically dp2ep (e.g. in #732). Currently in Llama 4, EP is not supported yet, which means we are doing homogeneous FSDP2 wrapping to all the transformer blocks only (not to MoE modules). So I suppose the full graph compilation shouldn't be violated. If we set full_graph=True, where would it break?

@bdhirsh
Copy link
Author

bdhirsh commented May 13, 2025

Locally, I was seeing that when we compile each transformer block layer, dynamo was trying (and failing) to graph break, because someone was attempting to install backward hooks inside of one of the transformer blocks. If that's surprising to you I can try to find the code that is installing the backward hook.

@tianyu-l
Copy link
Contributor

I think the backward hooks are from the auxiliary-loss-free load balancing (#1114).

The load balancing algorithm would possess a bias term for each expert, based on the number of tokens an expert has seen so far.

  1. The single-device algo needs a backward hook to update the bias term after each iteration.
  2. For multi-device, we need another backward hook to all-reduce the bias term across all DP ranks, as different DP ranks see different inputs.

Using forward, forward pre, or backward pre hooks would cause conflict with activation checkpointing.

@tianyu-l
Copy link
Contributor

cc @soulitzer on this issue
Basically in order to update the bias terms without intruding model code, I needed to use hooks. To not interfere with AC, I had to use full backward hooks, which breaks full-graph compilation.

I wonder if AC supports optionally bypassing some hook computation, even in the full AC mode?

@soulitzer
Copy link

@tianyu-l Do you have more details on what the conflict with activation checkpointing is?

Without knowing any more context, one guess is that issue that using forward, forward pre, or backward pre hooks causes conflict with activation checkpointing because it would make the original forward computation and recompute see different bias terms.
And the reason the backward post hook is fine is because it would mean the bias term is updated AFTER the recompute has already been done?
(With this explanation though, forward pre hooks should be fine too?)

@tianyu-l
Copy link
Contributor

@soulitzer

Do you have more details on what the conflict with activation checkpointing is?

The problem is, the hook _update_expert_bias should only be executed once, which only full backward hook can achieve. AC would incur repeated computation on other hooks.

(With this explanation though, forward pre hooks should be fine too?)

The problem with forward pre hook is:
After the expected call to _update_expert_bias happens, the self.tokens_per_experts (which has statistics of last iteration) gets cleared after self.expert_bias gets updated.
https://github.com/pytorch/torchtitan/pull/1114/files#diff-87cc24d85c768f0b3d1f5c54cca39dc9de52ee20e8f601814c3200722901aee5R242

During AC recomputation, this hook will be executed again, but this time self.tokens_per_experts has been updated by forward to reflect the statistics of the current iteration.
https://github.com/pytorch/torchtitan/pull/1114/files#diff-87cc24d85c768f0b3d1f5c54cca39dc9de52ee20e8f601814c3200722901aee5R263
This would cause discrepancy in self.expert_bias between activation during forward and activation during recomputation.

@soulitzer
Copy link

Thanks for the explanation, that makes sense!

I wonder if AC supports optionally bypassing some hook computation, even in the full AC mode?

This is tricky because AC treats the user function like a black box, so it always executes everything a second time. Adding more extensions points to interpose additional setup/takedown logic (e.g. to remove/reinstall hooks) before/after recomputation seems to avoid solving the core issue, which is that side effects execute twice.

I'm actually working on a new tracer-based version of AC repo doc that should address this problem soon.
It traces out the aten operations that your forward executes and only executes those. This would mean that python side effects no longer run twice (proper handling in-place aten ops is not ready yet, but should be able to be done in a nicer way than in the previous AC)
It's still WIP, but so far I've tried it on llama3 8B non-compile (don't really under the FLOPs stats but theoretically they should be strictly lower, peak memory matches with full AC and 2-3% lower on SAC)

@tianyu-l
Copy link
Contributor

@soulitzer
This sounds exciting! Let's definitely try it and see if it can help avoid using backward hooks in this problem.

This would mean that python side effects no longer run twice

how would you expose control to users so that when users want to recompute some forward hooks they can still do so?

proper handling in-place aten ops is not ready yet, but should be able to be done in a nicer way than in the previous AC

I guess I'm missing some context on the previous AC treatment of in-place ops.
I do wanna flag that in a recent fix, the _update_expert_bias hook uses in-place addition to update the buffer self.expert_bias.
#1226

@soulitzer
Copy link

soulitzer commented May 30, 2025

@tianyu-l

This sounds exciting! Let's definitely try it and see if it can help avoid using backward hooks in this problem.

Great! Let me try to get these in-place fixes in soon.

how would you expose control to users so that when users want to recompute some forward hooks they can still do so?

There are no APIs for such explicit control today, but if the forward hooks are in the checkpoint region, and contain aten ops, and those aten ops are required to recompute a saved tensor, then those aten ops would already be recomputed (unless marked saved). OTOH, things like python print/ assignment to a python global, etc. won't be replayed. Do you have more context on what types of operations you'd like to do in these forward hooks?

I guess I'm missing some context on the previous AC treatment of in-place ops.

There were a couple cases for old eager AC wrt in-place:

  1. If the thing being mutated itself was recomputed in the AC region anyway, mutating is safe
  2. If the thing being mutated is an explicit input to the AC region (or saved by SAC), we always error because autograd sees the version counter changing.
  3. If the thing being mutated is a captured input to the AC region (usually parameters/buffers), that tensor gets silently mutated twice, e.g. batch norm.

In the new AC:

  1. This case gets a little more complicated - mutations, even if they are not on inputs, are not necessarily safe anymore, because the new AC does not guarantee to replay the forward ops in the same order. We can decide to save the in-place op or do some kind of clone.
  2. (If the input did not get used by another op prior to mutation) we can instead save the updated version of the tensor instead. If another op depends on the previous version of the tensor, then we may have no other choice than to clone.
  3. (new AC doesn't distinguish between explicit/captured inputs)

For compile, I would need to double check what exactly happens, but for SAC in compile we have a warning that basically says "please don't use in-place"

@tianyu-l
Copy link
Contributor

@soulitzer
oh it sounds quite involved

Overall my request from the MoE auxiliary-loss-free load balancing is
"being able to register the hooks in #1114 so that it works nicely with AC and torch.compile".

Specifically

  • torch.compile doesn't like backward hooks
  • the hooks I registered don't work well with AC if not as full backward hooks

So if the solution is to make them compatible with AC as forward hooks, we need to make sure:
those hooks to update buffers in place (#1226) only happen in forward but not in backward

Happy to meet some time to clarify the questions and paths.

@soulitzer
Copy link

soulitzer commented May 30, 2025

@tianyu-l

those hooks to update buffers in place (#1226) only happen in forward but not in backward

I see, overall I'd say this is very much in line with what the new AC wants to achieve, and is what is also tricky to achieve in the old AC design, so it's great to have this as a concrete use case in MOEs!
(To the perspective of AC forward hooks are not any different from any other logic being executed in the forward pass.)

I've just now added some very basic in-place support only tested in eager so far: "If a tensor is mutated, save its latest version. If someone depends on an earlier version, error".
soulitzer/ac-experimental@c7f495f#diff-3665d65394f4f58a56a256ad6dd8621c68118d90fe56a19387e251c19cec2d2eR406
There's still things to be done though, e.g. we may want to smarter things in the case of ops like batch norm.

I'll take a look at the compile side next (about to sign off, but should have something early next week)

@tianyu-l tianyu-l mentioned this pull request Jun 2, 2025
14 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants