Skip to content

Commit 9036fd2

Browse files
authored
Multi-token prediction (#179)
1 parent 21182c2 commit 9036fd2

File tree

10 files changed

+276
-67
lines changed

10 files changed

+276
-67
lines changed

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def _sample(self) -> None:
128128
# Calculate basic stats.
129129
documents_per_epoch = document_sizes.numel()
130130
tokens_per_epoch = document_sizes.sum().item()
131+
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
131132
# We produce sequences of length `self._sequence_length + 1` so the last token has a label,
132133
# but we also include that last label in the following sample,
133134
# so we need `sequence_length * num_samples + 1` tokens in total.

fast_llm/engine/checkpoint/external.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,16 @@ def import_weight(
141141
return weight
142142

143143

144-
class IgnoreWeightConverter(WeightConverter):
144+
class IgnoreImportWeightConverter(WeightConverter):
145+
def __post_init__(self):
146+
Assert.eq(len(self.fast_llm_name), 0)
147+
Assert.gt(len(self.export_name), 0)
148+
145149
def export_weight(
146150
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
147151
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
148152
raise RuntimeError(
149-
f"IgnoreWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}"
153+
f"IgnoreImportWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}"
150154
)
151155

152156
def import_weight(
@@ -155,6 +159,24 @@ def import_weight(
155159
return ()
156160

157161

162+
class IgnoreExportWeightConverter(WeightConverter):
163+
def __post_init__(self):
164+
Assert.gt(len(self.fast_llm_name), 0)
165+
Assert.eq(len(self.export_name), 0)
166+
167+
def export_weight(
168+
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
169+
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
170+
return ()
171+
172+
def import_weight(
173+
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
174+
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
175+
raise RuntimeError(
176+
f"IgnoreExportWeightConverter should not be used for import: {self.fast_llm_name}, {self.export_name}"
177+
)
178+
179+
158180
class CopyWeightConverter(WeightConverter):
159181
def export_weight(
160182
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
@@ -198,7 +220,9 @@ def __init__(self, model: "FastLLMModel"):
198220
if weight_converter.fast_llm_name
199221
}
200222
self._import_converters = {
201-
weight_converter.export_name[0]: weight_converter for weight_converter in weight_converters
223+
weight_converter.export_name[0]: weight_converter
224+
for weight_converter in weight_converters
225+
if weight_converter.export_name
202226
}
203227

204228
@classmethod

fast_llm/engine/checkpoint/state_dict.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
5656
saver.add_tensor(self._get_key(exported_name, shard_name), exported_tensor)
5757

5858
for shard_name, shard_state_dict in state_dict.items():
59-
assert not shard_state_dict, (shard_name, list(state_dict))
59+
assert (
60+
not shard_state_dict
61+
), f"Un-handled entries after conversion: {({k: list(v) for k, v in state_dict.items()})}"
6062

6163
index = saver.finalize()
6264
if self._model.config.distributed.rank == 0:
@@ -90,7 +92,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
9092
context.mark_as_loaded(loaded, (parameter_name, shard_name))
9193

9294
for shard_name, shard_state_dict in state_dict.items():
93-
assert not shard_state_dict, (shard_name, list(state_dict))
95+
assert not shard_state_dict, (shard_name, list(shard_state_dict))
9496

9597
@classmethod
9698
@abc.abstractmethod

fast_llm/layers/language_model/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ class LanguageModelLossNames:
2222
language_model_loss = "language_model_loss"
2323
z_loss = "z_loss"
2424

25+
@staticmethod
26+
def multi_token_prediction_loss(index: int) -> str:
27+
if index == 0:
28+
return LanguageModelLossNames.language_model_loss
29+
return f"language_model_loss_{index}"
30+
2531

2632
class LanguageModelKwargs:
2733
position_ids = "position_ids"
@@ -57,6 +63,12 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig):
5763
tie_word_embeddings: bool = Field(
5864
default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", hint=FieldHint.core
5965
)
66+
prediction_heads: int = Field(
67+
default=1,
68+
desc="Number of multi-token prediction heads.",
69+
hint=FieldHint.feature,
70+
valid=check_field(Assert.gt, 0),
71+
)
6072

6173
def _validate(self) -> None:
6274
if self.use_position_embeddings is None:

fast_llm/layers/language_model/head.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fast_llm.functional.config import CrossEntropyImpl, TritonConfig
1414
from fast_llm.functional.cross_entropy import cross_entropy_forward_backward
1515
from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward
16-
from fast_llm.layers.common.auxiliary_loss import z_loss
16+
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss
1717
from fast_llm.layers.language_model.config import (
1818
LanguageModelBaseConfig,
1919
LanguageModelDimNames,
@@ -24,7 +24,9 @@
2424
from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs
2525
from fast_llm.logging import log_distributed_tensor
2626
from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_
27-
from fast_llm.utils import div
27+
from fast_llm.utils import Assert, div
28+
29+
OUTPUT_WEIGHTS = "output_weights"
2830

2931

3032
class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer):
@@ -38,6 +40,7 @@ def __init__(
3840
self,
3941
config: LanguageModelBaseConfig,
4042
tensor_space: TensorSpace,
43+
prediction_distance: int,
4144
):
4245
super().__init__(config)
4346
self._debug_transformer = config.transformer.debug_transformer
@@ -56,23 +59,24 @@ def __init__(
5659

5760
hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
5861

62+
self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance)
5963
self.final_norm = config.transformer.normalization.get_layer(hidden_dim)
6064
self._logits_scale_factor = config.logits_scale_factor
6165
self._z_loss_factor = config.logit_z_loss
6266

63-
# untie embedding weights
64-
if not self._tie_word_embeddings:
65-
vocab_dim = self._tensor_space.get_tensor_dim(
66-
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
67-
)
68-
self.output_weights = ParameterMeta.from_dims(
69-
(vocab_dim, hidden_dim),
70-
init_method=init_normal_(
71-
std=config.init_method_std_embed,
72-
min_val=config.init_method_min_embed,
73-
max_val=config.init_method_max_embed,
74-
),
75-
)
67+
# Distance of the target token prediction
68+
# 0: next-token prediction
69+
# >0: multi-token prediction (MTP)
70+
Assert.geq(prediction_distance, 0)
71+
self._prediction_distance = prediction_distance
72+
self.is_last_head = self._prediction_distance == config.prediction_heads - 1
73+
if self._prediction_distance > 0:
74+
assert (
75+
not self._sequence_parallel_logits
76+
), "Sequence parallel logits not supported for multi-token prediction."
77+
assert not self._cross_entropy_splits, "Cross-entropy splits not supported for multi-token prediction."
78+
79+
self._init_output_weights(hidden_dim, config)
7680

7781
self._cross_entropy_impl = config.cross_entropy_impl
7882
if self._cross_entropy_impl == CrossEntropyImpl.auto:
@@ -90,6 +94,23 @@ def __init__(
9094
if hasattr(self, "output_weights"):
9195
self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights)
9296

97+
def _init_output_weights(self, hidden_dim: TensorDim, config) -> None:
98+
# Only the first head defines the output weights
99+
if self._tie_word_embeddings or self._prediction_distance > 0:
100+
return
101+
# untie embedding weights
102+
vocab_dim = self._tensor_space.get_tensor_dim(
103+
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
104+
)
105+
self.output_weights = ParameterMeta.from_dims(
106+
(vocab_dim, hidden_dim),
107+
init_method=init_normal_(
108+
std=config.init_method_std_embed,
109+
min_val=config.init_method_min_embed,
110+
max_val=config.init_method_max_embed,
111+
),
112+
)
113+
93114
def forward(
94115
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
95116
) -> torch.Tensor:
@@ -100,33 +121,50 @@ def forward(
100121
tensor_name="Loss",
101122
reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa
102123
)
124+
if not self.is_last_head:
125+
# MTP: split the stacked input
126+
shared_hidden, input_ = torch.unbind(input_, dim=0)
103127
# TODO: Pytorch copies the grads in backward for no reason (not sure if still the case)
104128
# TODO: Torch compile implementation sometimes break.
105129
# TODO: Double-check correctness, optimize a bit more.
106130
# TODO: Drop autograd entirely.
107131
# TODO: Skip cross-entropy backward if not needed.
108132
language_model_loss = self._forward(input_, kwargs, losses)
109133
if language_model_loss is not None:
110-
losses[LanguageModelLossNames.language_model_loss].append(language_model_loss)
134+
losses[self._loss_name].append(language_model_loss)
111135
# TODO: Return the model output when needed.
112-
return language_model_loss
136+
if self.is_last_head:
137+
# Last head should return the loss for backward.
138+
return language_model_loss
139+
else:
140+
# Backward hook to compute the gradient of the loss
141+
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
142+
# MTP: Return shared_hidden to be used by the next head.
143+
return shared_hidden
113144

114145
def _forward_backward(
115146
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None
116147
) -> tuple[torch.Tensor, torch.Tensor | None]:
117-
labels = kwargs[LanguageModelKwargs.labels].flatten() if LanguageModelKwargs.labels in kwargs else None
148+
labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None
149+
# MTP: Shift the labels
150+
labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None
118151
if self._sequence_parallel_logits:
119152
labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0)
120153
do_grad = labels is not None and self.training
121154
input_ = input_.detach().requires_grad_(do_grad)
122155
with torch.enable_grad():
123-
ln_output = self.final_norm(input_)
156+
# MTP: truncate the input
157+
if self._prediction_distance > 0:
158+
truncated_input = input_[:, : -self._prediction_distance, :].contiguous()
159+
else:
160+
truncated_input = input_
161+
ln_output = self.final_norm(truncated_input)
124162

125163
grad_output = kwargs[TransformerKwargs.grad_output] / (
126164
self._group_size if self._sequence_parallel_logits else 1
127165
)
128166

129-
output_weights = kwargs[WORD_EMBEDDINGS_WEIGHT] if self._tie_word_embeddings else self.output_weights
167+
output_weights = self._get_output_weights(kwargs)
130168
loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split(
131169
ln_output.detach(), labels, output_weights, grad_output, kwargs, losses
132170
)
@@ -137,6 +175,13 @@ def _forward_backward(
137175
else:
138176
return loss, None
139177

178+
def _get_output_weights(self, kwargs: dict) -> torch.Tensor:
179+
if self._tie_word_embeddings:
180+
return kwargs[WORD_EMBEDDINGS_WEIGHT]
181+
if self._prediction_distance > 0:
182+
return kwargs[OUTPUT_WEIGHTS]
183+
return self.output_weights
184+
140185
def _logits_cross_entropy_forward_backward_split(
141186
self,
142187
input_: torch.Tensor,
@@ -156,6 +201,7 @@ def _logits_cross_entropy_forward_backward_split(
156201
return None, None
157202
else:
158203
loss = None
204+
# TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length
159205
split_size = div(labels.numel(), self._cross_entropy_splits)
160206
grad_output /= self._cross_entropy_splits
161207
logit_input = input_.flatten(0, -2)

fast_llm/layers/transformer/transformer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fast_llm.core.distributed import set_generator
77
from fast_llm.engine.base_model.base_model import Layer
88
from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank
9-
from fast_llm.engine.config_utils.tensor_space import TensorSpace
9+
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
1010
from fast_llm.layers.transformer.attention import Attention
1111
from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs
1212
from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP
@@ -27,11 +27,14 @@ def __init__(
2727
config: TransformerConfig,
2828
tensor_space: TensorSpace,
2929
layer_index: int,
30+
return_input: bool = False,
3031
):
3132
super().__init__()
3233
self._config = config
3334
self._tensor_space = tensor_space
3435
self._dropout_p = self._config.hidden_dropout
36+
# For multi-token prediction, return a stack of shared_hidden and transformer_output.
37+
self._return_input = return_input
3538

3639
self._layer_index = layer_index
3740
self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory
@@ -63,9 +66,10 @@ def name(self) -> str:
6366
return f"Transformer layer {self._layer_index}"
6467

6568
def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict):
66-
return TensorMeta.from_dims(
67-
kwargs[TransformerKwargs.hidden_dims], tensor_name=f"{self.name} {name}", dtype=tensor.dtype
68-
)
69+
dims = kwargs[TransformerKwargs.hidden_dims]
70+
if self._return_input:
71+
dims = (TensorDim("stacked_input_output", 2),) + dims
72+
return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype)
6973

7074
def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, typing.Any], *, bias=None) -> None:
7175
if self._config.debug_transformer_memory:
@@ -103,6 +107,7 @@ def forward(
103107
)
104108
if self._debug_mode:
105109
self._debug_log(None, "Begin", kwargs)
110+
fw_input = input_
106111
hidden_states = self.norm_1(input_)
107112
if self._debug_mode:
108113
self._debug_log(hidden_states, "Norm 1", kwargs)
@@ -123,4 +128,6 @@ def forward(
123128
hidden_states = self._bias_dropout_add(hidden_states, bias, input_)
124129
if self._debug_mode:
125130
self._debug_log(None, "MLP residual", kwargs, bias=bias)
131+
if self._return_input:
132+
hidden_states = torch.stack((fw_input, hidden_states), dim=0)
126133
return hidden_states

fast_llm/models/gpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Starcoder2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
3535
class LlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
3636
name: typing.ClassVar[str] = "llama"
3737

38+
3839
class Qwen2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
3940
name: typing.ClassVar[str] = "qwen2"
4041

0 commit comments

Comments
 (0)