Skip to content

Commit 54b686a

Browse files
committed
Merge remote-tracking branch 'origin/main' into toby/dpo
2 parents f7796d4 + 58b6f8a commit 54b686a

File tree

16 files changed

+554
-122
lines changed

16 files changed

+554
-122
lines changed

fast_llm/csrc/data.cpp

+60-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
/*
2929
Helper methods for fast index mapping builds.
30-
Changes for Fast-LLM: Use int16 for dataset index, add verbose argument to build_sample_idx.
30+
Changes for Fast-LLM: Use int16 for dataset index, add verbose argument to build_sample_idx, add build_sample_idx_padded
3131
*/
3232

3333
#include <iostream>
@@ -129,6 +129,65 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
129129

130130
}
131131

132+
py::array build_padded_token_cumsum(const py::array_t<int32_t>& sizes_,
133+
const int32_t seq_length,
134+
const int32_t token_cumsum_rate,
135+
const int64_t offset
136+
) {
137+
/*
138+
Build token cumsums at regular intervals from document sizes with padding in mind.
139+
We inject 0 or more padding tokens at the end of every sequence to fill the sequence length.
140+
*/
141+
int32_t seq_size = 0;
142+
int64_t sizes_idx = 0;
143+
int32_t samples = 0;
144+
auto sizes = sizes_.unchecked<1>();
145+
std::vector<int64_t> token_cumsum;
146+
147+
int64_t cumsum = offset;
148+
149+
while (sizes_idx < sizes.size()) {
150+
int32_t size = sizes[sizes_idx];
151+
if (size > seq_length) {
152+
// Skip sequences that are too long, to avoid truncations
153+
if (samples % token_cumsum_rate==0) token_cumsum.push_back(cumsum);
154+
sizes_idx += 1;
155+
samples += 1;
156+
} else if (seq_size + size > seq_length) {
157+
// add padded tokens if a document does not fit in current sequence and start a new sequence
158+
cumsum += seq_length - seq_size;
159+
seq_size = 0;
160+
} else {
161+
// Increment here to account for padding. This ensures that the stored values match the beginning of the next document.
162+
if (samples % token_cumsum_rate==0) token_cumsum.push_back(cumsum);
163+
seq_size += size;
164+
cumsum += size;
165+
sizes_idx += 1;
166+
samples += 1;
167+
}
168+
}
169+
170+
// Add a final (padded) entry so we know how many tokens there are in total.
171+
cumsum += seq_length - seq_size;
172+
token_cumsum.push_back(cumsum);
173+
174+
175+
int64_t* token_cumsum_result = new int64_t[token_cumsum.size()];
176+
memcpy(token_cumsum_result, token_cumsum.data(), token_cumsum.size() * sizeof(int64_t));
177+
178+
py::capsule free_when_done(token_cumsum_result, [](void *mem_) {
179+
int64_t *mem = reinterpret_cast<int64_t*>(mem_);
180+
delete[] mem;
181+
});
182+
183+
const auto byte_size = sizeof(int64_t);
184+
return py::array(std::vector<int64_t>{token_cumsum.size()},
185+
{byte_size},
186+
token_cumsum_result,
187+
free_when_done);
188+
}
189+
132190
PYBIND11_MODULE(data, m) {
133191
m.def("build_sample_idx", &build_sample_idx);
192+
m.def("build_padded_token_cumsum", &build_padded_token_cumsum);
134193
}

fast_llm/data/data/gpt/config.py

+9
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
5757
desc="Multiprocessing context. Do not touch.",
5858
hint=FieldHint.expert,
5959
)
60+
truncate_documents: bool = Field(
61+
default=True,
62+
desc=(
63+
"If enabled, documents may be truncated while being packed to fit the sequence length."
64+
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
65+
" (and documents exceeding the sequence length will be skipped altogether)."
66+
),
67+
hint=FieldHint.feature,
68+
)
6069

6170
def _validate(self) -> None:
6271
if not self.datasets:

fast_llm/data/data/gpt/data.py

+1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def setup(
135135
sequence_length=self._max_sequence_length,
136136
vocab_size=self._vocab_size,
137137
tokenizer=self._tokenizer,
138+
truncate_documents=self._config.truncate_documents,
138139
cross_document_attention=self._cross_document_attention,
139140
)
140141
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)

fast_llm/data/dataset/gpt/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class GPTSamplingData(SamplingData):
8181
sequence_length: int
8282
vocab_size: int
8383
tokenizer: "Tokenizer"
84+
truncate_documents: bool = True
8485
cross_document_attention: bool = True
8586

8687

fast_llm/data/dataset/gpt/sampled.py

+131-54
Large diffs are not rendered by default.

fast_llm/engine/checkpoint/external.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,16 @@ def import_weight(
141141
return weight
142142

143143

144-
class IgnoreWeightConverter(WeightConverter):
144+
class IgnoreImportWeightConverter(WeightConverter):
145+
def __post_init__(self):
146+
Assert.eq(len(self.fast_llm_name), 0)
147+
Assert.gt(len(self.export_name), 0)
148+
145149
def export_weight(
146150
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
147151
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
148152
raise RuntimeError(
149-
f"IgnoreWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}"
153+
f"IgnoreImportWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}"
150154
)
151155

152156
def import_weight(
@@ -155,6 +159,24 @@ def import_weight(
155159
return ()
156160

157161

162+
class IgnoreExportWeightConverter(WeightConverter):
163+
def __post_init__(self):
164+
Assert.gt(len(self.fast_llm_name), 0)
165+
Assert.eq(len(self.export_name), 0)
166+
167+
def export_weight(
168+
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
169+
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
170+
return ()
171+
172+
def import_weight(
173+
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
174+
) -> tuple[torch.Tensor | SafeTensorSlice, ...]:
175+
raise RuntimeError(
176+
f"IgnoreExportWeightConverter should not be used for import: {self.fast_llm_name}, {self.export_name}"
177+
)
178+
179+
158180
class CopyWeightConverter(WeightConverter):
159181
def export_weight(
160182
self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]
@@ -198,7 +220,9 @@ def __init__(self, model: "FastLLMModel"):
198220
if weight_converter.fast_llm_name
199221
}
200222
self._import_converters = {
201-
weight_converter.export_name[0]: weight_converter for weight_converter in weight_converters
223+
weight_converter.export_name[0]: weight_converter
224+
for weight_converter in weight_converters
225+
if weight_converter.export_name
202226
}
203227

204228
@classmethod

fast_llm/engine/checkpoint/state_dict.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
5656
saver.add_tensor(self._get_key(exported_name, shard_name), exported_tensor)
5757

5858
for shard_name, shard_state_dict in state_dict.items():
59-
assert not shard_state_dict, (shard_name, list(state_dict))
59+
assert (
60+
not shard_state_dict
61+
), f"Un-handled entries after conversion: {({k: list(v) for k, v in state_dict.items()})}"
6062

6163
index = saver.finalize()
6264
if self._model.config.distributed.rank == 0:
@@ -90,7 +92,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
9092
context.mark_as_loaded(loaded, (parameter_name, shard_name))
9193

9294
for shard_name, shard_state_dict in state_dict.items():
93-
assert not shard_state_dict, (shard_name, list(state_dict))
95+
assert not shard_state_dict, (shard_name, list(shard_state_dict))
9496

9597
@classmethod
9698
@abc.abstractmethod

fast_llm/layers/language_model/config.py

+12
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ class LanguageModelLossNames:
2222
language_model_loss = "language_model_loss"
2323
z_loss = "z_loss"
2424

25+
@staticmethod
26+
def multi_token_prediction_loss(index: int) -> str:
27+
if index == 0:
28+
return LanguageModelLossNames.language_model_loss
29+
return f"language_model_loss_{index}"
30+
2531

2632
class LanguageModelKwargs:
2733
position_ids = "position_ids"
@@ -59,6 +65,12 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig):
5965
tie_word_embeddings: bool = Field(
6066
default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", hint=FieldHint.core
6167
)
68+
prediction_heads: int = Field(
69+
default=1,
70+
desc="Number of multi-token prediction heads.",
71+
hint=FieldHint.feature,
72+
valid=check_field(Assert.gt, 0),
73+
)
6274

6375
def _validate(self) -> None:
6476
if self.use_position_embeddings is None:

fast_llm/layers/language_model/head.py

+66-20
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from fast_llm.functional.cross_entropy import cross_entropy_forward_backward
1515
from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward
1616
from fast_llm.functional.dpo import compute_simplified_dpo_loss
17-
from fast_llm.layers.common.auxiliary_loss import z_loss
17+
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss
1818
from fast_llm.layers.language_model.config import (
1919
LanguageModelBaseConfig,
2020
LanguageModelDimNames,
@@ -25,7 +25,9 @@
2525
from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs
2626
from fast_llm.logging import log_distributed_tensor
2727
from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_
28-
from fast_llm.utils import div
28+
from fast_llm.utils import Assert, div
29+
30+
OUTPUT_WEIGHTS = "output_weights"
2931

3032

3133
class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer):
@@ -39,6 +41,7 @@ def __init__(
3941
self,
4042
config: LanguageModelBaseConfig,
4143
tensor_space: TensorSpace,
44+
prediction_distance: int,
4245
):
4346
super().__init__(config)
4447
self._debug_transformer = config.transformer.debug_transformer
@@ -57,23 +60,24 @@ def __init__(
5760

5861
hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
5962

63+
self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance)
6064
self.final_norm = config.transformer.normalization.get_layer(hidden_dim)
6165
self._logits_scale_factor = config.logits_scale_factor
6266
self._z_loss_factor = config.logit_z_loss
6367

64-
# untie embedding weights
65-
if not self._tie_word_embeddings:
66-
vocab_dim = self._tensor_space.get_tensor_dim(
67-
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
68-
)
69-
self.output_weights = ParameterMeta.from_dims(
70-
(vocab_dim, hidden_dim),
71-
init_method=init_normal_(
72-
std=config.init_method_std_embed,
73-
min_val=config.init_method_min_embed,
74-
max_val=config.init_method_max_embed,
75-
),
76-
)
68+
# Distance of the target token prediction
69+
# 0: next-token prediction
70+
# >0: multi-token prediction (MTP)
71+
Assert.geq(prediction_distance, 0)
72+
self._prediction_distance = prediction_distance
73+
self.is_last_head = self._prediction_distance == config.prediction_heads - 1
74+
if self._prediction_distance > 0:
75+
assert (
76+
not self._sequence_parallel_logits
77+
), "Sequence parallel logits not supported for multi-token prediction."
78+
assert not self._cross_entropy_splits, "Cross-entropy splits not supported for multi-token prediction."
79+
80+
self._init_output_weights(hidden_dim, config)
7781

7882
self._loss_function_type = config.loss_function_type
7983
if self._loss_function_type == LossFunctionType.cross_entropy:
@@ -97,6 +101,23 @@ def __init__(
97101
if hasattr(self, "output_weights"):
98102
self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights)
99103

104+
def _init_output_weights(self, hidden_dim: TensorDim, config) -> None:
105+
# Only the first head defines the output weights
106+
if self._tie_word_embeddings or self._prediction_distance > 0:
107+
return
108+
# untie embedding weights
109+
vocab_dim = self._tensor_space.get_tensor_dim(
110+
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
111+
)
112+
self.output_weights = ParameterMeta.from_dims(
113+
(vocab_dim, hidden_dim),
114+
init_method=init_normal_(
115+
std=config.init_method_std_embed,
116+
min_val=config.init_method_min_embed,
117+
max_val=config.init_method_max_embed,
118+
),
119+
)
120+
100121
def forward(
101122
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
102123
) -> torch.Tensor:
@@ -107,33 +128,50 @@ def forward(
107128
tensor_name="Loss",
108129
reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa
109130
)
131+
if not self.is_last_head:
132+
# MTP: split the stacked input
133+
shared_hidden, input_ = torch.unbind(input_, dim=0)
110134
# TODO: Pytorch copies the grads in backward for no reason (not sure if still the case)
111135
# TODO: Torch compile implementation sometimes break.
112136
# TODO: Double-check correctness, optimize a bit more.
113137
# TODO: Drop autograd entirely.
114138
# TODO: Skip cross-entropy backward if not needed.
115139
language_model_loss = self._forward(input_, kwargs, losses)
116140
if language_model_loss is not None:
117-
losses[LanguageModelLossNames.language_model_loss].append(language_model_loss)
141+
losses[self._loss_name].append(language_model_loss)
118142
# TODO: Return the model output when needed.
119-
return language_model_loss
143+
if self.is_last_head:
144+
# Last head should return the loss for backward.
145+
return language_model_loss
146+
else:
147+
# Backward hook to compute the gradient of the loss
148+
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
149+
# MTP: Return shared_hidden to be used by the next head.
150+
return shared_hidden
120151

121152
def _forward_backward(
122153
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None
123154
) -> tuple[torch.Tensor, torch.Tensor | None]:
124-
labels = kwargs[LanguageModelKwargs.labels].flatten() if LanguageModelKwargs.labels in kwargs else None
155+
labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None
156+
# MTP: Shift the labels
157+
labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None
125158
if self._sequence_parallel_logits:
126159
labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0)
127160
do_grad = labels is not None and self.training
128161
input_ = input_.detach().requires_grad_(do_grad)
129162
with torch.enable_grad():
130-
ln_output = self.final_norm(input_)
163+
# MTP: truncate the input
164+
if self._prediction_distance > 0:
165+
truncated_input = input_[:, : -self._prediction_distance, :].contiguous()
166+
else:
167+
truncated_input = input_
168+
ln_output = self.final_norm(truncated_input)
131169

132170
grad_output = kwargs[TransformerKwargs.grad_output] / (
133171
self._group_size if self._sequence_parallel_logits else 1
134172
)
135173

136-
output_weights = kwargs[WORD_EMBEDDINGS_WEIGHT] if self._tie_word_embeddings else self.output_weights
174+
output_weights = self._get_output_weights(kwargs)
137175
loss, ln_output_grad = self._loss_fcn(
138176
ln_output.detach(), labels, output_weights, grad_output, kwargs, losses
139177
)
@@ -176,6 +214,13 @@ def _logits_dpo(
176214

177215

178216

217+
def _get_output_weights(self, kwargs: dict) -> torch.Tensor:
218+
if self._tie_word_embeddings:
219+
return kwargs[WORD_EMBEDDINGS_WEIGHT]
220+
if self._prediction_distance > 0:
221+
return kwargs[OUTPUT_WEIGHTS]
222+
return self.output_weights
223+
179224
def _logits_cross_entropy_forward_backward_split(
180225
self,
181226
input_: torch.Tensor,
@@ -195,6 +240,7 @@ def _logits_cross_entropy_forward_backward_split(
195240
return None, None
196241
else:
197242
loss = None
243+
# TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length
198244
split_size = div(labels.numel(), self._cross_entropy_splits)
199245
grad_output /= self._cross_entropy_splits
200246
logit_input = input_.flatten(0, -2)

0 commit comments

Comments
 (0)