|
11 | 11 | from typing import Any, Generator, Iterable, Optional
|
12 | 12 |
|
13 | 13 | import torch
|
14 |
| -from torch.distributed.elastic.multiprocessing.errors import record |
15 | 14 |
|
16 | 15 | import torchtitan.components.ft as ft
|
17 | 16 | import torchtitan.protocols.train_spec as train_spec_module
|
| 17 | +from torch.distributed.elastic.multiprocessing.errors import record |
| 18 | +from torch.distributed.tensor.experimental._attention import ( |
| 19 | + FlexAttentionContiguousSharder, |
| 20 | +) |
18 | 21 |
|
19 | 22 | from torchtitan.components.checkpoint import CheckpointManager
|
20 | 23 | from torchtitan.components.metrics import (
|
@@ -133,7 +136,9 @@ def __init__(self, job_config: JobConfig):
|
133 | 136 |
|
134 | 137 | # build model (using meta init)
|
135 | 138 | model_cls = self.train_spec.cls
|
| 139 | + # NOTE (xilunwu): need to store model_args.use_flex_attn for train_step |
136 | 140 | model_args = self.train_spec.config[job_config.model.flavor]
|
| 141 | + self.model_args = model_args |
137 | 142 | # set the model args from training job configs
|
138 | 143 | model_args.update_from_config(job_config, tokenizer)
|
139 | 144 |
|
@@ -319,13 +324,29 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
|
319 | 324 | # apply context parallelism if cp is enabled
|
320 | 325 | # ensure CP handles the separate freqs_cis buffer for each pp stage
|
321 | 326 | inputs = input_dict["input"]
|
| 327 | + |
| 328 | + # TODO: move this into `create_context_parallel_ctx` |
| 329 | + # init block_mask for flex_attention |
| 330 | + block_mask = None |
| 331 | + if self.model_args.use_flex_attn: |
| 332 | + from torchtitan.models.attention import FlexAttention |
| 333 | + |
| 334 | + mask_mod = FlexAttention._get_causal_mask_mod() |
| 335 | + batch_dimension = 1 |
| 336 | + seq_len = inputs.shape[1] |
| 337 | + block_mask = FlexAttention.compiled_create_block_mask( |
| 338 | + mask_mod, batch_dimension, None, seq_len, seq_len |
| 339 | + ) |
| 340 | + |
322 | 341 | optional_context_parallel_ctx = (
|
323 | 342 | dist_utils.create_context_parallel_ctx(
|
324 | 343 | cp_mesh=world_mesh["cp"],
|
325 | 344 | cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
|
326 | 345 | cp_seq_dims=[1, 1] + [0 for _ in model_parts],
|
327 | 346 | cp_no_restore_buffers={inputs, labels},
|
328 | 347 | cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
|
| 348 | + block_mask=block_mask, |
| 349 | + sharder=FlexAttentionContiguousSharder(), |
329 | 350 | )
|
330 | 351 | if parallel_dims.cp_enabled
|
331 | 352 | else None
|
|
0 commit comments