@@ -34,13 +34,17 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
34
34
self ._name = name
35
35
self ._prefix = pathlib .Path (prefix )
36
36
self ._has_spans = 0
37
+ self ._has_preference_spans = False
37
38
38
39
with self ._prefix .with_suffix (".idx" ).open ("rb" ) as stream :
39
40
Assert .eq (stream .read (9 ), MEMMAP_INDEX_HEADER )
40
41
self ._version = struct .unpack ("<Q" , stream .read (8 ))[0 ]
41
- assert self ._version in [1 , 2 ], f"Unsupported version for gpt_memmap dataset: { self ._version } ."
42
+ assert self ._version in [1 , 2 , 3 ], f"Unsupported version for gpt_memmap dataset: { self ._version } ."
42
43
if self ._version == 2 :
43
44
self ._has_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
45
+ if self ._version == 3 :
46
+ self ._has_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
47
+ self ._has_preference_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
44
48
45
49
self ._dtype = MEMMAP_DTYPES [struct .unpack ("<B" , stream .read (1 ))[0 ]].numpy
46
50
self ._num_documents = struct .unpack ("<Q" , stream .read (8 ))[0 ]
@@ -52,16 +56,21 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
52
56
53
57
self ._index_bin_buffer_mmap = np .memmap (self ._prefix .with_suffix (".idx" ), mode = "r" , order = "C" )
54
58
self ._index_bin_buffer = memoryview (self ._index_bin_buffer_mmap )
59
+
60
+ # read document sizes
55
61
self ._document_sizes = np .frombuffer (
56
62
self ._index_bin_buffer , dtype = np .int32 , count = self ._num_documents , offset = offset
57
63
)
64
+
65
+ # read pointers
58
66
self ._pointers = np .frombuffer (
59
67
self ._index_bin_buffer ,
60
68
dtype = np .int64 ,
61
69
count = self ._num_documents ,
62
70
offset = offset + self ._document_sizes .nbytes ,
63
71
)
64
72
73
+ # read spans
65
74
self ._spans = None
66
75
if self ._has_spans and self ._version == 2 :
67
76
self ._spans = []
@@ -83,6 +92,34 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
83
92
).reshape (- 1 , 2 )
84
93
)
85
94
95
+ # read preference spans
96
+ self ._chosen_spans = None
97
+ self ._rejected_spans = None
98
+ if self ._has_preference_spans :
99
+ self ._chosen_spans = []
100
+ self ._rejected_spans = []
101
+ chosen_span_offset = offset + self ._document_sizes .nbytes + self ._pointers .nbytes
102
+ for idx in range (self ._num_documents ):
103
+ self ._chosen_spans .append (
104
+ np .frombuffer (
105
+ self ._index_bin_buffer ,
106
+ dtype = np .int32 ,
107
+ count = 2 ,
108
+ offset = chosen_span_offset + idx * 2 * np .dtype (np .int32 ).itemsize ,
109
+ ).reshape (- 1 , 2 )
110
+ )
111
+
112
+ rejected_span_offset = offset + self ._document_sizes .nbytes + self ._pointers .nbytes + np .array (self ._chosen_spans ).nbytes
113
+ for idx in range (self ._num_documents ):
114
+ self ._rejected_spans .append (
115
+ np .frombuffer (
116
+ self ._index_bin_buffer ,
117
+ dtype = np .int32 ,
118
+ count = 2 ,
119
+ offset = rejected_span_offset + idx * 2 * np .dtype (np .int32 ).itemsize ,
120
+ ).reshape (- 1 , 2 )
121
+ )
122
+
86
123
self ._bin_buffer_mmap = np .memmap (self ._prefix .with_suffix (".bin" ), mode = "r" , order = "C" )
87
124
self ._bin_buffer = memoryview (self ._bin_buffer_mmap )
88
125
@@ -105,7 +142,7 @@ def __del__(self):
105
142
del self ._index_bin_buffer_mmap
106
143
107
144
def get (
108
- self , idx : int , offset : int = 0 , length : int | None = None , use_loss_masking_spans : bool = False
145
+ self , idx : int , offset : int = 0 , length : int | None = None , use_loss_masking_spans : bool = False , use_preference_loss_masking_spans : bool = False
109
146
) -> GPTSample :
110
147
token_ids = np .frombuffer (
111
148
self ._bin_buffer ,
@@ -116,13 +153,47 @@ def get(
116
153
sample_spans = None
117
154
if use_loss_masking_spans and self ._spans is not None :
118
155
sample_spans = self ._spans [idx ]
119
- # adjust the spans for the offset and length
156
+
157
+ # filter spans that are outside the range of the selected tokens in the document
120
158
sample_spans = sample_spans [
121
159
(sample_spans [:, 0 ] < offset + len (token_ids )) & (sample_spans [:, 1 ] >= offset )
122
160
]
123
- sample_spans [:, 0 ] = np .maximum (sample_spans [:, 0 ], offset ) - offset
161
+
162
+ # subtract by offset to normalize span boundaries
163
+ sample_spans [:, 0 ] = np .maximum (sample_spans [:, 0 ], offset ) - offset # offset
124
164
sample_spans [:, 1 ] = np .minimum (sample_spans [:, 1 ], offset + len (token_ids ) - 1 ) - offset
125
- return GPTSample (token_ids = token_ids , loss_masking_spans = sample_spans )
165
+
166
+ chosen_spans = None
167
+ rejected_spans = None
168
+ if use_preference_loss_masking_spans and self ._chosen_spans is not None and self ._rejected_spans is not None :
169
+ chosen_spans = self ._chosen_spans [idx ]
170
+
171
+ # filter spans that are outside the range of the selected tokens in the document
172
+ chosen_sample_spans = chosen_spans [
173
+ (chosen_spans [:, 0 ] < offset + len (token_ids )) & (chosen_spans [:, 1 ] >= offset )
174
+ ]
175
+
176
+ # subtract by offset to normalize span boundaries
177
+ chosen_spans [:, 0 ] = np .maximum (chosen_spans [:, 0 ], offset ) - offset # offset
178
+ chosen_spans [:, 1 ] = np .minimum (chosen_spans [:, 1 ], offset + len (token_ids ) - 1 ) - offset
179
+
180
+ rejected_spans = self ._rejected_spans [idx ]
181
+
182
+ # filter spans that are outside the range of the selected tokens in the document
183
+ rejected_sample_spans = rejected_spans [
184
+ (rejected_spans [:, 0 ] < offset + len (token_ids )) & (rejected_spans [:, 1 ] >= offset )
185
+ ]
186
+
187
+ # subtract by offset to normalize span boundaries
188
+ rejected_spans [:, 0 ] = np .maximum (rejected_spans [:, 0 ], offset ) - offset # offset
189
+ rejected_spans [:, 1 ] = np .minimum (rejected_spans [:, 1 ], offset + len (token_ids ) - 1 ) - offset
190
+
191
+ return GPTSample (
192
+ token_ids = token_ids ,
193
+ loss_masking_spans = sample_spans ,
194
+ chosen_loss_masking_spans = chosen_sample_spans ,
195
+ rejected_loss_masking_spans = rejected_sample_spans
196
+ )
126
197
127
198
@property
128
199
def name (self ) -> str :
@@ -157,6 +228,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
157
228
# number of spans for each document
158
229
num_spans = []
159
230
spans = []
231
+ chosen_spans = []
232
+ rejected_spans = []
160
233
161
234
prefix = pathlib .Path (prefix )
162
235
prefix .parent .mkdir (parents = True , exist_ok = True )
@@ -182,6 +255,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
182
255
if document .loss_masking_spans is not None :
183
256
num_spans .append (len (document .loss_masking_spans ))
184
257
spans .append (document .loss_masking_spans )
258
+ if document .chosen_loss_masking_spans is not None :
259
+ chosen_spans .append (document .chosen_loss_masking_spans )
260
+ if document .rejected_loss_masking_spans is not None :
261
+ rejected_spans .append (document .rejected_loss_masking_spans )
185
262
offset += doc_length * np .dtype (dtype ).itemsize
186
263
num_documents += 1
187
264
@@ -193,15 +270,26 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
193
270
spans = np .vstack (spans , dtype = np .int32 )
194
271
else :
195
272
spans = np .array (spans , dtype = np .int32 )
273
+ # if len(chosen_spans) > 0:
274
+ # chosen_spans = np.vstack(chosen_spans, dtype=np.int32)
275
+ # else:
276
+ chosen_spans = np .array (chosen_spans , dtype = np .int32 ).reshape (- 1 , 2 )
277
+ # if len(rejected_spans) > 0:
278
+ # rejected_spans = np.vstack(rejected_spans, dtype=np.int32)
279
+ # else:
280
+ rejected_spans = np .array (rejected_spans , dtype = np .int32 ).reshape (- 1 , 2 )
196
281
197
282
# Write the index file (.idx)
198
283
with prefix .with_suffix (".idx" ).open ("wb" ) as idx_stream :
199
284
idx_stream .write (MEMMAP_INDEX_HEADER )
200
285
# Indicates the version
201
286
# Version 2 optionally adds loss-masking spans
202
- idx_stream .write (struct .pack ("<Q" , 2 ))
287
+ # Version 3 optionally adds chosen/rejected spans
288
+ idx_stream .write (struct .pack ("<Q" , 3 ))
203
289
# Flag to indicate whether loss-masking spans are present
204
290
idx_stream .write (struct .pack ("<B" , 1 if spans .size > 0 else 0 ))
291
+ # Flag to indicate whether preference loss-masking spans are present
292
+ idx_stream .write (struct .pack ("<B" , 1 if chosen_spans .size > 0 and rejected_spans .size > 0 else 0 ))
205
293
# Data type
206
294
idx_stream .write (struct .pack ("<B" , MEMMAP_DTYPES_INV [DataType .from_numpy (dtype .type )]))
207
295
# "Number of sequences", same as documents in our case
@@ -216,5 +304,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
216
304
idx_stream .write (num_spans .tobytes (order = "C" ))
217
305
# Span indices for each document
218
306
idx_stream .write (spans .tobytes (order = "C" ))
307
+ # Chosen indices for each document
308
+ idx_stream .write (chosen_spans .tobytes (order = "C" ))
309
+ # Rejected indices for each document
310
+ idx_stream .write (rejected_spans .tobytes (order = "C" ))
219
311
# Document indices, unused but needed for compatibility with Megatron-LM
220
312
idx_stream .write (np .arange (num_documents + 1 , dtype = np .int64 ).tobytes (order = "C" ))
0 commit comments