Skip to content

Commit a61afbd

Browse files
committed
[cp][flex_attention] integration test trial
ghstack-source-id: b0b6434 Pull-Request-resolved: #1160
1 parent a4ed09c commit a61afbd

File tree

6 files changed

+51
-10
lines changed

6 files changed

+51
-10
lines changed

torchtitan/distributed/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
import os
1010
from collections.abc import Generator, Iterable
1111
from datetime import timedelta
12+
from typing import Optional
1213

1314
import torch
1415
import torch.distributed._functional_collectives as funcol
1516
import torch.distributed.distributed_c10d as c10d
1617
from torch import distributed as dist
1718
from torch.distributed.device_mesh import DeviceMesh
1819
from torch.distributed.tensor import DTensor
20+
from torch.distributed.tensor.experimental._attention import FlexAttentionSharder
21+
from torch.nn.attention.flex_attention import BlockMask
1922

2023
from torchtitan.tools.logging import logger
2124
from torchtitan.tools.utils import device_module, device_type
@@ -154,22 +157,31 @@ def create_context_parallel_ctx(
154157
cp_seq_dims: list[int],
155158
cp_no_restore_buffers: set[torch.Tensor],
156159
cp_rotate_method: str,
160+
block_mask: Optional[BlockMask] = None,
161+
sharder: Optional[FlexAttentionSharder] = None,
157162
):
158163
try:
159164
from torch.distributed.tensor.experimental import context_parallel
160-
from torch.distributed.tensor.experimental._attention import set_rotate_method
165+
from torch.distributed.tensor.experimental._attention import (
166+
_dispatch_mode,
167+
_DispatchMode,
168+
set_rotate_method,
169+
)
161170
except ImportError:
162171
print(
163172
f"PyTorch version {torch.__version__} does not include the experimental "
164173
"Context Parallel API. Please update to a newer version."
165174
)
166175

167176
set_rotate_method(cp_rotate_method)
177+
_dispatch_mode = _DispatchMode.TORCH_DISPATCH
168178
return context_parallel(
169179
cp_mesh,
170180
buffers=cp_buffers,
171181
buffer_seq_dims=cp_seq_dims,
172182
no_restore_buffers=cp_no_restore_buffers,
183+
block_mask=block_mask,
184+
sharder=sharder,
173185
)
174186

175187

torchtitan/experiments/llama4/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
rope_theta=500000,
4141
num_experts=16,
4242
interleave_moe_layer_step=1,
43+
use_flex_attn=True,
44+
attn_mask_type="block_causal",
4345
),
4446
"17bx128e": TransformerModelArgs(
4547
dim=5120,

torchtitan/experiments/llama4/model/args.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class TransformerModelArgs(BaseModelArgs):
5555
interleave_moe_layer_step: int = 2
5656
# token-choice
5757
top_k: int = 1
58-
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
58+
use_grouped_mm: bool = False # grouped mm or for-loop for the experts computation
5959
load_balance_coeff: float | None = 1e-3
6060

6161
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
@@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
7474
"FlexAttention is not compatible with selective AC yet. "
7575
"See https://github.com/pytorch/pytorch/issues/147879"
7676
)
77-
77+
"""
7878
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
7979
raise ValueError(
8080
"FlexAttention is not compatible with CP yet. "
8181
"We are still working on this."
8282
)
83+
"""
8384

8485
def get_nparams_and_flops(
8586
self, model: nn.Module, seq_len: int

torchtitan/models/llama3/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
multiple_of=1024,
4848
rope_theta=500000,
4949
),
50+
"8B_flex_attn": TransformerModelArgs(
51+
dim=4096,
52+
n_layers=32,
53+
n_heads=32,
54+
n_kv_heads=8,
55+
ffn_dim_multiplier=1.3,
56+
multiple_of=1024,
57+
rope_theta=500000,
58+
use_flex_attn=True,
59+
attn_mask_type="block_causal",
60+
),
5061
"70B": TransformerModelArgs(
5162
dim=8192,
5263
n_layers=80,

torchtitan/models/llama3/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
5151
"See https://github.com/pytorch/pytorch/issues/147879"
5252
)
5353

54-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
55-
raise ValueError(
56-
"FlexAttention is not compatible with CP yet. "
57-
"We are still working on this."
58-
)
59-
6054
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
6155
nparams = sum(p.numel() for p in model.parameters())
6256
nparams_embedding = sum(

torchtitan/train.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
from typing import Any, Generator, Iterable, Optional
1212

1313
import torch
14-
from torch.distributed.elastic.multiprocessing.errors import record
1514

1615
import torchtitan.components.ft as ft
1716
import torchtitan.protocols.train_spec as train_spec_module
17+
from torch.distributed.elastic.multiprocessing.errors import record
18+
from torch.distributed.tensor.experimental._attention import (
19+
FlexAttentionContiguousSharder,
20+
)
1821

1922
from torchtitan.components.checkpoint import CheckpointManager
2023
from torchtitan.components.metrics import (
@@ -133,7 +136,9 @@ def __init__(self, job_config: JobConfig):
133136

134137
# build model (using meta init)
135138
model_cls = self.train_spec.cls
139+
# NOTE (xilunwu): need to store model_args.use_flex_attn for train_step
136140
model_args = self.train_spec.config[job_config.model.flavor]
141+
self.model_args = model_args
137142
# set the model args from training job configs
138143
model_args.update_from_config(job_config, tokenizer)
139144

@@ -319,13 +324,29 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
319324
# apply context parallelism if cp is enabled
320325
# ensure CP handles the separate freqs_cis buffer for each pp stage
321326
inputs = input_dict["input"]
327+
328+
# TODO: move this into `create_context_parallel_ctx`
329+
# init block_mask for flex_attention
330+
block_mask = None
331+
if self.model_args.use_flex_attn:
332+
from torchtitan.models.attention import FlexAttention
333+
334+
mask_mod = FlexAttention._get_causal_mask_mod()
335+
batch_dimension = 1
336+
seq_len = inputs.shape[1]
337+
block_mask = FlexAttention.compiled_create_block_mask(
338+
mask_mod, batch_dimension, None, seq_len, seq_len
339+
)
340+
322341
optional_context_parallel_ctx = (
323342
dist_utils.create_context_parallel_ctx(
324343
cp_mesh=world_mesh["cp"],
325344
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
326345
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
327346
cp_no_restore_buffers={inputs, labels},
328347
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
348+
block_mask=block_mask,
349+
sharder=FlexAttentionContiguousSharder(),
329350
)
330351
if parallel_dims.cp_enabled
331352
else None

0 commit comments

Comments
 (0)