Skip to content

Commit 132b1ee

Browse files
authored
[Flux] Improve reshare_after_forward for Flux model's last layer (#1097)
As title. Set reshard_after_forward=False for last layer to avoid gather right after reshard. Similar to llama as discussed in #1091.
1 parent 89cdc43 commit 132b1ee

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchtitan/experiments/flux/parallelize_flux.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ def apply_fsdp(
9595
block,
9696
**fsdp_config,
9797
)
98-
# apply FSDP to last layer
99-
fully_shard(model.final_layer, **fsdp_config)
98+
# apply FSDP to last layer. Set reshard_after_forward=False for last layer to avoid gather right after reshard
99+
fully_shard(model.final_layer, **fsdp_config, reshard_after_forward=False)
100+
100101
# Wrap all the rest of model
101102
fully_shard(model, **fsdp_config)
102103

0 commit comments

Comments
 (0)