Skip to content

Commit 9d99dc2

Browse files
authored
Generalize preprocessor (#224)
1 parent 3daf079 commit 9d99dc2

File tree

6 files changed

+51
-44
lines changed

6 files changed

+51
-44
lines changed

fast_llm/engine/base_model/base_model.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn
77

88
from fast_llm.config import Configurable
9-
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
9+
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig, Preprocessor
1010
from fast_llm.engine.config_utils.tensor_space import TensorSpace
1111
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
1212
from fast_llm.engine.distributed.distributed import Distributed
@@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
135135
@abc.abstractmethod
136136
def loss_defs(self) -> list[LossDef]:
137137
pass
138+
139+
def add_preprocessor(self, preprocessor: Preprocessor):
140+
# TODO: Generalize preprocessors.
141+
raise NotImplementedError()

fast_llm/engine/base_model/config.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import abc
12
import typing
23

34
from fast_llm.config import Config, config_class
@@ -40,3 +41,12 @@ class BaseModelConfig(BaseModelArchitectureConfig):
4041

4142
def get_architecture(self) -> BaseModelArchitectureConfig:
4243
return self.architecture_class.from_dict(self, strict=False)
44+
45+
46+
class Preprocessor(abc.ABC):
47+
def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
48+
pass
49+
50+
@abc.abstractmethod
51+
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
52+
pass

fast_llm/layers/language_model/preprocessing.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from fast_llm.engine.base_model.config import Preprocessor
67
from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace
78
from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs
89
from fast_llm.layers.transformer.config import TransformerKwargs
@@ -12,7 +13,7 @@
1213
logger = logging.getLogger(__name__)
1314

1415

15-
class PositionEmbeddingPreprocessor:
16+
class PositionEmbeddingPreprocessor(Preprocessor):
1617
_scalar_dim: TensorDim
1718
_rotary_embedding_frequencies: torch.Tensor
1819
_position_ids: torch.Tensor
@@ -29,7 +30,7 @@ def __init__(
2930
self._distributed_config = self._tensor_space.distributed_config
3031
self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar)
3132

32-
def create_tensors(self, sequence_length: int) -> None:
33+
def _create_tensors(self, sequence_length: int) -> None:
3334
if sequence_length <= self._tensor_cache_max_sequence_length:
3435
return
3536
self._tensor_cache_max_sequence_length = sequence_length
@@ -39,7 +40,8 @@ def create_tensors(self, sequence_length: int) -> None:
3940
0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64
4041
)
4142

42-
def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
43+
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
44+
self._create_tensors(kwargs[TransformerKwargs.sequence_length])
4345
sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size
4446
sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size
4547
if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths)) is not None:

fast_llm/layers/transformer/preprocessing.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from fast_llm.engine.base_model.config import Preprocessor
78
from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace
89
from fast_llm.functional.rotary import convert_rotary_complex_to_real
910
from fast_llm.layers.transformer.config import (
@@ -129,7 +130,7 @@ def get_rotary_frequencies(
129130
return frequencies
130131

131132

132-
class RotaryEmbeddingPreprocessor:
133+
class RotaryEmbeddingPreprocessor(Preprocessor):
133134
_scalar_dim: TensorDim
134135
_kv_channels_dim: TensorDim
135136
_rotary_embedding_frequencies: torch.Tensor
@@ -149,7 +150,7 @@ def __init__(
149150
self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar)
150151
self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels)
151152

152-
def create_tensors(self, sequence_length: int) -> None:
153+
def _create_tensors(self, sequence_length: int) -> None:
153154
if sequence_length <= self._tensor_cache_max_sequence_length:
154155
return
155156
self._tensor_cache_max_sequence_length = sequence_length
@@ -161,7 +162,8 @@ def create_tensors(self, sequence_length: int) -> None:
161162
device=self._tensor_space.distributed.device,
162163
)
163164

164-
def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
165+
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
166+
self._create_tensors(kwargs[TransformerKwargs.sequence_length])
165167
sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size
166168
kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[
167169
:, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k
@@ -189,7 +191,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
189191
)
190192

191193

192-
class BackupAttentionPreprocessor:
194+
class BackupAttentionPreprocessor(Preprocessor):
193195
_scalar_dim: TensorDim
194196
_kv_channels_dim: TensorDim
195197
_rotary_embedding_frequencies: torch.Tensor
@@ -208,7 +210,7 @@ def __init__(
208210
assert not self._config.do_use_flash_attention(self._distributed_config)
209211
self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar)
210212

211-
def create_tensors(self, sequence_length: int) -> None:
213+
def _create_tensors(self, sequence_length: int) -> None:
212214
if sequence_length <= self._tensor_cache_max_sequence_length:
213215
return
214216
self._tensor_cache_max_sequence_length = sequence_length
@@ -228,7 +230,8 @@ def create_tensors(self, sequence_length: int) -> None:
228230
device=self._tensor_space.distributed.device,
229231
)
230232

231-
def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
233+
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
234+
self._create_tensors(kwargs[TransformerKwargs.sequence_length])
232235
sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size
233236
sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size
234237
kwargs[TransformerKwargs.attention_mask] = self._mask[
@@ -264,14 +267,14 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
264267
)
265268

266269

267-
class FlashAttnVarlenPreprocessor:
270+
class FlashAttnVarlenPreprocessor(Preprocessor):
268271
def __init__(self, config: TransformerConfig, tensor_space: TensorSpace):
269272
self._config = config
270273
self._tensor_space = tensor_space
271274
self._distributed_config = self._tensor_space.distributed_config
272275
assert self._config.do_use_flash_attention(self._distributed_config)
273276

274-
def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
277+
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
275278
"""
276279
Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func:
277280
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375
@@ -281,7 +284,9 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
281284
also contain previous tokens from the first document in micro-sequence.
282285
We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths.
283286
"""
284-
sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths)
287+
if TransformerKwargs.sequence_lengths not in kwargs:
288+
return
289+
sequence_lengths = kwargs[TransformerKwargs.sequence_lengths]
285290
sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size
286291
sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size
287292
if sequence_q < kwargs[TransformerKwargs.sequence_length]:

fast_llm/models/gpt/model.py

+16-30
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from fast_llm.data.data.gpt.data import GPTBatch
77
from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef
8+
from fast_llm.engine.base_model.config import Preprocessor
89
from fast_llm.engine.config_utils.tensor_space import TensorDim
910
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType
1011
from fast_llm.engine.distributed.distributed import Distributed
@@ -58,18 +59,17 @@ def __init__(
5859
for param in self.parameters():
5960
Assert.custom(isinstance, param, ParameterMeta)
6061
param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa
62+
self._preprocessors: list[Preprocessor] = []
6163
if self._config.use_absolute_position_embeddings:
62-
self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space)
64+
self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space))
6365
if self._config.transformer.rotary.enabled:
64-
self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor(
65-
self._config.transformer.rotary, self._tensor_space
66-
)
67-
if not self._use_flash_attention:
68-
self._backup_attention_preprocessor = BackupAttentionPreprocessor(
69-
self._config.transformer, self._tensor_space
66+
self._preprocessors.append(
67+
RotaryEmbeddingPreprocessor(self._config.transformer.rotary, self._tensor_space)
7068
)
69+
if self._use_flash_attention:
70+
self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space))
7171
else:
72-
self._flash_varlen_preprocessor = FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)
72+
self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space))
7373

7474
def get_output_layers(self) -> list[Layer]:
7575
return [
@@ -207,12 +207,8 @@ def preprocess_meta(
207207
kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims(
208208
hidden_dims[:2], tensor_name="labels", dtype=torch.int64
209209
)
210-
if self._config.use_absolute_position_embeddings:
211-
self._position_embedding_preprocessor.preprocess_meta(kwargs)
212-
if self._config.transformer.rotary.enabled:
213-
self._rotary_embedding_preprocessor.preprocess_meta(kwargs)
214-
if not self._use_flash_attention:
215-
self._backup_attention_preprocessor.preprocess_meta(kwargs)
210+
for preprocessor in self._preprocessors:
211+
preprocessor.preprocess_meta(kwargs)
216212
preprocessed_meta.append((tokens, kwargs))
217213

218214
return preprocessed_meta
@@ -235,7 +231,6 @@ def preprocess(
235231
_, common_kwargs = preprocessed_meta[0]
236232
sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size
237233
sequence_first = common_kwargs[TransformerKwargs.sequence_first]
238-
sequence_length = common_kwargs[TransformerKwargs.sequence_length]
239234

240235
batch.token_ids = batch.token_ids.to(
241236
device=self._tensor_space.distributed.device,
@@ -246,13 +241,6 @@ def preprocess(
246241
# Move the sequence dimension first to make sequence parallel ops more efficient.
247242
batch.token_ids = batch.token_ids.transpose(0, 1).contiguous()
248243

249-
if self._config.use_absolute_position_embeddings:
250-
self._position_embedding_preprocessor.create_tensors(sequence_length)
251-
if self._config.transformer.rotary.enabled:
252-
self._rotary_embedding_preprocessor.create_tensors(sequence_length)
253-
if not self._use_flash_attention:
254-
self._backup_attention_preprocessor.create_tensors(sequence_length)
255-
256244
preprocessed = []
257245
presents = None
258246
for i, (tokens_meta, kwargs_meta) in enumerate(preprocessed_meta):
@@ -264,8 +252,6 @@ def preprocess(
264252
tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous()
265253
if batch.sequence_lengths is not None:
266254
kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths
267-
if self._use_flash_attention:
268-
self._flash_varlen_preprocessor.preprocess(kwargs_meta)
269255

270256
# TODO: Add pasts/presents to meta input?
271257
# Use lists as pointers so `past_key_values` is populated during the previous micro_sequence.
@@ -300,12 +286,8 @@ def preprocess(
300286
else:
301287
labels[i, start : end + 1] = -100
302288
kwargs[LanguageModelKwargs.labels] = labels
303-
if self._config.use_absolute_position_embeddings:
304-
self._position_embedding_preprocessor.preprocess(kwargs)
305-
if self._config.transformer.rotary.enabled:
306-
self._rotary_embedding_preprocessor.preprocess(kwargs)
307-
if not self._use_flash_attention:
308-
self._backup_attention_preprocessor.preprocess(kwargs)
289+
for preprocessor in self._preprocessors:
290+
preprocessor.preprocess(tokens, kwargs)
309291
preprocessed.append((tokens, kwargs))
310292

311293
return preprocessed
@@ -379,6 +361,10 @@ def loss_defs(self) -> list[LossDef]:
379361
)
380362
return loss_defs
381363

364+
def add_preprocessor(self, preprocessor: Preprocessor):
365+
assert not self._is_setup
366+
self._preprocessors.append(preprocessor)
367+
382368

383369
class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]):
384370
config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig

tests/test_attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,6 @@ def test_varlen_preprocessor():
8484
TransformerKwargs.sequence_length: sequence_length,
8585
TransformerKwargs.sequence_lengths: sequence_lengths,
8686
}
87-
varlen_preprocessor.preprocess(kwargs)
87+
varlen_preprocessor.preprocess(None, kwargs)
8888
Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx])
8989
Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx])

0 commit comments

Comments
 (0)