12
12
from fast_llm .data .dataset .abstract import SampledDataset
13
13
from fast_llm .data .dataset .gpt .config import GPTSamplingData , ShufflingType
14
14
from fast_llm .data .dataset .gpt .indexed import GPTIndexedDataset
15
- from fast_llm .engine .config_utils .data_type import get_unsigned_integer_type
15
+ from fast_llm .engine .config_utils .data_type import DataType , get_unsigned_integer_type
16
16
from fast_llm .engine .config_utils .run import log_main_rank
17
17
from fast_llm .utils import Assert
18
18
19
19
try :
20
- from fast_llm .csrc .data import build_sample_idx # noqa
20
+ from fast_llm .csrc .data import build_padded_token_cumsum , build_sample_idx # noqa
21
21
22
22
_extension_available = True
23
23
except ImportError :
@@ -89,6 +89,7 @@ def __init__(
89
89
self ._sequence_length = sampling .sequence_length
90
90
self ._cross_document_attention = sampling .cross_document_attention
91
91
self ._config = sampling .config
92
+ self ._truncate_documents = sampling .truncate_documents
92
93
self ._device = torch .device ("cuda" if self ._config .gpu else "cpu" )
93
94
94
95
if sampling .cache_directory is None :
@@ -124,15 +125,35 @@ def _sample(self) -> None:
124
125
"""
125
126
# Get the document sizes, the main information needed for sampling.
126
127
document_sizes = torch .from_numpy (self ._indexed_dataset .get_document_sizes ()).to (self ._device )
127
-
128
- # Calculate basic stats.
129
128
documents_per_epoch = document_sizes .numel ()
130
129
tokens_per_epoch = document_sizes .sum ().item ()
130
+
131
+ # Calculate basic stats.
132
+ if not self ._truncate_documents :
133
+ assert _extension_available , (
134
+ "The C++ extension for dataset sampling is missing."
135
+ " Please make sure Fast-LLM is installed correctly."
136
+ )
137
+ long_docs_filter = document_sizes > self ._sequence_length + 1
138
+ ignored_documents = sum (long_docs_filter )
139
+ if ignored_documents :
140
+ log_main_rank (
141
+ f" > { ignored_documents } /{ documents_per_epoch } documents are longer than { self ._sequence_length + 1 } tokens and will be ignored." ,
142
+ log_fn = logger .warning ,
143
+ )
144
+ tokens_per_epoch = document_sizes [~ long_docs_filter ].sum ().item ()
145
+ if tokens_per_epoch == 0 :
146
+ raise RuntimeError (
147
+ f" > No documents shorter than { self ._sequence_length + 1 } tokens found in dataset { self ._indexed_dataset .name } ."
148
+ )
131
149
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
132
150
# We produce sequences of length `self._sequence_length + 1` so the last token has a label,
133
- # but we also include that last label in the following sample,
151
+ # but in case of truncations we also include that last label in the following sample,
134
152
# so we need `sequence_length * num_samples + 1` tokens in total.
135
- num_epochs = math .ceil ((self ._sequence_length * self ._num_samples + 1 ) / tokens_per_epoch )
153
+ num_epochs = math .ceil (
154
+ ((self ._sequence_length + 1 - self ._truncate_documents ) * self ._num_samples + 1 * self ._truncate_documents )
155
+ / tokens_per_epoch
156
+ )
136
157
137
158
# Prepare for shuffling.
138
159
generator = torch .Generator (device = self ._device )
@@ -154,13 +175,17 @@ def _sample(self) -> None:
154
175
"num_samples" : self ._num_samples ,
155
176
"unshuffled_epochs" : unshuffled_epochs ,
156
177
"sequence_length" : self ._sequence_length ,
178
+ "truncate_documents" : self ._truncate_documents ,
157
179
"config" : self ._config .to_serialized (),
158
180
}
159
181
self ._load_yaml_data (yaml_data )
160
182
161
183
if self ._yaml_path is not None :
162
184
if self ._yaml_path .is_file ():
163
185
loaded_yaml_data = yaml .safe_load (self ._yaml_path .open ("r" ))
186
+ unshuffled_tokens = loaded_yaml_data .pop ("unshuffled_tokens" , None )
187
+ if unshuffled_tokens is not None :
188
+ self ._unshuffled_tokens = unshuffled_tokens
164
189
if loaded_yaml_data != yaml_data :
165
190
raise RuntimeError (
166
191
f"Invalid dataset cache for dataset { self .name } ."
@@ -172,9 +197,6 @@ def _sample(self) -> None:
172
197
# Dataset is already sampled, skip.
173
198
logger .info (f"Using existing sampling for dataset { self .name } " )
174
199
return
175
- else :
176
- self ._yaml_path .parent .mkdir (parents = True , exist_ok = True )
177
- yaml .safe_dump (yaml_data , self ._yaml_path .open ("w" ))
178
200
179
201
if shuffled_documents > 1e8 :
180
202
warnings .warn (
@@ -232,51 +254,78 @@ def _sample(self) -> None:
232
254
# So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`.
233
255
# Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation.
234
256
# Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))`
257
+ if unshuffled_epochs > 0 :
258
+ token_cumsum_unshuffled , num_tokens_unshuffled = self ._get_token_cumsum (
259
+ document_sizes ,
260
+ offset = 0 ,
261
+ # TODO: Allowing for max 100% extra tokens for padding, is that enough?
262
+ dtype = get_unsigned_integer_type ((2 - self ._truncate_documents ) * tokens_per_epoch * num_epochs ),
263
+ )
264
+ if self ._truncate_documents :
265
+ num_tokens_unshuffled = tokens_per_epoch * unshuffled_epochs
266
+ self ._token_cumsum_unshuffled .save (token_cumsum_unshuffled )
267
+ else :
268
+ num_tokens_unshuffled = 0
269
+ self ._unshuffled_tokens = num_tokens_unshuffled
270
+
271
+ if self ._yaml_path is not None :
272
+ yaml_data ["unshuffled_tokens" ] = num_tokens_unshuffled
273
+ self ._yaml_path .parent .mkdir (parents = True , exist_ok = True )
274
+ yaml .safe_dump (yaml_data , self ._yaml_path .open ("w" ))
275
+
235
276
if shuffled_epochs > 0 :
236
- token_cumsum_shuffled = self ._get_token_cumsum (
277
+ token_cumsum_shuffled , num_tokens_shuffled = self ._get_token_cumsum (
237
278
document_sizes [
238
279
# Torch indexing only works with int32 or int64
239
280
document_shuffling .to (
240
281
dtype = torch .int64 if document_shuffling .dtype == torch .int64 else torch .int32
241
282
)
242
283
],
243
- offset = unshuffled_epochs * tokens_per_epoch ,
244
- dtype = get_unsigned_integer_type (tokens_per_epoch * num_epochs ).torch ,
284
+ offset = num_tokens_unshuffled ,
285
+ # TODO: Allowing for max 100% extra tokens for padding, is that enough?
286
+ dtype = get_unsigned_integer_type ((2 - self ._truncate_documents ) * tokens_per_epoch * num_epochs ),
245
287
)
246
- self ._token_cumsum_shuffled .save (token_cumsum_shuffled . numpy ( force = self . _config . gpu ) )
288
+ self ._token_cumsum_shuffled .save (token_cumsum_shuffled )
247
289
self ._document_shuffling .save (
248
- document_shuffling [: (token_cumsum_shuffled .numel () + 1 ) * TOKEN_CUMSUM_RATE ].numpy (
290
+ document_shuffling [: (token_cumsum_shuffled .size + 1 ) * TOKEN_CUMSUM_RATE ].numpy (
249
291
force = self ._config .gpu
250
292
)
251
293
)
252
294
# Free memory
253
- del token_cumsum_shuffled
254
295
del document_shuffling
255
296
256
- if unshuffled_epochs > 0 :
257
- token_cumsum_unshuffled = self ._get_token_cumsum (
258
- document_sizes , offset = 0 , dtype = get_unsigned_integer_type (tokens_per_epoch * num_epochs ).torch
297
+ def _get_token_cumsum (self , sizes : torch .Tensor , offset : int , dtype : DataType ) -> tuple [np .ndarray , int | None ]:
298
+ if self ._truncate_documents :
299
+ # Create the output tensor.
300
+ out = sizes .new_empty (sizes .numel () // TOKEN_CUMSUM_RATE + 1 , dtype = dtype .torch )
301
+ # Get partial sums for regular intervals, excluding the last incomplete interval.
302
+ torch .sum (
303
+ sizes [: sizes .numel () - sizes .numel () % TOKEN_CUMSUM_RATE ].view (- 1 , TOKEN_CUMSUM_RATE ),
304
+ dim = 1 ,
305
+ out = out [1 :],
259
306
)
260
- self ._token_cumsum_unshuffled .save (token_cumsum_unshuffled .numpy (force = self ._config .gpu ))
261
-
262
- def _get_token_cumsum (self , sizes : torch .Tensor , offset : int , dtype : torch .dtype ) -> torch .Tensor :
263
- # Create the output tensor.
264
- out = sizes .new_empty (sizes .numel () // TOKEN_CUMSUM_RATE + 1 , dtype = dtype )
265
- # Get partial sums for regular intervals, excluding the last incomplete interval.
266
- torch .sum (
267
- sizes [: sizes .numel () - sizes .numel () % TOKEN_CUMSUM_RATE ].view (- 1 , TOKEN_CUMSUM_RATE ), dim = 1 , out = out [1 :]
268
- )
269
- # Pad with the begin offset
270
- out [0 ] = offset
271
- # Calculate the cumsum.
272
- out .cumsum_ (0 )
273
- # Crop unnecessary entries.
274
- return out [
275
- : torch .clamp_min_ (
276
- torch .searchsorted (out , self ._num_samples * self ._sequence_length , side = "right" ),
277
- 0 ,
307
+ # Pad with the begin offset
308
+ out [0 ] = offset
309
+ # Calculate the cumsum.
310
+ out .cumsum_ (0 )
311
+ # Crop unnecessary entries.
312
+ out = out [
313
+ : torch .clamp_min_ (
314
+ torch .searchsorted (out , self ._num_samples * self ._sequence_length , side = "right" ),
315
+ 0 ,
316
+ )
317
+ ]
318
+ return out .numpy (force = self ._config .gpu ), None
319
+ else :
320
+ # TODO: dynamically handle int64 or int32 in CPP
321
+ out = build_padded_token_cumsum (
322
+ sizes .cpu ().numpy (), (self ._sequence_length + 1 ), TOKEN_CUMSUM_RATE , offset
278
323
)
279
- ]
324
+ num_tokens = out [- 1 ]
325
+ out = out [:- 1 ][
326
+ : np .clip (np .searchsorted (out , self ._num_samples * (self ._sequence_length + 1 ), side = "right" ), 0 , None )
327
+ ]
328
+ return out , num_tokens
280
329
281
330
def __len__ (self ) -> int :
282
331
return self ._num_samples
@@ -288,7 +337,9 @@ def __getitem__(self, index: int) -> typing.Any:
288
337
The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`).
289
338
"""
290
339
self ._lazy_load ()
291
- token_start = index * self ._sequence_length
340
+ # tokens at the boundary are included in only one sample when we pack without truncations
341
+ # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample
342
+ token_start = index * (self ._sequence_length + 1 - self ._truncate_documents )
292
343
token_end = token_start + self ._sequence_length + 1
293
344
294
345
if token_start < self ._unshuffled_tokens :
@@ -302,6 +353,7 @@ def __getitem__(self, index: int) -> typing.Any:
302
353
token_start_cumsum_index = np .searchsorted (token_start_array , token_start , side = "right" ).item () - 1
303
354
304
355
document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset
356
+
305
357
token_count = token_start_array [token_start_cumsum_index ]
306
358
307
359
token_ids = []
@@ -314,6 +366,25 @@ def __getitem__(self, index: int) -> typing.Any:
314
366
document_index = self ._document_shuffling [document_sampling_index - self ._unshuffled_documents ].item ()
315
367
316
368
document_size = self ._indexed_dataset .get_document_size (document_index )
369
+
370
+ if not self ._truncate_documents :
371
+ if document_size > self ._sequence_length + 1 :
372
+ # Document too long, ignore
373
+ document_sampling_index += 1
374
+ continue
375
+ tokens_in_sample = token_count % (self ._sequence_length + 1 )
376
+ if document_size + tokens_in_sample > self ._sequence_length + 1 :
377
+ # Document belongs to the next sample, need to account for padding.
378
+ padding_size = self ._sequence_length + 1 - tokens_in_sample
379
+ if token_count > token_start :
380
+ # Add padding tokens to current sample
381
+ token_ids .append (np .full ((padding_size ,), - 100 , dtype = np .int64 ))
382
+ Assert .eq (token_count + padding_size , token_end )
383
+ break
384
+ else :
385
+ # Move on to the next sample.
386
+ token_count += padding_size
387
+
317
388
# Determine if the document belongs to the requested sample.
318
389
if token_count + document_size >= token_start :
319
390
# Determine which part of the document belong to the sample, and add it to the list.
@@ -343,7 +414,9 @@ def __getitem__(self, index: int) -> typing.Any:
343
414
)
344
415
token_ids = np .concatenate (token_ids , dtype = np .int64 )
345
416
loss_masking_spans = (
346
- np .stack (loss_masking_spans , dtype = np .int32 ) if self ._config .use_loss_masking_spans else None
417
+ (np .stack (loss_masking_spans , dtype = np .int32 ) if loss_masking_spans else np .array ([]))
418
+ if self ._config .use_loss_masking_spans
419
+ else None
347
420
)
348
421
Assert .eq (len (token_ids ), self ._sequence_length + 1 )
349
422
@@ -357,9 +430,12 @@ def _lazy_load(self):
357
430
if not hasattr (self , "_documents_per_epoch" ):
358
431
self ._load_yaml_data (yaml .safe_load (self ._yaml_path .open ("r" )))
359
432
360
- def _load_yaml_data (self , data : dict [str , typing .Any ]):
433
+ def _load_yaml_data (self , data : dict [str , typing .Any ]) -> None :
361
434
self ._documents_per_epoch = data ["dataset" ]["documents_per_epoch" ]
362
- self ._unshuffled_tokens = data ["unshuffled_epochs" ] * data ["dataset" ]["tokens_per_epoch" ]
435
+ if unshuffled_tokens := data .get ("unshuffled_tokens" ) is not None :
436
+ self ._unshuffled_tokens = unshuffled_tokens
437
+ else :
438
+ self ._unshuffled_tokens = data ["unshuffled_epochs" ] * data ["dataset" ]["tokens_per_epoch" ]
363
439
self ._unshuffled_documents = data ["unshuffled_epochs" ] * self ._documents_per_epoch
364
440
365
441
@@ -380,9 +456,12 @@ def __init__(
380
456
self ._indexed_dataset = indexed_dataset
381
457
self ._num_samples = sampling .num_samples
382
458
self ._sequence_length = sampling .sequence_length
459
+ if not sampling .truncate_documents :
460
+ raise NotImplementedError (
461
+ "Legacy sampling only supports document truncation. Please use the latest dataset format."
462
+ )
383
463
self ._cross_document_attention = sampling .cross_document_attention
384
464
self ._config = sampling .config
385
- self ._tokenizer = sampling .tokenizer
386
465
387
466
if sampling .cache_directory is None :
388
467
log_main_rank (
@@ -498,7 +577,7 @@ def __getitem__(self, idx: int) -> typing.Any:
498
577
for span in sample .loss_masking_spans :
499
578
spans .append (span + offset )
500
579
offset += len (sample .token_ids )
501
- spans = np .stack (spans , dtype = np .int32 )
580
+ spans = np .stack (spans , dtype = np .int32 ) if spans else np . array ([])
502
581
else :
503
582
spans = None
504
583
sequence_lengths = (
0 commit comments