@@ -96,6 +96,8 @@ def __init__(
96
96
97
97
if self ._config .enable_packing and self ._config .use_preference_loss_masking_spans :
98
98
raise NotImplementedError ("Packing currently not implemented with preference loss masking." )
99
+ if not self ._config .enable_packing and self ._truncate_documents :
100
+ raise NotImplementedError ("If packing is disabled, document truncation must also be disabled." )
99
101
100
102
if sampling .cache_directory is None :
101
103
self ._document_shuffling = MemmapArray ()
@@ -122,7 +124,9 @@ def __init__(
122
124
self ._token_cumsum_shuffled = MemmapArray (base_path .with_name (base_path .name + "_shuffled_cumsum.npy" ))
123
125
self ._token_cumsum_unshuffled = MemmapArray (base_path .with_name (base_path .name + "_unshuffled_cumsum.npy" ))
124
126
125
- self ._document_sizes = MemmapArray (base_path .with_name (base_path .name + "_doc_sizes.npy" ))
127
+ if not self ._config .enable_packing :
128
+ self ._document_sizes = MemmapArray (base_path .with_name (base_path .name + "_doc_sizes.npy" ))
129
+ self ._doc_length_filtered_indicies = MemmapArray (base_path .with_name (base_path .name + "_doc_length_filtered_indices.npy" ))
126
130
127
131
self ._yaml_path = base_path .with_suffix (".yaml" )
128
132
# Sample or validate the dataset of a given rank.
@@ -168,6 +172,7 @@ def _sample(self) -> None:
168
172
/ tokens_per_epoch
169
173
)
170
174
else :
175
+ documents_per_epoch = (~ long_docs_filter ).sum ().item ()
171
176
num_epochs = math .ceil (self ._num_samples / documents_per_epoch )
172
177
173
178
# Prepare for shuffling.
@@ -310,6 +315,18 @@ def _sample(self) -> None:
310
315
# Free memory
311
316
del document_shuffling
312
317
else :
318
+ # index of all documents less than seq length long
319
+ doc_length_filtered_indicies = torch .nonzero (~ long_docs_filter , as_tuple = True )[0 ]
320
+ self ._doc_length_filtered_indicies .save (
321
+ doc_length_filtered_indicies .numpy (
322
+ force = self ._config .gpu
323
+ )
324
+ )
325
+
326
+ # # apply shuffling on doc_length_filtered_indicies
327
+ # document_shuffling_length_filtered_indices = torch.gather(
328
+ # doc_length_filtered_indicies, dim=0, index=document_shuffling.to(torch.int64)
329
+ # )
313
330
if shuffled_epochs > 0 :
314
331
self ._document_shuffling .save (
315
332
document_shuffling [:self ._num_samples ].numpy (
@@ -321,8 +338,6 @@ def _sample(self) -> None:
321
338
# yaml_data["unshuffled_tokens"] = num_tokens_unshuffled
322
339
self ._yaml_path .parent .mkdir (parents = True , exist_ok = True )
323
340
yaml .safe_dump (yaml_data , self ._yaml_path .open ("w" ))
324
-
325
-
326
341
327
342
def _get_token_cumsum (self , sizes : torch .Tensor , offset : int , dtype : DataType ) -> tuple [np .ndarray , int | None ]:
328
343
if self ._truncate_documents :
@@ -454,9 +469,9 @@ def __getitem__(self, index: int) -> typing.Any:
454
469
return GPTSample (token_ids = token_ids , loss_masking_spans = loss_masking_spans , sequence_lengths = sequence_lengths )
455
470
else :
456
471
if index < self ._unshuffled_documents :
457
- document_index = index % self ._documents_per_epoch
472
+ document_index = self . _doc_length_filtered_indicies [ index % self ._documents_per_epoch ]
458
473
else :
459
- document_index = self ._document_shuffling [index - self ._unshuffled_documents ].item ()
474
+ document_index = self ._doc_length_filtered_indicies [ self . _document_shuffling [index - self ._unshuffled_documents ].item ()]
460
475
461
476
sample = self ._indexed_dataset .get (
462
477
document_index ,
0 commit comments