Skip to content

Commit 596d8ac

Browse files
committed
[cp][flex_attention] integration test trial
ghstack-source-id: 7e12a16 Pull-Request-resolved: #1160
1 parent a4ed09c commit 596d8ac

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

torchtitan/distributed/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
import os
1010
from collections.abc import Generator, Iterable
1111
from datetime import timedelta
12+
from typing import Optional
1213

1314
import torch
1415
import torch.distributed._functional_collectives as funcol
1516
import torch.distributed.distributed_c10d as c10d
1617
from torch import distributed as dist
1718
from torch.distributed.device_mesh import DeviceMesh
1819
from torch.distributed.tensor import DTensor
20+
from torch.distributed.tensor.experimental._attention import FlexAttentionSharder
21+
from torch.nn.attention.flex_attention import BlockMask
1922

2023
from torchtitan.tools.logging import logger
2124
from torchtitan.tools.utils import device_module, device_type
@@ -154,22 +157,31 @@ def create_context_parallel_ctx(
154157
cp_seq_dims: list[int],
155158
cp_no_restore_buffers: set[torch.Tensor],
156159
cp_rotate_method: str,
160+
block_mask: Optional[BlockMask] = None,
161+
sharder: Optional[FlexAttentionSharder] = None,
157162
):
158163
try:
159164
from torch.distributed.tensor.experimental import context_parallel
160-
from torch.distributed.tensor.experimental._attention import set_rotate_method
165+
from torch.distributed.tensor.experimental._attention import (
166+
_dispatch_mode,
167+
_DispatchMode,
168+
set_rotate_method,
169+
)
161170
except ImportError:
162171
print(
163172
f"PyTorch version {torch.__version__} does not include the experimental "
164173
"Context Parallel API. Please update to a newer version."
165174
)
166175

167176
set_rotate_method(cp_rotate_method)
177+
_dispatch_mode = _DispatchMode.TORCH_DISPATCH
168178
return context_parallel(
169179
cp_mesh,
170180
buffers=cp_buffers,
171181
buffer_seq_dims=cp_seq_dims,
172182
no_restore_buffers=cp_no_restore_buffers,
183+
block_mask=block_mask,
184+
sharder=sharder,
173185
)
174186

175187

torchtitan/experiments/llama4/model/args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
7474
"FlexAttention is not compatible with selective AC yet. "
7575
"See https://github.com/pytorch/pytorch/issues/147879"
7676
)
77-
77+
"""
7878
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
7979
raise ValueError(
8080
"FlexAttention is not compatible with CP yet. "
8181
"We are still working on this."
8282
)
83+
"""
8384

8485
def get_nparams_and_flops(
8586
self, model: nn.Module, seq_len: int

torchtitan/train.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
from typing import Any, Generator, Iterable, Optional
1212

1313
import torch
14-
from torch.distributed.elastic.multiprocessing.errors import record
1514

1615
import torchtitan.components.ft as ft
1716
import torchtitan.protocols.train_spec as train_spec_module
17+
from torch.distributed.elastic.multiprocessing.errors import record
18+
from torch.distributed.tensor.experimental._attention import (
19+
FlexAttentionContiguousSharder,
20+
)
1821

1922
from torchtitan.components.checkpoint import CheckpointManager
2023
from torchtitan.components.metrics import (
@@ -133,7 +136,9 @@ def __init__(self, job_config: JobConfig):
133136

134137
# build model (using meta init)
135138
model_cls = self.train_spec.cls
139+
# NOTE (xilunwu): need to store model_args.use_flex_attn for train_step
136140
model_args = self.train_spec.config[job_config.model.flavor]
141+
self.model_args = model_args
137142
# set the model args from training job configs
138143
model_args.update_from_config(job_config, tokenizer)
139144

@@ -319,13 +324,29 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
319324
# apply context parallelism if cp is enabled
320325
# ensure CP handles the separate freqs_cis buffer for each pp stage
321326
inputs = input_dict["input"]
327+
328+
# TODO: move this into `create_context_parallel_ctx`
329+
# init block_mask for flex_attention
330+
block_mask = None
331+
if self.model_args.use_flex_attn:
332+
from torchtitan.models.attention import FlexAttention
333+
334+
mask_mod = FlexAttention._get_causal_mask_mod()
335+
batch_dimension = 1
336+
seq_len = inputs.shape[1]
337+
block_mask = FlexAttention.compiled_create_block_mask(
338+
mask_mod, batch_dimension, None, seq_len, seq_len
339+
)
340+
322341
optional_context_parallel_ctx = (
323342
dist_utils.create_context_parallel_ctx(
324343
cp_mesh=world_mesh["cp"],
325344
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
326345
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
327346
cp_no_restore_buffers={inputs, labels},
328347
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
348+
block_mask=block_mask,
349+
sharder=FlexAttentionContiguousSharder(),
329350
)
330351
if parallel_dims.cp_enabled
331352
else None

0 commit comments

Comments
 (0)