Skip to content

Commit 818a162

Browse files
author
Toby Liang
committed
initial dpo updates
1 parent fae0102 commit 818a162

File tree

8 files changed

+309
-91
lines changed

8 files changed

+309
-91
lines changed

fast_llm/data/dataset/gpt/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ class GPTSamplingConfig(SamplingConfig):
5757
desc="Read loss masking spans from the dataset.",
5858
hint=FieldHint.feature,
5959
)
60+
use_preference_loss_masking_spans: bool | None = Field(
61+
default=None,
62+
desc="Read preference loss masking spans from the dataset.",
63+
hint=FieldHint.feature,
64+
)
65+
enable_packing: bool | None = Field(
66+
default=True,
67+
desc="Whether to enable packing or not.",
68+
hint=FieldHint.feature,
69+
)
6070
shuffle: ShufflingType | None = Field(
6171
default=None,
6272
desc="Shuffling strategy.",

fast_llm/data/dataset/gpt/fim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def __init__(
2020
):
2121
if sampling.config.use_loss_masking_spans:
2222
raise NotImplementedError("FIM is currently not compatible with loss masking.")
23+
if sampling.config.use_preference_loss_masking_spans:
24+
raise NotImplementedError("FIM is currently not compatible with preference loss masking.")
2325
self._config = config
2426
self._dataset = dataset
2527
self._seed = sampling.config.seed

fast_llm/data/dataset/gpt/memmap.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
3434
self._name = name
3535
self._prefix = pathlib.Path(prefix)
3636
self._has_spans = 0
37+
self._has_preference_spans = False
3738

3839
with self._prefix.with_suffix(".idx").open("rb") as stream:
3940
Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER)
4041
self._version = struct.unpack("<Q", stream.read(8))[0]
41-
assert self._version in [1, 2], f"Unsupported version for gpt_memmap dataset: {self._version}."
42+
assert self._version in [1, 2, 3], f"Unsupported version for gpt_memmap dataset: {self._version}."
4243
if self._version == 2:
4344
self._has_spans = struct.unpack("<B", stream.read(1))[0]
45+
if self._version == 3:
46+
self._has_spans = struct.unpack("<B", stream.read(1))[0]
47+
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]
4448

4549
self._dtype = MEMMAP_DTYPES[struct.unpack("<B", stream.read(1))[0]].numpy
4650
self._num_documents = struct.unpack("<Q", stream.read(8))[0]
@@ -52,16 +56,21 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
5256

5357
self._index_bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".idx"), mode="r", order="C")
5458
self._index_bin_buffer = memoryview(self._index_bin_buffer_mmap)
59+
60+
# read document sizes
5561
self._document_sizes = np.frombuffer(
5662
self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset
5763
)
64+
65+
# read pointers
5866
self._pointers = np.frombuffer(
5967
self._index_bin_buffer,
6068
dtype=np.int64,
6169
count=self._num_documents,
6270
offset=offset + self._document_sizes.nbytes,
6371
)
6472

73+
# read spans
6574
self._spans = None
6675
if self._has_spans and self._version == 2:
6776
self._spans = []
@@ -83,6 +92,34 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
8392
).reshape(-1, 2)
8493
)
8594

95+
# read preference spans
96+
self._chosen_spans = None
97+
self._rejected_spans = None
98+
if self._has_preference_spans:
99+
self._chosen_spans = []
100+
self._rejected_spans = []
101+
chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes
102+
for idx in range(self._num_documents):
103+
self._chosen_spans.append(
104+
np.frombuffer(
105+
self._index_bin_buffer,
106+
dtype=np.int32,
107+
count=2,
108+
offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
109+
).reshape(-1, 2)
110+
)
111+
112+
rejected_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes
113+
for idx in range(self._num_documents):
114+
self._rejected_spans.append(
115+
np.frombuffer(
116+
self._index_bin_buffer,
117+
dtype=np.int32,
118+
count=2,
119+
offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
120+
).reshape(-1, 2)
121+
)
122+
86123
self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C")
87124
self._bin_buffer = memoryview(self._bin_buffer_mmap)
88125

@@ -105,7 +142,7 @@ def __del__(self):
105142
del self._index_bin_buffer_mmap
106143

107144
def get(
108-
self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False
145+
self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, use_preference_loss_masking_spans: bool = False
109146
) -> GPTSample:
110147
token_ids = np.frombuffer(
111148
self._bin_buffer,
@@ -116,13 +153,47 @@ def get(
116153
sample_spans = None
117154
if use_loss_masking_spans and self._spans is not None:
118155
sample_spans = self._spans[idx]
119-
# adjust the spans for the offset and length
156+
157+
# filter spans that are outside the range of the selected tokens in the document
120158
sample_spans = sample_spans[
121159
(sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset)
122160
]
123-
sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset
161+
162+
# subtract by offset to normalize span boundaries
163+
sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset
124164
sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset
125-
return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans)
165+
166+
chosen_spans = None
167+
rejected_spans = None
168+
if use_preference_loss_masking_spans and self._chosen_spans is not None and self._rejected_spans is not None:
169+
chosen_spans = self._chosen_spans[idx]
170+
171+
# filter spans that are outside the range of the selected tokens in the document
172+
chosen_sample_spans = chosen_spans[
173+
(chosen_spans[:, 0] < offset + len(token_ids)) & (chosen_spans[:, 1] >= offset)
174+
]
175+
176+
# subtract by offset to normalize span boundaries
177+
chosen_spans[:, 0] = np.maximum(chosen_spans[:, 0], offset) - offset # offset
178+
chosen_spans[:, 1] = np.minimum(chosen_spans[:, 1], offset + len(token_ids) - 1) - offset
179+
180+
rejected_spans = self._rejected_spans[idx]
181+
182+
# filter spans that are outside the range of the selected tokens in the document
183+
rejected_sample_spans = rejected_spans[
184+
(rejected_spans[:, 0] < offset + len(token_ids)) & (rejected_spans[:, 1] >= offset)
185+
]
186+
187+
# subtract by offset to normalize span boundaries
188+
rejected_spans[:, 0] = np.maximum(rejected_spans[:, 0], offset) - offset # offset
189+
rejected_spans[:, 1] = np.minimum(rejected_spans[:, 1], offset + len(token_ids) - 1) - offset
190+
191+
return GPTSample(
192+
token_ids=token_ids,
193+
loss_masking_spans=sample_spans,
194+
chosen_loss_masking_spans=chosen_sample_spans,
195+
rejected_loss_masking_spans=rejected_sample_spans
196+
)
126197

127198
@property
128199
def name(self) -> str:
@@ -157,6 +228,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
157228
# number of spans for each document
158229
num_spans = []
159230
spans = []
231+
chosen_spans = []
232+
rejected_spans = []
160233

161234
prefix = pathlib.Path(prefix)
162235
prefix.parent.mkdir(parents=True, exist_ok=True)
@@ -182,6 +255,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
182255
if document.loss_masking_spans is not None:
183256
num_spans.append(len(document.loss_masking_spans))
184257
spans.append(document.loss_masking_spans)
258+
if document.chosen_loss_masking_spans is not None:
259+
chosen_spans.append(document.chosen_loss_masking_spans)
260+
if document.rejected_loss_masking_spans is not None:
261+
rejected_spans.append(document.rejected_loss_masking_spans)
185262
offset += doc_length * np.dtype(dtype).itemsize
186263
num_documents += 1
187264

@@ -193,15 +270,26 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
193270
spans = np.vstack(spans, dtype=np.int32)
194271
else:
195272
spans = np.array(spans, dtype=np.int32)
273+
# if len(chosen_spans) > 0:
274+
# chosen_spans = np.vstack(chosen_spans, dtype=np.int32)
275+
# else:
276+
chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2)
277+
# if len(rejected_spans) > 0:
278+
# rejected_spans = np.vstack(rejected_spans, dtype=np.int32)
279+
# else:
280+
rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2)
196281

197282
# Write the index file (.idx)
198283
with prefix.with_suffix(".idx").open("wb") as idx_stream:
199284
idx_stream.write(MEMMAP_INDEX_HEADER)
200285
# Indicates the version
201286
# Version 2 optionally adds loss-masking spans
202-
idx_stream.write(struct.pack("<Q", 2))
287+
# Version 3 optionally adds chosen/rejected spans
288+
idx_stream.write(struct.pack("<Q", 3))
203289
# Flag to indicate whether loss-masking spans are present
204290
idx_stream.write(struct.pack("<B", 1 if spans.size > 0 else 0))
291+
# Flag to indicate whether preference loss-masking spans are present
292+
idx_stream.write(struct.pack("<B", 1 if chosen_spans.size > 0 and rejected_spans.size > 0 else 0))
205293
# Data type
206294
idx_stream.write(struct.pack("<B", MEMMAP_DTYPES_INV[DataType.from_numpy(dtype.type)]))
207295
# "Number of sequences", same as documents in our case
@@ -216,5 +304,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
216304
idx_stream.write(num_spans.tobytes(order="C"))
217305
# Span indices for each document
218306
idx_stream.write(spans.tobytes(order="C"))
307+
# Chosen indices for each document
308+
idx_stream.write(chosen_spans.tobytes(order="C"))
309+
# Rejected indices for each document
310+
idx_stream.write(rejected_spans.tobytes(order="C"))
219311
# Document indices, unused but needed for compatibility with Megatron-LM
220312
idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C"))

0 commit comments

Comments
 (0)