Skip to content

Commit 58b6f8a

Browse files
sohamparikhjlamypoiriersohampnow
authored
create samples with padding to avoid truncations (#186)
Co-authored-by: Joel Lamy-Poirier <[email protected]> Co-authored-by: soham.parikh <[email protected]>
1 parent 9036fd2 commit 58b6f8a

File tree

7 files changed

+270
-44
lines changed

7 files changed

+270
-44
lines changed

fast_llm/csrc/data.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
/*
2929
Helper methods for fast index mapping builds.
30-
Changes for Fast-LLM: Use int16 for dataset index, add verbose argument to build_sample_idx.
30+
Changes for Fast-LLM: Use int16 for dataset index, add verbose argument to build_sample_idx, add build_sample_idx_padded
3131
*/
3232

3333
#include <iostream>
@@ -129,6 +129,65 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
129129

130130
}
131131

132+
py::array build_padded_token_cumsum(const py::array_t<int32_t>& sizes_,
133+
const int32_t seq_length,
134+
const int32_t token_cumsum_rate,
135+
const int64_t offset
136+
) {
137+
/*
138+
Build token cumsums at regular intervals from document sizes with padding in mind.
139+
We inject 0 or more padding tokens at the end of every sequence to fill the sequence length.
140+
*/
141+
int32_t seq_size = 0;
142+
int64_t sizes_idx = 0;
143+
int32_t samples = 0;
144+
auto sizes = sizes_.unchecked<1>();
145+
std::vector<int64_t> token_cumsum;
146+
147+
int64_t cumsum = offset;
148+
149+
while (sizes_idx < sizes.size()) {
150+
int32_t size = sizes[sizes_idx];
151+
if (size > seq_length) {
152+
// Skip sequences that are too long, to avoid truncations
153+
if (samples % token_cumsum_rate==0) token_cumsum.push_back(cumsum);
154+
sizes_idx += 1;
155+
samples += 1;
156+
} else if (seq_size + size > seq_length) {
157+
// add padded tokens if a document does not fit in current sequence and start a new sequence
158+
cumsum += seq_length - seq_size;
159+
seq_size = 0;
160+
} else {
161+
// Increment here to account for padding. This ensures that the stored values match the beginning of the next document.
162+
if (samples % token_cumsum_rate==0) token_cumsum.push_back(cumsum);
163+
seq_size += size;
164+
cumsum += size;
165+
sizes_idx += 1;
166+
samples += 1;
167+
}
168+
}
169+
170+
// Add a final (padded) entry so we know how many tokens there are in total.
171+
cumsum += seq_length - seq_size;
172+
token_cumsum.push_back(cumsum);
173+
174+
175+
int64_t* token_cumsum_result = new int64_t[token_cumsum.size()];
176+
memcpy(token_cumsum_result, token_cumsum.data(), token_cumsum.size() * sizeof(int64_t));
177+
178+
py::capsule free_when_done(token_cumsum_result, [](void *mem_) {
179+
int64_t *mem = reinterpret_cast<int64_t*>(mem_);
180+
delete[] mem;
181+
});
182+
183+
const auto byte_size = sizeof(int64_t);
184+
return py::array(std::vector<int64_t>{token_cumsum.size()},
185+
{byte_size},
186+
token_cumsum_result,
187+
free_when_done);
188+
}
189+
132190
PYBIND11_MODULE(data, m) {
133191
m.def("build_sample_idx", &build_sample_idx);
192+
m.def("build_padded_token_cumsum", &build_padded_token_cumsum);
134193
}

fast_llm/data/data/gpt/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
5757
desc="Multiprocessing context. Do not touch.",
5858
hint=FieldHint.expert,
5959
)
60+
truncate_documents: bool = Field(
61+
default=True,
62+
desc=(
63+
"If enabled, documents may be truncated while being packed to fit the sequence length."
64+
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
65+
" (and documents exceeding the sequence length will be skipped altogether)."
66+
),
67+
hint=FieldHint.feature,
68+
)
6069

6170
def _validate(self) -> None:
6271
if not self.datasets:

fast_llm/data/data/gpt/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def setup(
121121
sequence_length=self._max_sequence_length,
122122
vocab_size=self._vocab_size,
123123
tokenizer=self._tokenizer,
124+
truncate_documents=self._config.truncate_documents,
124125
cross_document_attention=self._cross_document_attention,
125126
)
126127
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)

fast_llm/data/dataset/gpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class GPTSamplingData(SamplingData):
7171
sequence_length: int
7272
vocab_size: int
7373
tokenizer: "Tokenizer"
74+
truncate_documents: bool = True
7475
cross_document_attention: bool = True
7576

7677

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 122 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from fast_llm.data.dataset.abstract import SampledDataset
1313
from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType
1414
from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset
15-
from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type
15+
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
1616
from fast_llm.engine.config_utils.run import log_main_rank
1717
from fast_llm.utils import Assert
1818

1919
try:
20-
from fast_llm.csrc.data import build_sample_idx # noqa
20+
from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa
2121

2222
_extension_available = True
2323
except ImportError:
@@ -89,6 +89,7 @@ def __init__(
8989
self._sequence_length = sampling.sequence_length
9090
self._cross_document_attention = sampling.cross_document_attention
9191
self._config = sampling.config
92+
self._truncate_documents = sampling.truncate_documents
9293
self._device = torch.device("cuda" if self._config.gpu else "cpu")
9394

9495
if sampling.cache_directory is None:
@@ -124,15 +125,35 @@ def _sample(self) -> None:
124125
"""
125126
# Get the document sizes, the main information needed for sampling.
126127
document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device)
127-
128-
# Calculate basic stats.
129128
documents_per_epoch = document_sizes.numel()
130129
tokens_per_epoch = document_sizes.sum().item()
130+
131+
# Calculate basic stats.
132+
if not self._truncate_documents:
133+
assert _extension_available, (
134+
"The C++ extension for dataset sampling is missing."
135+
" Please make sure Fast-LLM is installed correctly."
136+
)
137+
long_docs_filter = document_sizes > self._sequence_length + 1
138+
ignored_documents = sum(long_docs_filter)
139+
if ignored_documents:
140+
log_main_rank(
141+
f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.",
142+
log_fn=logger.warning,
143+
)
144+
tokens_per_epoch = document_sizes[~long_docs_filter].sum().item()
145+
if tokens_per_epoch == 0:
146+
raise RuntimeError(
147+
f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}."
148+
)
131149
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
132150
# We produce sequences of length `self._sequence_length + 1` so the last token has a label,
133-
# but we also include that last label in the following sample,
151+
# but in case of truncations we also include that last label in the following sample,
134152
# so we need `sequence_length * num_samples + 1` tokens in total.
135-
num_epochs = math.ceil((self._sequence_length * self._num_samples + 1) / tokens_per_epoch)
153+
num_epochs = math.ceil(
154+
((self._sequence_length + 1 - self._truncate_documents) * self._num_samples + 1 * self._truncate_documents)
155+
/ tokens_per_epoch
156+
)
136157

137158
# Prepare for shuffling.
138159
generator = torch.Generator(device=self._device)
@@ -154,13 +175,17 @@ def _sample(self) -> None:
154175
"num_samples": self._num_samples,
155176
"unshuffled_epochs": unshuffled_epochs,
156177
"sequence_length": self._sequence_length,
178+
"truncate_documents": self._truncate_documents,
157179
"config": self._config.to_serialized(),
158180
}
159181
self._load_yaml_data(yaml_data)
160182

161183
if self._yaml_path is not None:
162184
if self._yaml_path.is_file():
163185
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
186+
unshuffled_tokens = loaded_yaml_data.pop("unshuffled_tokens", None)
187+
if unshuffled_tokens is not None:
188+
self._unshuffled_tokens = unshuffled_tokens
164189
if loaded_yaml_data != yaml_data:
165190
raise RuntimeError(
166191
f"Invalid dataset cache for dataset {self.name}."
@@ -172,9 +197,6 @@ def _sample(self) -> None:
172197
# Dataset is already sampled, skip.
173198
logger.info(f"Using existing sampling for dataset {self.name}")
174199
return
175-
else:
176-
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
177-
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))
178200

179201
if shuffled_documents > 1e8:
180202
warnings.warn(
@@ -232,51 +254,78 @@ def _sample(self) -> None:
232254
# So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`.
233255
# Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation.
234256
# Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))`
257+
if unshuffled_epochs > 0:
258+
token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum(
259+
document_sizes,
260+
offset=0,
261+
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
262+
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs),
263+
)
264+
if self._truncate_documents:
265+
num_tokens_unshuffled = tokens_per_epoch * unshuffled_epochs
266+
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled)
267+
else:
268+
num_tokens_unshuffled = 0
269+
self._unshuffled_tokens = num_tokens_unshuffled
270+
271+
if self._yaml_path is not None:
272+
yaml_data["unshuffled_tokens"] = num_tokens_unshuffled
273+
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
274+
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))
275+
235276
if shuffled_epochs > 0:
236-
token_cumsum_shuffled = self._get_token_cumsum(
277+
token_cumsum_shuffled, num_tokens_shuffled = self._get_token_cumsum(
237278
document_sizes[
238279
# Torch indexing only works with int32 or int64
239280
document_shuffling.to(
240281
dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32
241282
)
242283
],
243-
offset=unshuffled_epochs * tokens_per_epoch,
244-
dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch,
284+
offset=num_tokens_unshuffled,
285+
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
286+
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs),
245287
)
246-
self._token_cumsum_shuffled.save(token_cumsum_shuffled.numpy(force=self._config.gpu))
288+
self._token_cumsum_shuffled.save(token_cumsum_shuffled)
247289
self._document_shuffling.save(
248-
document_shuffling[: (token_cumsum_shuffled.numel() + 1) * TOKEN_CUMSUM_RATE].numpy(
290+
document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy(
249291
force=self._config.gpu
250292
)
251293
)
252294
# Free memory
253-
del token_cumsum_shuffled
254295
del document_shuffling
255296

256-
if unshuffled_epochs > 0:
257-
token_cumsum_unshuffled = self._get_token_cumsum(
258-
document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch
297+
def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]:
298+
if self._truncate_documents:
299+
# Create the output tensor.
300+
out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype.torch)
301+
# Get partial sums for regular intervals, excluding the last incomplete interval.
302+
torch.sum(
303+
sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE),
304+
dim=1,
305+
out=out[1:],
259306
)
260-
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled.numpy(force=self._config.gpu))
261-
262-
def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: torch.dtype) -> torch.Tensor:
263-
# Create the output tensor.
264-
out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype)
265-
# Get partial sums for regular intervals, excluding the last incomplete interval.
266-
torch.sum(
267-
sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), dim=1, out=out[1:]
268-
)
269-
# Pad with the begin offset
270-
out[0] = offset
271-
# Calculate the cumsum.
272-
out.cumsum_(0)
273-
# Crop unnecessary entries.
274-
return out[
275-
: torch.clamp_min_(
276-
torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"),
277-
0,
307+
# Pad with the begin offset
308+
out[0] = offset
309+
# Calculate the cumsum.
310+
out.cumsum_(0)
311+
# Crop unnecessary entries.
312+
out = out[
313+
: torch.clamp_min_(
314+
torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"),
315+
0,
316+
)
317+
]
318+
return out.numpy(force=self._config.gpu), None
319+
else:
320+
# TODO: dynamically handle int64 or int32 in CPP
321+
out = build_padded_token_cumsum(
322+
sizes.cpu().numpy(), (self._sequence_length + 1), TOKEN_CUMSUM_RATE, offset
278323
)
279-
]
324+
num_tokens = out[-1]
325+
out = out[:-1][
326+
: np.clip(np.searchsorted(out, self._num_samples * (self._sequence_length + 1), side="right"), 0, None)
327+
]
328+
return out, num_tokens
280329

281330
def __len__(self) -> int:
282331
return self._num_samples
@@ -288,7 +337,9 @@ def __getitem__(self, index: int) -> typing.Any:
288337
The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`).
289338
"""
290339
self._lazy_load()
291-
token_start = index * self._sequence_length
340+
# tokens at the boundary are included in only one sample when we pack without truncations
341+
# in case of packing with truncations, the last token from the previous sample is also the first token of the next sample
342+
token_start = index * (self._sequence_length + 1 - self._truncate_documents)
292343
token_end = token_start + self._sequence_length + 1
293344

294345
if token_start < self._unshuffled_tokens:
@@ -302,6 +353,7 @@ def __getitem__(self, index: int) -> typing.Any:
302353
token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1
303354

304355
document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset
356+
305357
token_count = token_start_array[token_start_cumsum_index]
306358

307359
token_ids = []
@@ -314,6 +366,25 @@ def __getitem__(self, index: int) -> typing.Any:
314366
document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item()
315367

316368
document_size = self._indexed_dataset.get_document_size(document_index)
369+
370+
if not self._truncate_documents:
371+
if document_size > self._sequence_length + 1:
372+
# Document too long, ignore
373+
document_sampling_index += 1
374+
continue
375+
tokens_in_sample = token_count % (self._sequence_length + 1)
376+
if document_size + tokens_in_sample > self._sequence_length + 1:
377+
# Document belongs to the next sample, need to account for padding.
378+
padding_size = self._sequence_length + 1 - tokens_in_sample
379+
if token_count > token_start:
380+
# Add padding tokens to current sample
381+
token_ids.append(np.full((padding_size,), -100, dtype=np.int64))
382+
Assert.eq(token_count + padding_size, token_end)
383+
break
384+
else:
385+
# Move on to the next sample.
386+
token_count += padding_size
387+
317388
# Determine if the document belongs to the requested sample.
318389
if token_count + document_size >= token_start:
319390
# Determine which part of the document belong to the sample, and add it to the list.
@@ -343,7 +414,9 @@ def __getitem__(self, index: int) -> typing.Any:
343414
)
344415
token_ids = np.concatenate(token_ids, dtype=np.int64)
345416
loss_masking_spans = (
346-
np.stack(loss_masking_spans, dtype=np.int32) if self._config.use_loss_masking_spans else None
417+
(np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([]))
418+
if self._config.use_loss_masking_spans
419+
else None
347420
)
348421
Assert.eq(len(token_ids), self._sequence_length + 1)
349422

@@ -357,9 +430,12 @@ def _lazy_load(self):
357430
if not hasattr(self, "_documents_per_epoch"):
358431
self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r")))
359432

360-
def _load_yaml_data(self, data: dict[str, typing.Any]):
433+
def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:
361434
self._documents_per_epoch = data["dataset"]["documents_per_epoch"]
362-
self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"]
435+
if unshuffled_tokens := data.get("unshuffled_tokens") is not None:
436+
self._unshuffled_tokens = unshuffled_tokens
437+
else:
438+
self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"]
363439
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch
364440

365441

@@ -380,9 +456,12 @@ def __init__(
380456
self._indexed_dataset = indexed_dataset
381457
self._num_samples = sampling.num_samples
382458
self._sequence_length = sampling.sequence_length
459+
if not sampling.truncate_documents:
460+
raise NotImplementedError(
461+
"Legacy sampling only supports document truncation. Please use the latest dataset format."
462+
)
383463
self._cross_document_attention = sampling.cross_document_attention
384464
self._config = sampling.config
385-
self._tokenizer = sampling.tokenizer
386465

387466
if sampling.cache_directory is None:
388467
log_main_rank(
@@ -498,7 +577,7 @@ def __getitem__(self, idx: int) -> typing.Any:
498577
for span in sample.loss_masking_spans:
499578
spans.append(span + offset)
500579
offset += len(sample.token_ids)
501-
spans = np.stack(spans, dtype=np.int32)
580+
spans = np.stack(spans, dtype=np.int32) if spans else np.array([])
502581
else:
503582
spans = None
504583
sequence_lengths = (

0 commit comments

Comments
 (0)