5
5
6
6
from fast_llm .data .data .gpt .data import GPTBatch
7
7
from fast_llm .engine .base_model .base_model import BaseModel , Layer , LossDef
8
+ from fast_llm .engine .base_model .config import Preprocessor
8
9
from fast_llm .engine .config_utils .tensor_space import TensorDim
9
10
from fast_llm .engine .distributed .config import DistributedConfig , DistributedDimNames , PhaseType
10
11
from fast_llm .engine .distributed .distributed import Distributed
@@ -58,18 +59,17 @@ def __init__(
58
59
for param in self .parameters ():
59
60
Assert .custom (isinstance , param , ParameterMeta )
60
61
param .init_parameter = get_init_megatron (param , self ._config .transformer ) # Noqa
62
+ self ._preprocessors : list [Preprocessor ] = []
61
63
if self ._config .use_absolute_position_embeddings :
62
- self ._position_embedding_preprocessor = PositionEmbeddingPreprocessor (self ._config , self ._tensor_space )
64
+ self ._preprocessors . append ( PositionEmbeddingPreprocessor (self ._config , self ._tensor_space ) )
63
65
if self ._config .transformer .rotary .enabled :
64
- self ._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor (
65
- self ._config .transformer .rotary , self ._tensor_space
66
- )
67
- if not self ._use_flash_attention :
68
- self ._backup_attention_preprocessor = BackupAttentionPreprocessor (
69
- self ._config .transformer , self ._tensor_space
66
+ self ._preprocessors .append (
67
+ RotaryEmbeddingPreprocessor (self ._config .transformer .rotary , self ._tensor_space )
70
68
)
69
+ if self ._use_flash_attention :
70
+ self ._preprocessors .append (FlashAttnVarlenPreprocessor (self ._config .transformer , self ._tensor_space ))
71
71
else :
72
- self ._flash_varlen_preprocessor = FlashAttnVarlenPreprocessor ( self ._config .transformer , self ._tensor_space )
72
+ self ._preprocessors . append ( BackupAttentionPreprocessor ( self ._config .transformer , self ._tensor_space ) )
73
73
74
74
def get_output_layers (self ) -> list [Layer ]:
75
75
return [
@@ -207,12 +207,8 @@ def preprocess_meta(
207
207
kwargs [LanguageModelKwargs .labels ] = TensorMeta .from_dims (
208
208
hidden_dims [:2 ], tensor_name = "labels" , dtype = torch .int64
209
209
)
210
- if self ._config .use_absolute_position_embeddings :
211
- self ._position_embedding_preprocessor .preprocess_meta (kwargs )
212
- if self ._config .transformer .rotary .enabled :
213
- self ._rotary_embedding_preprocessor .preprocess_meta (kwargs )
214
- if not self ._use_flash_attention :
215
- self ._backup_attention_preprocessor .preprocess_meta (kwargs )
210
+ for preprocessor in self ._preprocessors :
211
+ preprocessor .preprocess_meta (kwargs )
216
212
preprocessed_meta .append ((tokens , kwargs ))
217
213
218
214
return preprocessed_meta
@@ -235,7 +231,6 @@ def preprocess(
235
231
_ , common_kwargs = preprocessed_meta [0 ]
236
232
sequence_q = common_kwargs [TransformerKwargs .sequence_q_dim ].size
237
233
sequence_first = common_kwargs [TransformerKwargs .sequence_first ]
238
- sequence_length = common_kwargs [TransformerKwargs .sequence_length ]
239
234
240
235
batch .token_ids = batch .token_ids .to (
241
236
device = self ._tensor_space .distributed .device ,
@@ -246,13 +241,6 @@ def preprocess(
246
241
# Move the sequence dimension first to make sequence parallel ops more efficient.
247
242
batch .token_ids = batch .token_ids .transpose (0 , 1 ).contiguous ()
248
243
249
- if self ._config .use_absolute_position_embeddings :
250
- self ._position_embedding_preprocessor .create_tensors (sequence_length )
251
- if self ._config .transformer .rotary .enabled :
252
- self ._rotary_embedding_preprocessor .create_tensors (sequence_length )
253
- if not self ._use_flash_attention :
254
- self ._backup_attention_preprocessor .create_tensors (sequence_length )
255
-
256
244
preprocessed = []
257
245
presents = None
258
246
for i , (tokens_meta , kwargs_meta ) in enumerate (preprocessed_meta ):
@@ -264,8 +252,6 @@ def preprocess(
264
252
tokens = batch .token_ids [:, sequence_k - sequence_q : sequence_k ].contiguous ()
265
253
if batch .sequence_lengths is not None :
266
254
kwargs_meta [TransformerKwargs .sequence_lengths ] = batch .sequence_lengths
267
- if self ._use_flash_attention :
268
- self ._flash_varlen_preprocessor .preprocess (kwargs_meta )
269
255
270
256
# TODO: Add pasts/presents to meta input?
271
257
# Use lists as pointers so `past_key_values` is populated during the previous micro_sequence.
@@ -300,12 +286,8 @@ def preprocess(
300
286
else :
301
287
labels [i , start : end + 1 ] = - 100
302
288
kwargs [LanguageModelKwargs .labels ] = labels
303
- if self ._config .use_absolute_position_embeddings :
304
- self ._position_embedding_preprocessor .preprocess (kwargs )
305
- if self ._config .transformer .rotary .enabled :
306
- self ._rotary_embedding_preprocessor .preprocess (kwargs )
307
- if not self ._use_flash_attention :
308
- self ._backup_attention_preprocessor .preprocess (kwargs )
289
+ for preprocessor in self ._preprocessors :
290
+ preprocessor .preprocess (tokens , kwargs )
309
291
preprocessed .append ((tokens , kwargs ))
310
292
311
293
return preprocessed
@@ -379,6 +361,10 @@ def loss_defs(self) -> list[LossDef]:
379
361
)
380
362
return loss_defs
381
363
364
+ def add_preprocessor (self , preprocessor : Preprocessor ):
365
+ assert not self ._is_setup
366
+ self ._preprocessors .append (preprocessor )
367
+
382
368
383
369
class GPTModel [ConfigType : GPTModelConfig ](FastLLMModel [ConfigType ]):
384
370
config_class : typing .ClassVar [type [GPTModelConfig ]] = GPTModelConfig
0 commit comments