Skip to content

Commit 3cdc62a

Browse files
committed
[cp][flex_attention] integration test trial
ghstack-source-id: 9d38242 Pull-Request-resolved: #1160
1 parent 3381277 commit 3cdc62a

File tree

10 files changed

+83
-22
lines changed

10 files changed

+83
-22
lines changed

torchtitan/distributed/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
import os
1010
from collections.abc import Generator, Iterable
1111
from datetime import timedelta
12+
from typing import Optional
1213

1314
import torch
1415
import torch.distributed._functional_collectives as funcol
1516
import torch.distributed.distributed_c10d as c10d
1617
from torch import distributed as dist
1718
from torch.distributed.device_mesh import DeviceMesh
1819
from torch.distributed.tensor import DTensor
20+
from torch.distributed.tensor.experimental._attention import FlexAttentionSharder
21+
from torch.nn.attention.flex_attention import BlockMask
1922

2023
from torchtitan.tools.logging import logger
2124
from torchtitan.tools.utils import device_module, device_type
@@ -154,22 +157,37 @@ def create_context_parallel_ctx(
154157
cp_seq_dims: list[int],
155158
cp_no_restore_buffers: set[torch.Tensor],
156159
cp_rotate_method: str,
160+
block_mask: Optional[BlockMask] = None,
161+
sharder: Optional[FlexAttentionSharder] = None,
157162
):
158163
try:
159164
from torch.distributed.tensor.experimental import context_parallel
160-
from torch.distributed.tensor.experimental._attention import set_rotate_method
165+
from torch.distributed.tensor.experimental._attention import (
166+
_dispatch_mode,
167+
_DispatchMode,
168+
set_rotate_method,
169+
)
161170
except ImportError:
162171
print(
163172
f"PyTorch version {torch.__version__} does not include the experimental "
164173
"Context Parallel API. Please update to a newer version."
165174
)
166175

167176
set_rotate_method(cp_rotate_method)
177+
torch.distributed.tensor.experimental._attention._dispatch_mode = (
178+
_DispatchMode.TORCH_DISPATCH
179+
)
180+
assert (
181+
torch.distributed.tensor.experimental._attention._dispatch_mode
182+
== _DispatchMode.TORCH_DISPATCH
183+
)
168184
return context_parallel(
169185
cp_mesh,
170186
buffers=cp_buffers,
171187
buffer_seq_dims=cp_seq_dims,
172188
no_restore_buffers=cp_no_restore_buffers,
189+
block_mask=block_mask,
190+
sharder=sharder,
173191
)
174192

175193

torchtitan/experiments/llama4/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
rope_theta=500000,
4141
num_experts=16,
4242
interleave_moe_layer_step=1,
43+
use_flex_attn=True,
44+
attn_mask_type="block_causal",
45+
# attn_mask_type="causal",
4346
),
4447
"17bx128e": TransformerModelArgs(
4548
dim=5120,

torchtitan/experiments/llama4/model/args.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class TransformerModelArgs(BaseModelArgs):
5555
interleave_moe_layer_step: int = 2
5656
# token-choice
5757
top_k: int = 1
58-
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
58+
use_grouped_mm: bool = False # grouped mm or for-loop for the experts computation
5959
load_balance_coeff: float | None = 1e-3
6060

6161
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
@@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
7474
"FlexAttention is not compatible with selective AC yet. "
7575
"See https://github.com/pytorch/pytorch/issues/147879"
7676
)
77-
77+
"""
7878
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
7979
raise ValueError(
8080
"FlexAttention is not compatible with CP yet. "
8181
"We are still working on this."
8282
)
83+
"""
8384

8485
def get_nparams_and_flops(
8586
self, model: nn.Module, seq_len: int

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ batch_size = 8
4141
seq_len = 2048
4242
max_norm = 1.0 # grad norm clipping
4343
steps = 10
44-
compile = false
44+
compile = true
4545
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4646

4747
[parallelism]

torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ profile_freq = 100
1111

1212
[metrics]
1313
log_freq = 10
14-
enable_tensorboard = false
14+
enable_tensorboard = true
1515
save_tb_folder = "tb"
1616

1717
[model]
@@ -27,23 +27,31 @@ eps = 1e-15
2727

2828
[lr_scheduler]
2929
warmup_steps = 600
30+
# warmup_steps = 20
3031
lr_min = 0.1
3132

3233
[training]
33-
batch_size = 8
34+
# batch_size = 8
35+
batch_size = 4
3436
seq_len = 8192
37+
# seq_len = 16384
38+
# seq_len = 32768
39+
# seq_len = 65536
3540
max_norm = 1.0 # grad norm clipping
3641
steps = 3000
42+
# steps = 100
3743
compile = false
44+
# compile = true
3845
dataset = "c4"
46+
deterministic = true
3947

4048
[parallelism]
4149
data_parallel_replicate_degree = 1
4250
data_parallel_shard_degree = -1
4351
tensor_parallel_degree = 8
4452
enable_async_tensor_parallel = false
4553
pipeline_parallel_degree = 1
46-
context_parallel_degree = 1
54+
context_parallel_degree = 4
4755

4856
[checkpoint]
4957
enable_checkpoint = false

torchtitan/models/llama3/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
multiple_of=1024,
4848
rope_theta=500000,
4949
),
50+
"8B_flex_attn": TransformerModelArgs(
51+
dim=4096,
52+
n_layers=32,
53+
n_heads=32,
54+
n_kv_heads=8,
55+
ffn_dim_multiplier=1.3,
56+
multiple_of=1024,
57+
rope_theta=500000,
58+
use_flex_attn=True,
59+
attn_mask_type="block_causal",
60+
),
5061
"70B": TransformerModelArgs(
5162
dim=8192,
5263
n_layers=80,

torchtitan/models/llama3/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
5151
"See https://github.com/pytorch/pytorch/issues/147879"
5252
)
5353

54-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
55-
raise ValueError(
56-
"FlexAttention is not compatible with CP yet. "
57-
"We are still working on this."
58-
)
59-
6054
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
6155
nparams = sum(p.numel() for p in model.parameters())
6256
nparams_embedding = sum(

torchtitan/models/llama3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ batch_size = 8
4343
seq_len = 2048
4444
max_norm = 1.0 # grad norm clipping
4545
steps = 10
46-
compile = false
46+
compile = true
4747
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4848

4949
[parallelism]

torchtitan/models/llama3/train_configs/llama3_8b.toml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ save_traces_folder = "profile_trace"
1111
profile_freq = 100
1212

1313
[metrics]
14-
log_freq = 10
14+
log_freq = 50
1515
enable_tensorboard = true
16+
# enable_tensorboard = false
1617
save_tb_folder = "tb"
1718

1819
[model]
@@ -27,22 +28,25 @@ lr = 3e-4
2728
eps = 1e-8
2829

2930
[lr_scheduler]
30-
warmup_steps = 200 # lr scheduler warm up
31+
# warmup_steps = 200 # lr scheduler warm up
32+
warmup_steps = 600
3133

3234
[training]
33-
batch_size = 1
35+
batch_size = 4
3436
seq_len = 8192
3537
max_norm = 1.0 # grad norm clipping
36-
steps = 1000
37-
compile = false
38+
# steps = 1000
39+
steps = 3000
40+
compile = true
3841
dataset = "c4"
42+
deterministic = true
3943

4044
[parallelism]
4145
data_parallel_replicate_degree = 1
4246
data_parallel_shard_degree = -1
4347
tensor_parallel_degree = 1
4448
pipeline_parallel_degree = 1
45-
context_parallel_degree = 1
49+
context_parallel_degree = 4
4650

4751
[checkpoint]
4852
enable_checkpoint = false
@@ -53,7 +57,8 @@ export_dtype = "float32"
5357
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
5458

5559
[activation_checkpoint]
56-
mode = "selective" # ["none", "selective", "full"]
60+
# mode = "selective" # ["none", "selective", "full"]
61+
mode = "full"
5762
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
5863

5964
[float8]

torchtitan/train.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
from typing import Any, Generator, Iterable, Optional
1212

1313
import torch
14-
from torch.distributed.elastic.multiprocessing.errors import record
1514

1615
import torchtitan.components.ft as ft
1716
import torchtitan.protocols.train_spec as train_spec_module
17+
from torch.distributed.elastic.multiprocessing.errors import record
18+
from torch.distributed.tensor.experimental._attention import (
19+
FlexAttentionContiguousSharder,
20+
)
1821

1922
from torchtitan.components.checkpoint import CheckpointManager
2023
from torchtitan.components.metrics import (
@@ -133,7 +136,9 @@ def __init__(self, job_config: JobConfig):
133136

134137
# build model (using meta init)
135138
model_cls = self.train_spec.cls
139+
# NOTE (xilunwu): need to store model_args.use_flex_attn for train_step
136140
model_args = self.train_spec.config[job_config.model.flavor]
141+
self.model_args = model_args
137142
# set the model args from training job configs
138143
model_args.update_from_config(job_config, tokenizer)
139144

@@ -319,13 +324,29 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
319324
# apply context parallelism if cp is enabled
320325
# ensure CP handles the separate freqs_cis buffer for each pp stage
321326
inputs = input_dict["input"]
327+
328+
# TODO: move this into `create_context_parallel_ctx`
329+
# init block_mask for flex_attention
330+
block_mask = None
331+
if self.model_args.use_flex_attn:
332+
from torchtitan.models.attention import FlexAttention
333+
334+
mask_mod = FlexAttention._get_causal_mask_mod()
335+
batch_dimension = 1
336+
seq_len = inputs.shape[1]
337+
block_mask = FlexAttention.compiled_create_block_mask(
338+
mask_mod, batch_dimension, None, seq_len, seq_len
339+
)
340+
322341
optional_context_parallel_ctx = (
323342
dist_utils.create_context_parallel_ctx(
324343
cp_mesh=world_mesh["cp"],
325344
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
326345
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
327346
cp_no_restore_buffers={inputs, labels},
328347
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
348+
block_mask=block_mask,
349+
sharder=FlexAttentionContiguousSharder(),
329350
)
330351
if parallel_dims.cp_enabled
331352
else None

0 commit comments

Comments
 (0)