Skip to content

Commit b9e5954

Browse files
committed
[do NOT land] llama-3 8B w/ flex_attention
ghstack-source-id: c2d400e Pull Request resolved: #1181
1 parent a4ed09c commit b9e5954

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

torchtitan/models/llama3/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
multiple_of=1024,
4848
rope_theta=500000,
4949
),
50+
"8B_flex_attn": TransformerModelArgs(
51+
dim=4096,
52+
n_layers=32,
53+
n_heads=32,
54+
n_kv_heads=8,
55+
ffn_dim_multiplier=1.3,
56+
multiple_of=1024,
57+
rope_theta=500000,
58+
use_flex_attn=True,
59+
attn_mask_type="block_causal",
60+
),
5061
"70B": TransformerModelArgs(
5162
dim=8192,
5263
n_layers=80,

torchtitan/models/llama3/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
5151
"See https://github.com/pytorch/pytorch/issues/147879"
5252
)
5353

54-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
55-
raise ValueError(
56-
"FlexAttention is not compatible with CP yet. "
57-
"We are still working on this."
58-
)
59-
6054
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
6155
nparams = sum(p.numel() for p in model.parameters())
6256
nparams_embedding = sum(

0 commit comments

Comments
 (0)