Skip to content

Commit 30bd26c

Browse files
committed
[cp][flex_attention] integration test trial
ghstack-source-id: 9d38242 Pull-Request-resolved: #1160 ghstack-source-id: 53d803f Pull Request resolved: #1228
1 parent 29a67ec commit 30bd26c

File tree

10 files changed

+87
-22
lines changed

10 files changed

+87
-22
lines changed

torchtitan/distributed/utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
import os
1010
from collections.abc import Generator, Iterable
1111
from datetime import timedelta
12+
from typing import Optional
1213

1314
import torch
1415
import torch.distributed._functional_collectives as funcol
1516
import torch.distributed.distributed_c10d as c10d
1617
from torch import distributed as dist
1718
from torch.distributed.device_mesh import DeviceMesh
1819
from torch.distributed.tensor import DTensor
20+
from torch.distributed.tensor.experimental._attention import _FlexAttentionSharder
1921
from torch.nn.attention import SDPBackend
22+
from torch.nn.attention.flex_attention import BlockMask
2023

2124
from torchtitan.models.attention import ScaledDotProductAttention
2225
from torchtitan.tools.logging import logger
@@ -156,22 +159,35 @@ def create_context_parallel_ctx(
156159
cp_seq_dims: list[int],
157160
cp_no_restore_buffers: set[torch.Tensor],
158161
cp_rotate_method: str,
162+
sharder: Optional[_FlexAttentionSharder] = None,
159163
):
160164
try:
161165
from torch.distributed.tensor.experimental import context_parallel
162-
from torch.distributed.tensor.experimental._attention import set_rotate_method
166+
from torch.distributed.tensor.experimental._attention import (
167+
_DispatchMode,
168+
_set_dispatch_mode,
169+
set_rotate_method,
170+
)
163171
except ImportError:
164172
print(
165173
f"PyTorch version {torch.__version__} does not include the experimental "
166174
"Context Parallel API. Please update to a newer version."
167175
)
168176

169177
set_rotate_method(cp_rotate_method)
178+
"""
179+
_set_dispatch_mode("torch_dispatch")
180+
assert (
181+
torch.distributed.tensor.experimental._attention._dispatch_mode
182+
== _DispatchMode.TORCH_DISPATCH
183+
)
184+
"""
170185
return context_parallel(
171186
cp_mesh,
172187
buffers=cp_buffers,
173188
buffer_seq_dims=cp_seq_dims,
174189
no_restore_buffers=cp_no_restore_buffers,
190+
sharder=sharder,
175191
)
176192

177193

@@ -192,8 +208,9 @@ def context(cp_context: Generator[None, None, None] | None = None):
192208
if cp_context is not None:
193209
if SDPBackend.MATH in ScaledDotProductAttention.backends:
194210
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
211+
# TODO: add logic for flex-attention
195212
assert (
196-
ScaledDotProductAttention.backends
213+
ScaledDotProductAttention.backends or True
197214
), "No valid SDPA backends with CP."
198215
stack.enter_context(cp_context)
199216

torchtitan/experiments/llama4/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
rope_theta=500000,
4141
num_experts=16,
4242
interleave_moe_layer_step=1,
43+
use_flex_attn=True,
44+
attn_mask_type="block_causal",
45+
# attn_mask_type="causal",
4346
),
4447
"17bx128e": TransformerModelArgs(
4548
dim=5120,

torchtitan/experiments/llama4/model/args.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class TransformerModelArgs(BaseModelArgs):
5555
interleave_moe_layer_step: int = 2
5656
# token-choice
5757
top_k: int = 1
58-
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
58+
use_grouped_mm: bool = False # grouped mm or for-loop for the experts computation
5959
load_balance_coeff: float | None = 1e-3
6060

6161
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
@@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
7474
"FlexAttention is not compatible with selective AC yet. "
7575
"See https://github.com/pytorch/pytorch/issues/147879"
7676
)
77-
77+
"""
7878
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
7979
raise ValueError(
8080
"FlexAttention is not compatible with CP yet. "
8181
"We are still working on this."
8282
)
83+
"""
8384

8485
def get_nparams_and_flops(
8586
self, model: nn.Module, seq_len: int

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ batch_size = 8
4141
seq_len = 2048
4242
max_norm = 1.0 # grad norm clipping
4343
steps = 10
44-
compile = false
44+
compile = true
4545
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4646

4747
[parallelism]

torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ profile_freq = 100
1111

1212
[metrics]
1313
log_freq = 10
14-
enable_tensorboard = false
14+
enable_tensorboard = true
1515
save_tb_folder = "tb"
1616

1717
[model]
@@ -27,23 +27,31 @@ eps = 1e-15
2727

2828
[lr_scheduler]
2929
warmup_steps = 600
30+
# warmup_steps = 20
3031
lr_min = 0.1
3132

3233
[training]
33-
batch_size = 8
34+
# batch_size = 8
35+
batch_size = 4
3436
seq_len = 8192
37+
# seq_len = 16384
38+
# seq_len = 32768
39+
# seq_len = 65536
3540
max_norm = 1.0 # grad norm clipping
3641
steps = 3000
42+
# steps = 100
3743
compile = false
44+
# compile = true
3845
dataset = "c4"
46+
deterministic = true
3947

4048
[parallelism]
4149
data_parallel_replicate_degree = 1
4250
data_parallel_shard_degree = -1
4351
tensor_parallel_degree = 8
4452
enable_async_tensor_parallel = false
4553
pipeline_parallel_degree = 1
46-
context_parallel_degree = 1
54+
context_parallel_degree = 4
4755

4856
[checkpoint]
4957
enable_checkpoint = false

torchtitan/models/llama3/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
multiple_of=1024,
4848
rope_theta=500000,
4949
),
50+
"8B_flex_attn": TransformerModelArgs(
51+
dim=4096,
52+
n_layers=32,
53+
n_heads=32,
54+
n_kv_heads=8,
55+
ffn_dim_multiplier=1.3,
56+
multiple_of=1024,
57+
rope_theta=500000,
58+
use_flex_attn=True,
59+
attn_mask_type="block_causal",
60+
),
5061
"70B": TransformerModelArgs(
5162
dim=8192,
5263
n_layers=80,

torchtitan/models/llama3/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
5151
"See https://github.com/pytorch/pytorch/issues/147879"
5252
)
5353

54-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
55-
raise ValueError(
56-
"FlexAttention is not compatible with CP yet. "
57-
"We are still working on this."
58-
)
59-
6054
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
6155
nparams = sum(p.numel() for p in model.parameters())
6256
nparams_embedding = sum(

torchtitan/models/llama3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ batch_size = 8
4343
seq_len = 2048
4444
max_norm = 1.0 # grad norm clipping
4545
steps = 10
46-
compile = false
46+
compile = true
4747
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4848

4949
[parallelism]

torchtitan/models/llama3/train_configs/llama3_8b.toml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@ description = "Llama 3 8B training"
88
[profiling]
99
enable_profiling = true
1010
save_traces_folder = "profile_trace"
11-
profile_freq = 100
11+
# profile_freq = 100
12+
profile_freq = 10
13+
enable_memory_snapshot = true
14+
save_memory_snapshot_folder = "memory_snapshot"
1215

1316
[metrics]
1417
log_freq = 10
1518
enable_tensorboard = true
19+
# enable_tensorboard = false
1620
save_tb_folder = "tb"
1721

1822
[model]
@@ -27,22 +31,25 @@ lr = 3e-4
2731
eps = 1e-8
2832

2933
[lr_scheduler]
30-
warmup_steps = 200 # lr scheduler warm up
34+
# warmup_steps = 200 # lr scheduler warm up
35+
warmup_steps = 600
3136

3237
[training]
33-
batch_size = 1
38+
batch_size = 4
3439
seq_len = 8192
3540
max_norm = 1.0 # grad norm clipping
36-
steps = 1000
41+
# steps = 1000
42+
steps = 20
3743
compile = false
3844
dataset = "c4"
45+
deterministic = false
3946

4047
[parallelism]
4148
data_parallel_replicate_degree = 1
4249
data_parallel_shard_degree = -1
4350
tensor_parallel_degree = 1
4451
pipeline_parallel_degree = 1
45-
context_parallel_degree = 1
52+
context_parallel_degree = 4
4653

4754
[checkpoint]
4855
enable_checkpoint = false
@@ -53,7 +60,8 @@ export_dtype = "float32"
5360
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
5461

5562
[activation_checkpoint]
56-
mode = "selective" # ["none", "selective", "full"]
63+
# mode = "selective" # ["none", "selective", "full"]
64+
mode = "full"
5765
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
5866

5967
[float8]

torchtitan/train.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
from typing import Any, Generator, Iterable, Optional
1212

1313
import torch
14-
from torch.distributed.elastic.multiprocessing.errors import record
1514

1615
import torchtitan.components.ft as ft
1716
import torchtitan.protocols.train_spec as train_spec_module
17+
from torch.distributed.elastic.multiprocessing.errors import record
18+
from torch.distributed.tensor.experimental._attention import (
19+
_FlexAttentionSequentialSharder,
20+
)
21+
1822
from torchtitan.components.checkpoint import CheckpointManager
1923
from torchtitan.components.metrics import (
2024
build_metrics_processor,
@@ -132,7 +136,9 @@ def __init__(self, job_config: JobConfig):
132136

133137
# build model (using meta init)
134138
model_cls = self.train_spec.cls
139+
# NOTE (xilunwu): need to store model_args.use_flex_attn for train_step
135140
model_args = self.train_spec.config[job_config.model.flavor]
141+
self.model_args = model_args
136142
# set the model args from training job configs
137143
model_args.update_from_config(job_config, tokenizer)
138144

@@ -323,13 +329,30 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
323329
# apply context parallelism if cp is enabled
324330
# ensure CP handles the separate freqs_cis buffer for each pp stage
325331
inputs = input_dict["input"]
332+
333+
# TODO: move this into `create_context_parallel_ctx`
334+
# init block_mask for flex_attention
335+
block_mask = None
336+
if self.model_args.use_flex_attn:
337+
from torchtitan.models.attention import FlexAttention
338+
339+
mask_mod = FlexAttention._get_causal_mask_mod()
340+
batch_dimension = 1
341+
seq_len = inputs.shape[1]
342+
block_mask = FlexAttention.compiled_create_block_mask(
343+
mask_mod, batch_dimension, None, seq_len, seq_len
344+
)
345+
326346
optional_context_parallel_ctx = (
327347
dist_utils.create_context_parallel_ctx(
328348
cp_mesh=world_mesh["cp"],
329349
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
330350
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
331351
cp_no_restore_buffers={inputs, labels},
332352
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
353+
sharder=_FlexAttentionSequentialSharder(
354+
mesh=world_mesh["cp"], block_mask=block_mask
355+
),
333356
)
334357
if parallel_dims.cp_enabled
335358
else None

0 commit comments

Comments
 (0)