Skip to content

Commit 40c96c8

Browse files
committed
dataset changes for dpo
1 parent 422a78b commit 40c96c8

File tree

4 files changed

+66
-29
lines changed

4 files changed

+66
-29
lines changed

fast_llm/data/data/gpt/data.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,34 @@ class GPTBatch:
3232
token_ids: torch.Tensor
3333
loss_masking_spans: list[torch.Tensor] | None = None
3434
sequence_lengths: list[torch.Tensor] | None = None
35+
chosen_loss_masking_spans: list[torch.Tensor] | None = None
36+
rejected_loss_masking_spans: list[torch.Tensor] | None = None
3537

3638

3739
def gpt_data_collate_fn(
38-
batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool
40+
batch: list[GPTSample],
41+
use_loss_masking_spans: bool,
42+
cross_document_attention: bool,
43+
use_preference_loss_masking_spans: bool
3944
) -> GPTBatch:
4045
stacked_ids = np.stack([sample.token_ids for sample in batch])
4146
stacked_spans = None
4247
sequence_lengths = None
48+
stacked_chosen_spans = None
49+
stacked_rejected_spans = None
4350
if use_loss_masking_spans:
4451
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
52+
if use_preference_loss_masking_spans:
53+
stacked_chosen_spans = [torch.from_numpy(sample.chosen_loss_masking_spans) for sample in batch]
54+
stacked_rejected_spans= [torch.from_numpy(sample.rejected_loss_masking_spans) for sample in batch]
4555
if not cross_document_attention:
4656
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
4757
return GPTBatch(
48-
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
58+
token_ids=torch.from_numpy(stacked_ids),
59+
loss_masking_spans=stacked_spans,
60+
sequence_lengths=sequence_lengths,
61+
chosen_loss_masking_spans=stacked_chosen_spans,
62+
rejected_loss_masking_spans=stacked_rejected_spans
4963
)
5064

5165

@@ -169,6 +183,7 @@ def get_iterator(
169183
gpt_data_collate_fn,
170184
use_loss_masking_spans=self._config.sampling.use_loss_masking_spans,
171185
cross_document_attention=self._cross_document_attention,
186+
use_preference_loss_masking_spans=self._config.sampling.use_preference_loss_masking_spans
172187
),
173188
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
174189
)

fast_llm/data/dataset/gpt/memmap.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
106106
dtype=np.int32,
107107
count=2,
108108
offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
109-
).reshape(-1, 2)
109+
)
110110
)
111111

112112
rejected_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes
@@ -117,7 +117,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
117117
dtype=np.int32,
118118
count=2,
119119
offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
120-
).reshape(-1, 2)
120+
)
121121
)
122122

123123
self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C")
@@ -169,30 +169,30 @@ def get(
169169
chosen_spans = self._chosen_spans[idx]
170170

171171
# filter spans that are outside the range of the selected tokens in the document
172-
chosen_sample_spans = chosen_spans[
173-
(chosen_spans[:, 0] < offset + len(token_ids)) & (chosen_spans[:, 1] >= offset)
174-
]
172+
chosen_spans = chosen_spans[
173+
(chosen_spans[0] < offset + len(token_ids)) & (chosen_spans[1] >= offset)
174+
][0]
175175

176176
# subtract by offset to normalize span boundaries
177-
chosen_spans[:, 0] = np.maximum(chosen_spans[:, 0], offset) - offset # offset
178-
chosen_spans[:, 1] = np.minimum(chosen_spans[:, 1], offset + len(token_ids) - 1) - offset
177+
chosen_spans[0] = np.maximum(chosen_spans[0], offset) - offset # offset
178+
chosen_spans[1] = np.minimum(chosen_spans[1], offset + len(token_ids) - 1) - offset
179179

180180
rejected_spans = self._rejected_spans[idx]
181181

182182
# filter spans that are outside the range of the selected tokens in the document
183-
rejected_sample_spans = rejected_spans[
184-
(rejected_spans[:, 0] < offset + len(token_ids)) & (rejected_spans[:, 1] >= offset)
185-
]
183+
rejected_spans = rejected_spans[
184+
(rejected_spans[0] < offset + len(token_ids)) & (rejected_spans[1] >= offset)
185+
][0]
186186

187187
# subtract by offset to normalize span boundaries
188-
rejected_spans[:, 0] = np.maximum(rejected_spans[:, 0], offset) - offset # offset
189-
rejected_spans[:, 1] = np.minimum(rejected_spans[:, 1], offset + len(token_ids) - 1) - offset
188+
rejected_spans[0] = np.maximum(rejected_spans[0], offset) - offset # offset
189+
rejected_spans[1] = np.minimum(rejected_spans[1], offset + len(token_ids) - 1) - offset
190190

191191
return GPTSample(
192192
token_ids=token_ids,
193193
loss_masking_spans=sample_spans,
194-
chosen_loss_masking_spans=chosen_sample_spans,
195-
rejected_loss_masking_spans=rejected_sample_spans
194+
chosen_loss_masking_spans=chosen_spans,
195+
rejected_loss_masking_spans=rejected_spans
196196
)
197197

198198
@property

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def __init__(
120120
# contains cumulative sum of document sizes grouped by TOKEN_CUMSUM_RATE in shuffled order
121121
self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy"))
122122
self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy"))
123+
124+
self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy"))
125+
123126
self._yaml_path = base_path.with_suffix(".yaml")
124127
# Sample or validate the dataset of a given rank.
125128
if sampling.distributed.config.rank == sampling.get_next_rank():
@@ -132,11 +135,11 @@ def _sample(self) -> None:
132135
Create a `GPTSampledDataset` with the requested parameters.
133136
"""
134137
# Get the document sizes, the main information needed for sampling.
135-
self.document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device)
138+
document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device)
136139

137140
# Calculate basic stats.
138-
documents_per_epoch = self.document_sizes.numel()
139-
tokens_per_epoch = self.document_sizes.sum().item()
141+
documents_per_epoch = document_sizes.numel()
142+
tokens_per_epoch = document_sizes.sum().item()
140143
# We produce sequences of length `self._sequence_length + 1` so the last token has a label,
141144
# but we also include that last label in the following sample,
142145
# so we need `sequence_length * num_samples + 1` tokens in total.
@@ -160,7 +163,7 @@ def _sample(self) -> None:
160163
"dataset": {
161164
"name": self._indexed_dataset.name,
162165
"documents_per_epoch": documents_per_epoch,
163-
"tokens_per_epoch": tokens_per_epoch,
166+
"tokens_per_epoch": tokens_per_epoch
164167
},
165168
"num_samples": self._num_samples,
166169
"unshuffled_epochs": unshuffled_epochs,
@@ -247,7 +250,7 @@ def _sample(self) -> None:
247250
if self._config.enable_packing:
248251
if shuffled_epochs > 0:
249252
token_cumsum_shuffled = self._get_token_cumsum(
250-
self.document_sizes[
253+
document_sizes[
251254
# Torch indexing only works with int32 or int64
252255
document_shuffling.to(
253256
dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32
@@ -268,15 +271,17 @@ def _sample(self) -> None:
268271

269272
if unshuffled_epochs > 0:
270273
token_cumsum_unshuffled = self._get_token_cumsum(
271-
self.document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch
274+
document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch
272275
)
273276
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled.numpy(force=self._config.gpu))
274277
else:
275-
self._document_shuffling.save(
276-
document_shuffling[:self._num_samples].numpy(
277-
force=self._config.gpu
278+
if shuffled_epochs > 0:
279+
self._document_shuffling.save(
280+
document_shuffling[:self._num_samples].numpy(
281+
force=self._config.gpu
282+
)
278283
)
279-
)
284+
self._document_sizes.save(document_sizes.numpy(force=self._config.gpu))
280285

281286
def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: torch.dtype) -> torch.Tensor:
282287
# Create the output tensor.
@@ -385,11 +390,15 @@ def __getitem__(self, index: int) -> typing.Any:
385390
sample = self._indexed_dataset.get(
386391
document_index,
387392
offset=0,
388-
length=self.document_sizes[document_index],
393+
length=self._document_sizes[document_index],
389394
use_loss_masking_spans=self._config.use_loss_masking_spans,
390395
use_preference_loss_masking_spans=self._config.use_preference_loss_masking_spans
391396
)
392397

398+
chosen_loss_masking_span_end = sample.chosen_loss_masking_spans[1] + 1
399+
sequence_lengths = np.array([chosen_loss_masking_span_end, len(sample.token_ids) - chosen_loss_masking_span_end])
400+
sample.sequence_lengths = sequence_lengths
401+
393402
return sample
394403

395404
@property

fast_llm/data/tokenizer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(self, config: TokenizerConfig):
2222
raise ValueError("Tokenizer does not have an BOS token.")
2323
self.eod_id = self.tokenizer.eos_token_id
2424
self.bod_id = self.tokenizer.bos_token_id
25+
self.eod_token = self.tokenizer.eos_token
26+
self.bod_token = self.tokenizer.bos_token
2527

2628
@property
2729
def vocab_size(self) -> int:
@@ -52,6 +54,9 @@ def tokenize_with_spans(
5254
token_spans = []
5355
char_pos = 0
5456
beginning_of_text = True
57+
if text.startswith(self.bod_token):
58+
beginning_of_text = False
59+
5560
for start, end in char_spans:
5661
if char_pos < start:
5762
curr_text = text[char_pos:start]
@@ -60,7 +65,11 @@ def tokenize_with_spans(
6065
input_ids.extend(tokenized_text)
6166
curr_text = text[start : end + 1]
6267
if end >= len(text) - 1:
63-
tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True)
68+
tokenized_text = self.tokenize(
69+
curr_text,
70+
begin=beginning_of_text,
71+
end=True if not curr_text.endswith(self.eod_token) else False
72+
)
6473
else:
6574
tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False)
6675
beginning_of_text = False
@@ -69,7 +78,11 @@ def tokenize_with_spans(
6978
char_pos = end + 1
7079
if char_pos < len(text):
7180
curr_text = text[char_pos:]
72-
tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True)
81+
tokenized_text = self.tokenize(
82+
curr_text,
83+
begin=beginning_of_text,
84+
end=True if not curr_text.endswith(self.eod_token) else False
85+
)
7386
input_ids.extend(tokenized_text)
7487
return input_ids, token_spans
7588

0 commit comments

Comments
 (0)