13
13
from fast_llm .functional .config import CrossEntropyImpl , TritonConfig
14
14
from fast_llm .functional .cross_entropy import cross_entropy_forward_backward
15
15
from fast_llm .functional .linear import output_parallel_linear_backward , output_parallel_linear_forward
16
- from fast_llm .layers .common .auxiliary_loss import z_loss
16
+ from fast_llm .layers .common .auxiliary_loss import AuxiliaryLoss , z_loss
17
17
from fast_llm .layers .language_model .config import (
18
18
LanguageModelBaseConfig ,
19
19
LanguageModelDimNames ,
24
24
from fast_llm .layers .transformer .config import TransformerDimNames , TransformerKwargs
25
25
from fast_llm .logging import log_distributed_tensor
26
26
from fast_llm .tensor import ParameterMeta , TensorMeta , init_normal_
27
- from fast_llm .utils import div
27
+ from fast_llm .utils import Assert , div
28
+
29
+ OUTPUT_WEIGHTS = "output_weights"
28
30
29
31
30
32
class LanguageModelHead [ConfigType : LanguageModelBaseConfig ](Configurable [LanguageModelBaseConfig ], Layer ):
@@ -38,6 +40,7 @@ def __init__(
38
40
self ,
39
41
config : LanguageModelBaseConfig ,
40
42
tensor_space : TensorSpace ,
43
+ prediction_distance : int ,
41
44
):
42
45
super ().__init__ (config )
43
46
self ._debug_transformer = config .transformer .debug_transformer
@@ -56,23 +59,24 @@ def __init__(
56
59
57
60
hidden_dim = self ._tensor_space .get_tensor_dim (TransformerDimNames .hidden )
58
61
62
+ self ._loss_name = LanguageModelLossNames .multi_token_prediction_loss (prediction_distance )
59
63
self .final_norm = config .transformer .normalization .get_layer (hidden_dim )
60
64
self ._logits_scale_factor = config .logits_scale_factor
61
65
self ._z_loss_factor = config .logit_z_loss
62
66
63
- # untie embedding weights
64
- if not self . _tie_word_embeddings :
65
- vocab_dim = self . _tensor_space . get_tensor_dim (
66
- LanguageModelDimNames . vocab_tp if self . _parallel_embeddings else LanguageModelDimNames . vocab
67
- )
68
- self .output_weights = ParameterMeta . from_dims (
69
- ( vocab_dim , hidden_dim ),
70
- init_method = init_normal_ (
71
- std = config . init_method_std_embed ,
72
- min_val = config . init_method_min_embed ,
73
- max_val = config . init_method_max_embed ,
74
- ),
75
- )
67
+ # Distance of the target token prediction
68
+ # 0: next-token prediction
69
+ # >0: multi-token prediction (MTP)
70
+ Assert . geq ( prediction_distance , 0 )
71
+ self . _prediction_distance = prediction_distance
72
+ self . is_last_head = self ._prediction_distance == config . prediction_heads - 1
73
+ if self . _prediction_distance > 0 :
74
+ assert (
75
+ not self . _sequence_parallel_logits
76
+ ), "Sequence parallel logits not supported for multi-token prediction."
77
+ assert not self . _cross_entropy_splits , "Cross-entropy splits not supported for multi-token prediction."
78
+
79
+ self . _init_output_weights ( hidden_dim , config )
76
80
77
81
self ._cross_entropy_impl = config .cross_entropy_impl
78
82
if self ._cross_entropy_impl == CrossEntropyImpl .auto :
@@ -90,6 +94,23 @@ def __init__(
90
94
if hasattr (self , "output_weights" ):
91
95
self .output_weights = self ._config .transformer .peft .apply_weight (self .output_weights )
92
96
97
+ def _init_output_weights (self , hidden_dim : TensorDim , config ) -> None :
98
+ # Only the first head defines the output weights
99
+ if self ._tie_word_embeddings or self ._prediction_distance > 0 :
100
+ return
101
+ # untie embedding weights
102
+ vocab_dim = self ._tensor_space .get_tensor_dim (
103
+ LanguageModelDimNames .vocab_tp if self ._parallel_embeddings else LanguageModelDimNames .vocab
104
+ )
105
+ self .output_weights = ParameterMeta .from_dims (
106
+ (vocab_dim , hidden_dim ),
107
+ init_method = init_normal_ (
108
+ std = config .init_method_std_embed ,
109
+ min_val = config .init_method_min_embed ,
110
+ max_val = config .init_method_max_embed ,
111
+ ),
112
+ )
113
+
93
114
def forward (
94
115
self , input_ : torch .Tensor , kwargs : dict , losses : dict | None = None , metrics : dict | None = None
95
116
) -> torch .Tensor :
@@ -100,33 +121,50 @@ def forward(
100
121
tensor_name = "Loss" ,
101
122
reductions = ((DistributedDimNames .data , ReduceOp .AVG ),), # noqa
102
123
)
124
+ if not self .is_last_head :
125
+ # MTP: split the stacked input
126
+ shared_hidden , input_ = torch .unbind (input_ , dim = 0 )
103
127
# TODO: Pytorch copies the grads in backward for no reason (not sure if still the case)
104
128
# TODO: Torch compile implementation sometimes break.
105
129
# TODO: Double-check correctness, optimize a bit more.
106
130
# TODO: Drop autograd entirely.
107
131
# TODO: Skip cross-entropy backward if not needed.
108
132
language_model_loss = self ._forward (input_ , kwargs , losses )
109
133
if language_model_loss is not None :
110
- losses [LanguageModelLossNames . language_model_loss ].append (language_model_loss )
134
+ losses [self . _loss_name ].append (language_model_loss )
111
135
# TODO: Return the model output when needed.
112
- return language_model_loss
136
+ if self .is_last_head :
137
+ # Last head should return the loss for backward.
138
+ return language_model_loss
139
+ else :
140
+ # Backward hook to compute the gradient of the loss
141
+ shared_hidden = AuxiliaryLoss .apply (shared_hidden , language_model_loss , 1.0 )
142
+ # MTP: Return shared_hidden to be used by the next head.
143
+ return shared_hidden
113
144
114
145
def _forward_backward (
115
146
self , input_ : torch .Tensor , kwargs : dict , losses : dict | None = None
116
147
) -> tuple [torch .Tensor , torch .Tensor | None ]:
117
- labels = kwargs [LanguageModelKwargs .labels ].flatten () if LanguageModelKwargs .labels in kwargs else None
148
+ labels = kwargs [LanguageModelKwargs .labels ] if LanguageModelKwargs .labels in kwargs else None
149
+ # MTP: Shift the labels
150
+ labels = labels [:, self ._prediction_distance :].flatten () if labels is not None else None
118
151
if self ._sequence_parallel_logits :
119
152
labels = split_op (labels , self ._tensor_space .distributed .tensor_group , 0 )
120
153
do_grad = labels is not None and self .training
121
154
input_ = input_ .detach ().requires_grad_ (do_grad )
122
155
with torch .enable_grad ():
123
- ln_output = self .final_norm (input_ )
156
+ # MTP: truncate the input
157
+ if self ._prediction_distance > 0 :
158
+ truncated_input = input_ [:, : - self ._prediction_distance , :].contiguous ()
159
+ else :
160
+ truncated_input = input_
161
+ ln_output = self .final_norm (truncated_input )
124
162
125
163
grad_output = kwargs [TransformerKwargs .grad_output ] / (
126
164
self ._group_size if self ._sequence_parallel_logits else 1
127
165
)
128
166
129
- output_weights = kwargs [ WORD_EMBEDDINGS_WEIGHT ] if self ._tie_word_embeddings else self . output_weights
167
+ output_weights = self ._get_output_weights ( kwargs )
130
168
loss , ln_output_grad = self ._logits_cross_entropy_forward_backward_split (
131
169
ln_output .detach (), labels , output_weights , grad_output , kwargs , losses
132
170
)
@@ -137,6 +175,13 @@ def _forward_backward(
137
175
else :
138
176
return loss , None
139
177
178
+ def _get_output_weights (self , kwargs : dict ) -> torch .Tensor :
179
+ if self ._tie_word_embeddings :
180
+ return kwargs [WORD_EMBEDDINGS_WEIGHT ]
181
+ if self ._prediction_distance > 0 :
182
+ return kwargs [OUTPUT_WEIGHTS ]
183
+ return self .output_weights
184
+
140
185
def _logits_cross_entropy_forward_backward_split (
141
186
self ,
142
187
input_ : torch .Tensor ,
@@ -156,6 +201,7 @@ def _logits_cross_entropy_forward_backward_split(
156
201
return None , None
157
202
else :
158
203
loss = None
204
+ # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length
159
205
split_size = div (labels .numel (), self ._cross_entropy_splits )
160
206
grad_output /= self ._cross_entropy_splits
161
207
logit_input = input_ .flatten (0 , - 2 )
0 commit comments