-
Notifications
You must be signed in to change notification settings - Fork 378
[cp][flex_attention] integration test trial #1160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/XilunWu/18/base
Are you sure you want to change the base?
Conversation
@@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non | |||
"FlexAttention is not compatible with selective AC yet. " | |||
"See https://github.com/pytorch/pytorch/issues/147879" | |||
) | |||
|
|||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just remove this block.
mask_mod = FlexAttention._get_causal_mask_mod() | ||
batch_dimension = 1 | ||
seq_len = inputs.shape[1] | ||
block_mask = FlexAttention.compiled_create_block_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should either let flex attention provide this compiled_create_block_mask to minimize the dependency on users' code when parallelizing CP. cc., @drisspg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meaning that Flex provides the compiled partial with no mask_mod args?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For CP + flex_attention, this PR generates 3 compiled BlockMask object for each mask_mod in:
- QKV sharding -- this requires the existence of compiled BlockMask on global batch input and mask_mod to load balance.
- actual training. The first FlexAttention module in model will create a compiled BlockMask from the sharded batch input and mask_mod. Note that applying this mask_mod to the sharded batch input is meaningless. Therefore this BlockMask will not be used the actual CP flex_attention computation.
- actual training. When forward flex_attention is called over the sharded batch input for the first time in the current step, a BlockMask will be created from the sharded batch input and a remapped mask_mod which corresponds to the local region in the attention score (the Q_LEN by KV_LEN rectangle).
(1) introduces a dependency in user code in order to adopt CP flex_attention. (2) is how we define the mask_mod in torchtitan and can be modified. Ideally (1) and (2) can be merged so that there's no redundancy as well as user code modification in order to use CP.
if self.model_args.use_flex_attn: | ||
from torchtitan.models.attention import FlexAttention | ||
|
||
mask_mod = FlexAttention._get_causal_mask_mod() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think mask_mod
should be the input of context_parallel()
and we can directly call compiled_create_block_mask
. See the comment below.
Stack from ghstack (oldest at bottom):