14
14
from fast_llm .functional .cross_entropy import cross_entropy_forward_backward
15
15
from fast_llm .functional .linear import output_parallel_linear_backward , output_parallel_linear_forward
16
16
from fast_llm .functional .dpo import compute_simplified_dpo_loss
17
- from fast_llm .layers .common .auxiliary_loss import z_loss
17
+ from fast_llm .layers .common .auxiliary_loss import AuxiliaryLoss , z_loss
18
18
from fast_llm .layers .language_model .config import (
19
19
LanguageModelBaseConfig ,
20
20
LanguageModelDimNames ,
25
25
from fast_llm .layers .transformer .config import TransformerDimNames , TransformerKwargs
26
26
from fast_llm .logging import log_distributed_tensor
27
27
from fast_llm .tensor import ParameterMeta , TensorMeta , init_normal_
28
- from fast_llm .utils import div
28
+ from fast_llm .utils import Assert , div
29
+
30
+ OUTPUT_WEIGHTS = "output_weights"
29
31
30
32
31
33
class LanguageModelHead [ConfigType : LanguageModelBaseConfig ](Configurable [LanguageModelBaseConfig ], Layer ):
@@ -39,6 +41,7 @@ def __init__(
39
41
self ,
40
42
config : LanguageModelBaseConfig ,
41
43
tensor_space : TensorSpace ,
44
+ prediction_distance : int ,
42
45
):
43
46
super ().__init__ (config )
44
47
self ._debug_transformer = config .transformer .debug_transformer
@@ -57,23 +60,24 @@ def __init__(
57
60
58
61
hidden_dim = self ._tensor_space .get_tensor_dim (TransformerDimNames .hidden )
59
62
63
+ self ._loss_name = LanguageModelLossNames .multi_token_prediction_loss (prediction_distance )
60
64
self .final_norm = config .transformer .normalization .get_layer (hidden_dim )
61
65
self ._logits_scale_factor = config .logits_scale_factor
62
66
self ._z_loss_factor = config .logit_z_loss
63
67
64
- # untie embedding weights
65
- if not self . _tie_word_embeddings :
66
- vocab_dim = self . _tensor_space . get_tensor_dim (
67
- LanguageModelDimNames . vocab_tp if self . _parallel_embeddings else LanguageModelDimNames . vocab
68
- )
69
- self .output_weights = ParameterMeta . from_dims (
70
- ( vocab_dim , hidden_dim ),
71
- init_method = init_normal_ (
72
- std = config . init_method_std_embed ,
73
- min_val = config . init_method_min_embed ,
74
- max_val = config . init_method_max_embed ,
75
- ),
76
- )
68
+ # Distance of the target token prediction
69
+ # 0: next-token prediction
70
+ # >0: multi-token prediction (MTP)
71
+ Assert . geq ( prediction_distance , 0 )
72
+ self . _prediction_distance = prediction_distance
73
+ self . is_last_head = self ._prediction_distance == config . prediction_heads - 1
74
+ if self . _prediction_distance > 0 :
75
+ assert (
76
+ not self . _sequence_parallel_logits
77
+ ), "Sequence parallel logits not supported for multi-token prediction."
78
+ assert not self . _cross_entropy_splits , "Cross-entropy splits not supported for multi-token prediction."
79
+
80
+ self . _init_output_weights ( hidden_dim , config )
77
81
78
82
self ._loss_function_type = config .loss_function_type
79
83
if self ._loss_function_type == LossFunctionType .cross_entropy :
@@ -97,6 +101,23 @@ def __init__(
97
101
if hasattr (self , "output_weights" ):
98
102
self .output_weights = self ._config .transformer .peft .apply_weight (self .output_weights )
99
103
104
+ def _init_output_weights (self , hidden_dim : TensorDim , config ) -> None :
105
+ # Only the first head defines the output weights
106
+ if self ._tie_word_embeddings or self ._prediction_distance > 0 :
107
+ return
108
+ # untie embedding weights
109
+ vocab_dim = self ._tensor_space .get_tensor_dim (
110
+ LanguageModelDimNames .vocab_tp if self ._parallel_embeddings else LanguageModelDimNames .vocab
111
+ )
112
+ self .output_weights = ParameterMeta .from_dims (
113
+ (vocab_dim , hidden_dim ),
114
+ init_method = init_normal_ (
115
+ std = config .init_method_std_embed ,
116
+ min_val = config .init_method_min_embed ,
117
+ max_val = config .init_method_max_embed ,
118
+ ),
119
+ )
120
+
100
121
def forward (
101
122
self , input_ : torch .Tensor , kwargs : dict , losses : dict | None = None , metrics : dict | None = None
102
123
) -> torch .Tensor :
@@ -107,33 +128,50 @@ def forward(
107
128
tensor_name = "Loss" ,
108
129
reductions = ((DistributedDimNames .data , ReduceOp .AVG ),), # noqa
109
130
)
131
+ if not self .is_last_head :
132
+ # MTP: split the stacked input
133
+ shared_hidden , input_ = torch .unbind (input_ , dim = 0 )
110
134
# TODO: Pytorch copies the grads in backward for no reason (not sure if still the case)
111
135
# TODO: Torch compile implementation sometimes break.
112
136
# TODO: Double-check correctness, optimize a bit more.
113
137
# TODO: Drop autograd entirely.
114
138
# TODO: Skip cross-entropy backward if not needed.
115
139
language_model_loss = self ._forward (input_ , kwargs , losses )
116
140
if language_model_loss is not None :
117
- losses [LanguageModelLossNames . language_model_loss ].append (language_model_loss )
141
+ losses [self . _loss_name ].append (language_model_loss )
118
142
# TODO: Return the model output when needed.
119
- return language_model_loss
143
+ if self .is_last_head :
144
+ # Last head should return the loss for backward.
145
+ return language_model_loss
146
+ else :
147
+ # Backward hook to compute the gradient of the loss
148
+ shared_hidden = AuxiliaryLoss .apply (shared_hidden , language_model_loss , 1.0 )
149
+ # MTP: Return shared_hidden to be used by the next head.
150
+ return shared_hidden
120
151
121
152
def _forward_backward (
122
153
self , input_ : torch .Tensor , kwargs : dict , losses : dict | None = None
123
154
) -> tuple [torch .Tensor , torch .Tensor | None ]:
124
- labels = kwargs [LanguageModelKwargs .labels ].flatten () if LanguageModelKwargs .labels in kwargs else None
155
+ labels = kwargs [LanguageModelKwargs .labels ] if LanguageModelKwargs .labels in kwargs else None
156
+ # MTP: Shift the labels
157
+ labels = labels [:, self ._prediction_distance :].flatten () if labels is not None else None
125
158
if self ._sequence_parallel_logits :
126
159
labels = split_op (labels , self ._tensor_space .distributed .tensor_group , 0 )
127
160
do_grad = labels is not None and self .training
128
161
input_ = input_ .detach ().requires_grad_ (do_grad )
129
162
with torch .enable_grad ():
130
- ln_output = self .final_norm (input_ )
163
+ # MTP: truncate the input
164
+ if self ._prediction_distance > 0 :
165
+ truncated_input = input_ [:, : - self ._prediction_distance , :].contiguous ()
166
+ else :
167
+ truncated_input = input_
168
+ ln_output = self .final_norm (truncated_input )
131
169
132
170
grad_output = kwargs [TransformerKwargs .grad_output ] / (
133
171
self ._group_size if self ._sequence_parallel_logits else 1
134
172
)
135
173
136
- output_weights = kwargs [ WORD_EMBEDDINGS_WEIGHT ] if self ._tie_word_embeddings else self . output_weights
174
+ output_weights = self ._get_output_weights ( kwargs )
137
175
loss , ln_output_grad = self ._loss_fcn (
138
176
ln_output .detach (), labels , output_weights , grad_output , kwargs , losses
139
177
)
@@ -176,6 +214,13 @@ def _logits_dpo(
176
214
177
215
178
216
217
+ def _get_output_weights (self , kwargs : dict ) -> torch .Tensor :
218
+ if self ._tie_word_embeddings :
219
+ return kwargs [WORD_EMBEDDINGS_WEIGHT ]
220
+ if self ._prediction_distance > 0 :
221
+ return kwargs [OUTPUT_WEIGHTS ]
222
+ return self .output_weights
223
+
179
224
def _logits_cross_entropy_forward_backward_split (
180
225
self ,
181
226
input_ : torch .Tensor ,
@@ -195,6 +240,7 @@ def _logits_cross_entropy_forward_backward_split(
195
240
return None , None
196
241
else :
197
242
loss = None
243
+ # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length
198
244
split_size = div (labels .numel (), self ._cross_entropy_splits )
199
245
grad_output /= self ._cross_entropy_splits
200
246
logit_input = input_ .flatten (0 , - 2 )
0 commit comments