Skip to content

Commit 3c0199f

Browse files
committed
packing disabled filter sequennces longer than seq length
1 parent 54b686a commit 3c0199f

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

fast_llm/data/dataset/gpt/sampled.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def __init__(
9696

9797
if self._config.enable_packing and self._config.use_preference_loss_masking_spans:
9898
raise NotImplementedError("Packing currently not implemented with preference loss masking.")
99+
if not self._config.enable_packing and self._truncate_documents:
100+
raise NotImplementedError("If packing is disabled, document truncation must also be disabled.")
99101

100102
if sampling.cache_directory is None:
101103
self._document_shuffling = MemmapArray()
@@ -122,7 +124,9 @@ def __init__(
122124
self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy"))
123125
self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy"))
124126

125-
self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy"))
127+
if not self._config.enable_packing:
128+
self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy"))
129+
self._doc_length_filtered_indicies = MemmapArray(base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy"))
126130

127131
self._yaml_path = base_path.with_suffix(".yaml")
128132
# Sample or validate the dataset of a given rank.
@@ -168,6 +172,7 @@ def _sample(self) -> None:
168172
/ tokens_per_epoch
169173
)
170174
else:
175+
documents_per_epoch = (~long_docs_filter).sum().item()
171176
num_epochs = math.ceil(self._num_samples / documents_per_epoch)
172177

173178
# Prepare for shuffling.
@@ -310,6 +315,18 @@ def _sample(self) -> None:
310315
# Free memory
311316
del document_shuffling
312317
else:
318+
# index of all documents less than seq length long
319+
doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0]
320+
self._doc_length_filtered_indicies.save(
321+
doc_length_filtered_indicies.numpy(
322+
force=self._config.gpu
323+
)
324+
)
325+
326+
# # apply shuffling on doc_length_filtered_indicies
327+
# document_shuffling_length_filtered_indices = torch.gather(
328+
# doc_length_filtered_indicies, dim=0, index=document_shuffling.to(torch.int64)
329+
# )
313330
if shuffled_epochs > 0:
314331
self._document_shuffling.save(
315332
document_shuffling[:self._num_samples].numpy(
@@ -321,8 +338,6 @@ def _sample(self) -> None:
321338
# yaml_data["unshuffled_tokens"] = num_tokens_unshuffled
322339
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
323340
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))
324-
325-
326341

327342
def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]:
328343
if self._truncate_documents:
@@ -454,9 +469,9 @@ def __getitem__(self, index: int) -> typing.Any:
454469
return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths)
455470
else:
456471
if index < self._unshuffled_documents:
457-
document_index = index % self._documents_per_epoch
472+
document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch]
458473
else:
459-
document_index = self._document_shuffling[index - self._unshuffled_documents].item()
474+
document_index = self._doc_length_filtered_indicies[self._document_shuffling[index - self._unshuffled_documents].item()]
460475

461476
sample = self._indexed_dataset.get(
462477
document_index,

0 commit comments

Comments
 (0)