-
Notifications
You must be signed in to change notification settings - Fork 383
compile: turn off fullgraph=True to support llama4 #1182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/bdhirsh/3/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
@@ -304,7 +304,7 @@ def apply_compile(model: nn.Module): | |||
repeated structure. Alternatively one can compile the whole model (after applying DP). | |||
""" | |||
for layer_id, transformer_block in model.layers.named_children(): | |||
transformer_block = torch.compile(transformer_block, fullgraph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment/TODO to remind us to turn it back on when issues are resolved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh will do. there are two related things here:
(1) grouped_mm support in compile, which torchtitan uses in llama4. I added basic support in core in this PR: pytorch/pytorch#153384
(2) E2E llama4 + compile in torchtitan. The current reason this completely blows up today is that torchtitan's llama4 + FSDP2 integration requires wrapping the MoE layer, which requires installing backward hooks around the MoE layer. Compile does not support compiling backward hooks (we graph break), and so we need to do one of these options:
(a) allow the graph break (turn off fullgraph=True
)
(b) tweak torchtitan so that instead of compiling each transformer layer, we compile MoE layers separately, and compile the rest of the transformer block layer separately as well.
I also mentioned this to @tianyu-l but calling it out here: (a) is easier to do, so I'm doing it here, but it does have the risk that if any changes are made to core that increases the number of graph breaks in torchtitan, we won't error as loudly (we may see a perf drop instead). (b) is probably better to do at some point, I'm just doing the simpler thing here.
Are folks working on torchtitan interested in running benchmarks for titan + llama4 (with compile on/off?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin I actually tweaked the PR so that we still fullgraph=True compile the "regular" transformer blocks, and only fullgraph=False the blocks with MoE layers. I think this should reduce the risk we hit regressions, so this may be a reasonable long term solution (when using FSDP2 in torchtitan), as long as we see reasonable perf numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh Thanks for the thorough explanation. The 8GPU integration test is timeout. Since llama4 is not in the integration test, the integrate test issue should not be caused by this PR. I still relaunch the test but feel free to land it.
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current reason this completely blows up today is that torchtitan's llama4 + FSDP2 integration requires wrapping the MoE layer, which requires installing backward hooks around the MoE layer.
hmm let's be more careful here. This is only true when EP is used, specifically dp2ep (e.g. in #732). Currently in Llama 4, EP is not supported yet, which means we are doing homogeneous FSDP2 wrapping to all the transformer blocks only (not to MoE modules). So I suppose the full graph compilation shouldn't be violated. If we set full_graph=True
, where would it break?
Locally, I was seeing that when we compile each transformer block layer, dynamo was trying (and failing) to graph break, because someone was attempting to install backward hooks inside of one of the transformer blocks. If that's surprising to you I can try to find the code that is installing the backward hook. |
I think the backward hooks are from the auxiliary-loss-free load balancing (#1114). The load balancing algorithm would possess a bias term for each expert, based on the number of tokens an expert has seen so far.
Using forward, forward pre, or backward pre hooks would cause conflict with activation checkpointing. |
cc @soulitzer on this issue I wonder if AC supports optionally bypassing some hook computation, even in the full AC mode? |
@tianyu-l Do you have more details on what the conflict with activation checkpointing is? Without knowing any more context, one guess is that issue that using forward, forward pre, or backward pre hooks causes conflict with activation checkpointing because it would make the original forward computation and recompute see different bias terms. |
The problem is, the hook
The problem with forward pre hook is: During AC recomputation, this hook will be executed again, but this time |
Thanks for the explanation, that makes sense!
This is tricky because AC treats the user function like a black box, so it always executes everything a second time. Adding more extensions points to interpose additional setup/takedown logic (e.g. to remove/reinstall hooks) before/after recomputation seems to avoid solving the core issue, which is that side effects execute twice. I'm actually working on a new tracer-based version of AC repo doc that should address this problem soon. |
@soulitzer
how would you expose control to users so that when users want to recompute some forward hooks they can still do so?
I guess I'm missing some context on the previous AC treatment of in-place ops. |
Great! Let me try to get these in-place fixes in soon.
There are no APIs for such explicit control today, but if the forward hooks are in the checkpoint region, and contain aten ops, and those aten ops are required to recompute a saved tensor, then those aten ops would already be recomputed (unless marked saved). OTOH, things like python print/ assignment to a python global, etc. won't be replayed. Do you have more context on what types of operations you'd like to do in these forward hooks?
There were a couple cases for old eager AC wrt in-place:
In the new AC:
For compile, I would need to double check what exactly happens, but for SAC in compile we have a warning that basically says "please don't use in-place" |
@soulitzer Overall my request from the MoE auxiliary-loss-free load balancing is Specifically
So if the solution is to make them compatible with AC as forward hooks, we need to make sure: Happy to meet some time to clarify the questions and paths. |
I see, overall I'd say this is very much in line with what the new AC wants to achieve, and is what is also tricky to achieve in the old AC design, so it's great to have this as a concrete use case in MOEs! I've just now added some very basic in-place support only tested in eager so far: "If a tensor is mutated, save its latest version. If someone depends on an earlier version, error". I'll take a look at the compile side next (about to sign off, but should have something early next week) |
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
Stack from ghstack (oldest at bottom):