Skip to content

[cp][flex_attention] integration test trial #1160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: gh/XilunWu/18/base
Choose a base branch
from

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented May 1, 2025

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request May 1, 2025
ghstack-source-id: 7e12a16
Pull-Request-resolved: #1160
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 1, 2025
@XilunWu XilunWu requested a review from fegin May 1, 2025 21:22
@XilunWu XilunWu marked this pull request as draft May 1, 2025 21:22
@@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
"FlexAttention is not compatible with selective AC yet. "
"See https://github.com/pytorch/pytorch/issues/147879"
)

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just remove this block.

mask_mod = FlexAttention._get_causal_mask_mod()
batch_dimension = 1
seq_len = inputs.shape[1]
block_mask = FlexAttention.compiled_create_block_mask(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should either let flex attention provide this compiled_create_block_mask to minimize the dependency on users' code when parallelizing CP. cc., @drisspg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning that Flex provides the compiled partial with no mask_mod args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CP + flex_attention, this PR generates 3 compiled BlockMask object for each mask_mod in:

  1. QKV sharding -- this requires the existence of compiled BlockMask on global batch input and mask_mod to load balance.
  2. actual training. The first FlexAttention module in model will create a compiled BlockMask from the sharded batch input and mask_mod. Note that applying this mask_mod to the sharded batch input is meaningless. Therefore this BlockMask will not be used the actual CP flex_attention computation.
  3. actual training. When forward flex_attention is called over the sharded batch input for the first time in the current step, a BlockMask will be created from the sharded batch input and a remapped mask_mod which corresponds to the local region in the attention score (the Q_LEN by KV_LEN rectangle).

(1) introduces a dependency in user code in order to adopt CP flex_attention. (2) is how we define the mask_mod in torchtitan and can be modified. Ideally (1) and (2) can be merged so that there's no redundancy as well as user code modification in order to use CP.

if self.model_args.use_flex_attn:
from torchtitan.models.attention import FlexAttention

mask_mod = FlexAttention._get_causal_mask_mod()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mask_mod should be the input of context_parallel() and we can directly call compiled_create_block_mask. See the comment below.

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request May 12, 2025
ghstack-source-id: b0b6434
Pull-Request-resolved: #1160
[ghstack-poisoned]
XilunWu added a commit that referenced this pull request May 21, 2025
ghstack-source-id: 9d38242
Pull-Request-resolved: #1160
XilunWu added a commit that referenced this pull request May 27, 2025
ghstack-source-id: 9d38242
Pull-Request-resolved: #1160

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request May 27, 2025
ghstack-source-id: 9d38242
Pull-Request-resolved: #1160

ghstack-source-id: 53d803f
Pull Request resolved: #1228
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot. module: context parallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants