Skip to content

DPO #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 56 commits into from
May 13, 2025
Merged

DPO #223

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
818a162
initial dpo updates
Mar 28, 2025
422a78b
Merge branch 'main' into toby/dpo
Mar 28, 2025
40c96c8
dataset changes for dpo
tobyzl2 Apr 3, 2025
f7796d4
adding dpo loss
tobyzl2 Apr 3, 2025
54b686a
Merge remote-tracking branch 'origin/main' into toby/dpo
tobyzl2 Apr 3, 2025
3c0199f
packing disabled filter sequennces longer than seq length
tobyzl2 Apr 4, 2025
0e1335b
disable no packing for legacy sampling
tobyzl2 Apr 4, 2025
0e09098
adding dpo tests
tobyzl2 Apr 9, 2025
edca385
Merge branch 'main' of https://github.com/ServiceNow/Fast-LLM into to…
tobyzl2 Apr 9, 2025
1075176
small fix
tobyzl2 Apr 10, 2025
4156349
span tokenization updates
tobyzl2 Apr 10, 2025
9669211
enable chosen/rejected text for preparator
tobyzl2 Apr 10, 2025
257d236
removing assert
tobyzl2 Apr 10, 2025
aa8a871
moving dpo loss call
tobyzl2 Apr 10, 2025
d08bf4d
renaming
tobyzl2 Apr 10, 2025
b410210
padding fix
tobyzl2 Apr 13, 2025
366a20b
dpo config changes
tobyzl2 Apr 13, 2025
dca842e
memmap version fixes
tobyzl2 Apr 13, 2025
ca86694
removing dpo flags and new sampling class
tobyzl2 Apr 14, 2025
aa94f9a
removing extra lines
tobyzl2 Apr 14, 2025
7f37038
small data configuration updates
tobyzl2 Apr 15, 2025
0d7ccbd
update test case
tobyzl2 Apr 15, 2025
5fd1c86
logp span using index instead
tobyzl2 Apr 15, 2025
63db041
small updates
tobyzl2 Apr 15, 2025
dab6dab
small fix
tobyzl2 Apr 15, 2025
41fb3e3
fixing fim
tobyzl2 Apr 15, 2025
3d77986
adding checks for chosen/rej spans in memmap dataset
tobyzl2 Apr 15, 2025
905bc00
refractor to preprocessor
tobyzl2 Apr 16, 2025
1db18f9
merge
tobyzl2 Apr 16, 2025
52c8f9f
moving puse_pref_loss_spans to sampling parameters and combining samp…
tobyzl2 Apr 18, 2025
6ea086b
merge
tobyzl2 Apr 18, 2025
f51eedc
merge
tobyzl2 Apr 21, 2025
9067b6a
dpo loss enabling flag
tobyzl2 Apr 23, 2025
062ce88
check for config compatibility
tobyzl2 Apr 24, 2025
f53ac56
full dpo changes
tobyzl2 Apr 28, 2025
92f28ee
adding distillation model check
tobyzl2 Apr 28, 2025
2b2515f
update dpo test cases
tobyzl2 Apr 28, 2025
bd9142f
FFixing sampled for dpo
tobyzl2 Apr 28, 2025
4f26100
test case fixes
tobyzl2 Apr 28, 2025
41cc7fe
adding preference logps test case
tobyzl2 Apr 29, 2025
a6950f1
small fix
tobyzl2 Apr 29, 2025
fb9803d
higher mbs fixes
tobyzl2 Apr 30, 2025
723f30e
test higher mbs
tobyzl2 Apr 30, 2025
8063a21
small change
tobyzl2 Apr 30, 2025
c3a8ebb
updates
tobyzl2 Apr 30, 2025
db5242f
small changes
tobyzl2 Apr 30, 2025
ab139ca
small changes
tobyzl2 Apr 30, 2025
63041aa
remove comments
tobyzl2 Apr 30, 2025
e1c92f4
Merge branch 'main' of https://github.com/ServiceNow/Fast-LLM into to…
tobyzl2 Apr 30, 2025
e60ad62
maxlen consistency
tobyzl2 May 1, 2025
85613f7
remove comments
tobyzl2 May 1, 2025
2742692
refractoring
tobyzl2 May 2, 2025
29c9a4b
Merge branch 'main' of https://github.com/ServiceNow/Fast-LLM into to…
tobyzl2 May 2, 2025
8b837c0
fix
tobyzl2 May 2, 2025
16136ac
fix
tobyzl2 May 8, 2025
b64626c
merge
tobyzl2 May 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,29 @@ class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None
stacked_chosen_spans = None
stacked_rejected_spans = None
if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
if sampling_parameters.use_preference_loss_spans:
stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch]
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
)


Expand Down Expand Up @@ -149,6 +160,7 @@ def get_iterator(
sampling_parameters = self._sampling_parameters[dataset_name]
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")

return iter(
torch.utils.data.DataLoader(
self._datasets[dataset_name], # noqa
Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class GPTSamplingParameters(SamplingParameters):
sequence_length: int
vocab_size: int
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ def __init__(
):
if sampling.parameters.use_loss_masking_spans:
raise NotImplementedError("FIM is currently not compatible with loss masking.")
if sampling.parameters.use_preference_loss_spans:
raise NotImplementedError("FIM is currently not compatible with preference loss masking.")
self._config = config
self._dataset = dataset

self._seed = sampling.config.seed
self._tokenizer = sampling.tokenizer
if self._tokenizer is None:
Expand Down
114 changes: 106 additions & 8 deletions fast_llm/data/dataset/gpt/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
self._name = name
self._prefix = pathlib.Path(prefix)
self._has_spans = 0
self._has_preference_spans = False

with self._prefix.with_suffix(".idx").open("rb") as stream:
Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}")
self._version = struct.unpack("<Q", stream.read(8))[0]
assert self._version in [1, 2], f"Unsupported version for gpt_memmap dataset: {self._version}."
if self._version == 2:
assert self._version in [1, 2, 3], f"Unsupported version for gpt_memmap dataset: {self._version}."
if self._version >= 2:
self._has_spans = struct.unpack("<B", stream.read(1))[0]
if self._version >= 3:
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]

self._dtype = MEMMAP_DTYPES[struct.unpack("<B", stream.read(1))[0]].numpy
self._num_documents = struct.unpack("<Q", stream.read(8))[0]
Expand All @@ -52,18 +55,23 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None

self._index_bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".idx"), mode="r", order="C")
self._index_bin_buffer = memoryview(self._index_bin_buffer_mmap)

# read document sizes
self._document_sizes = np.frombuffer(
self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset
)

# read pointers
self._pointers = np.frombuffer(
self._index_bin_buffer,
dtype=np.int64,
count=self._num_documents,
offset=offset + self._document_sizes.nbytes,
)

# read spans
self._spans = None
if self._has_spans and self._version == 2:
if self._has_spans and self._version >= 2:
self._spans = []
self._num_spans = np.frombuffer(
self._index_bin_buffer,
Expand All @@ -83,6 +91,36 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
).reshape(-1, 2)
)

# read preference spans
self._chosen_spans = None
self._rejected_spans = None
if self._has_preference_spans and self._version >= 3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just set self._has_preference_spans=False for other versions.

self._chosen_spans = []
self._rejected_spans = []
chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes
for idx in range(self._num_documents):
self._chosen_spans.append(
np.frombuffer(
self._index_bin_buffer,
dtype=np.int32,
count=2,
offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
)
)

rejected_span_offset = (
offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes
)
for idx in range(self._num_documents):
self._rejected_spans.append(
np.frombuffer(
self._index_bin_buffer,
dtype=np.int32,
count=2,
offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
)
)

self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)

Expand All @@ -105,7 +143,12 @@ def __del__(self):
del self._index_bin_buffer_mmap

def get(
self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False
self,
idx: int,
offset: int = 0,
length: int | None = None,
use_loss_masking_spans: bool = False,
use_preference_loss_spans: bool = False,
) -> GPTSample:
token_ids = np.frombuffer(
self._bin_buffer,
Expand All @@ -116,13 +159,53 @@ def get(
sample_spans = None
if use_loss_masking_spans and self._spans is not None:
sample_spans = self._spans[idx]
# adjust the spans for the offset and length

# filter spans that are outside the range of the selected tokens in the document
sample_spans = sample_spans[
(sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset)
]
sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset

# subtract by offset to normalize span boundaries
sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset
sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset
return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans)

chosen_span = None
rejected_span = None

if use_preference_loss_spans:
if not self._has_preference_spans:
raise ValueError("No preference spans found in memmap dataset.")
elif self._has_preference_spans and self._chosen_spans is None:
raise ValueError("Failed to read chosen spans from memmap dataset.")
elif self._has_preference_spans and self._rejected_spans is None:
raise ValueError("Failed to read rejected spans from memmap dataset.")
else:
chosen_span = self._chosen_spans[idx]

# filter spans that are outside the range of the selected tokens in the document
chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0]

# subtract by offset to normalize span boundaries
chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset
chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset

rejected_span = self._rejected_spans[idx]

# filter spans that are outside the range of the selected tokens in the document
rejected_span = rejected_span[
(rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset)
][0]

# subtract by offset to normalize span boundaries
rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset
rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset

return GPTSample(
token_ids=token_ids,
loss_masking_spans=sample_spans,
chosen_span=chosen_span,
rejected_span=rejected_span,
)

@property
def name(self) -> str:
Expand Down Expand Up @@ -157,6 +240,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
# number of spans for each document
num_spans = []
spans = []
chosen_spans = []
rejected_spans = []

prefix = pathlib.Path(prefix)
prefix.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -182,6 +267,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
if document.loss_masking_spans is not None:
num_spans.append(len(document.loss_masking_spans))
spans.append(document.loss_masking_spans)
if document.chosen_span is not None:
chosen_spans.append(document.chosen_span)
if document.rejected_span is not None:
rejected_spans.append(document.rejected_span)
offset += doc_length * np.dtype(dtype).itemsize
num_documents += 1

Expand All @@ -193,15 +282,20 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
spans = np.vstack(spans, dtype=np.int32)
else:
spans = np.array(spans, dtype=np.int32)
chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2)
rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2)

# Write the index file (.idx)
with prefix.with_suffix(".idx").open("wb") as idx_stream:
idx_stream.write(MEMMAP_INDEX_HEADER)
# Indicates the version
# Version 2 optionally adds loss-masking spans
idx_stream.write(struct.pack("<Q", 2))
# Version 3 optionally adds chosen/rejected spans
idx_stream.write(struct.pack("<Q", 3))
# Flag to indicate whether loss-masking spans are present
idx_stream.write(struct.pack("<B", 1 if spans.size > 0 else 0))
# Flag to indicate whether preference loss-masking spans are present
idx_stream.write(struct.pack("<B", 1 if chosen_spans.size > 0 and rejected_spans.size > 0 else 0))
# Data type
idx_stream.write(struct.pack("<B", MEMMAP_DTYPES_INV[DataType.from_numpy(dtype.type)]))
# "Number of sequences", same as documents in our case
Expand All @@ -216,5 +310,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
idx_stream.write(num_spans.tobytes(order="C"))
# Span indices for each document
idx_stream.write(spans.tobytes(order="C"))
# Chosen indices for each document
idx_stream.write(chosen_spans.tobytes(order="C"))
# Rejected indices for each document
idx_stream.write(rejected_spans.tobytes(order="C"))
# Document indices, unused but needed for compatibility with Megatron-LM
idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C"))
78 changes: 74 additions & 4 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
class GPTSample:
token_ids: np.ndarray
loss_masking_spans: np.ndarray | None = None
chosen_span: np.ndarray | None = None
rejected_span: np.ndarray | None = None
sequence_lengths: np.ndarray | None = None


Expand Down Expand Up @@ -112,6 +114,14 @@ def __init__(
self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy"))
self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy"))
self._yaml_path = base_path.with_suffix(".yaml")

# keep document sizes and len filtered docs for preference loss masking
if self._parameters.use_preference_loss_spans:
self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy"))
self._doc_length_filtered_indicies = MemmapArray(
base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy")
)

# Sample or validate the dataset of a given rank.
if sampling.distributed.config.rank == sampling.get_next_rank():
self._sample()
Expand Down Expand Up @@ -145,10 +155,14 @@ def _sample(self) -> None:
raise RuntimeError(
f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}."
)

# We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads,
# but in case of truncations we also include those last labels in the following sample,
# so we need `sequence_length * num_samples + extra_tokens` tokens in total.
if self._truncate_documents:
if self._parameters.use_preference_loss_spans:
documents_per_epoch = (~long_docs_filter).sum().item()
num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch)
elif self._truncate_documents:
num_epochs = math.ceil(
(self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens)
/ tokens_per_epoch
Expand Down Expand Up @@ -187,8 +201,8 @@ def _sample(self) -> None:

if self._yaml_path is not None and self._yaml_path.is_file():
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
self._load_yaml_data(loaded_yaml_data)
if not self._truncate_documents:
self._load_yaml_data(yaml_data)
if not self._truncate_documents and not self._parameters.use_preference_loss_spans:
del loaded_yaml_data["unshuffled_tokens"]

if loaded_yaml_data != yaml_data:
Expand Down Expand Up @@ -251,6 +265,24 @@ def _sample(self) -> None:
else:
raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}")

if self._parameters.use_preference_loss_spans:
yaml_data["unshuffled_tokens"] = 0 # not used, ignore

# index of all documents less than seq length long
doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0]
self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu))

# apply shuffling on doc_length_filtered_indicies
if shuffled_epochs > 0:
self._document_shuffling.save(
document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu)
)
self._document_sizes.save(document_sizes.numpy(force=self._config.gpu))
if self._yaml_path is not None:
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))
return

# To get a sample on the fly we need to know where it begins,
# and this is a non-trivial information because the documents have variable length.
# The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e.
Expand Down Expand Up @@ -349,6 +381,40 @@ def __getitem__(self, index: int) -> typing.Any:
The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`).
"""
self._lazy_load()

if self._parameters.use_preference_loss_spans:
if index < self._unshuffled_documents:
document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch]
else:
document_index = self._doc_length_filtered_indicies[
self._document_shuffling[index - self._unshuffled_documents].item()
]

sample = self._indexed_dataset.get(
document_index,
offset=0,
length=self._document_sizes[document_index],
use_loss_masking_spans=self._parameters.use_loss_masking_spans,
use_preference_loss_spans=self._parameters.use_preference_loss_spans,
)

chosen_span_end = sample.chosen_span[1] + 1
sequence_lengths = [
chosen_span_end,
len(sample.token_ids) - chosen_span_end,
]

# compute padding size
padding = np.full((self._parameters.sequence_length + 1,), 0)
padding[: len(sample.token_ids)] = sample.token_ids
sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids))
sample.token_ids = padding

if not self._parameters.cross_document_attention:
sample.sequence_lengths = np.array(sequence_lengths)

return sample

# tokens at the boundary are included in only one sample when we pack without truncations
# in case of packing with truncations, the last token from the previous sample is also the first token of the next sample
sample_length = (
Expand Down Expand Up @@ -454,7 +520,9 @@ def _lazy_load(self):
def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:
self._documents_per_epoch = data["dataset"]["documents_per_epoch"]

if "unshuffled_tokens" not in data:
if self._parameters.use_preference_loss_spans:
data["unshuffled_tokens"] = 0 # not used, ignore
elif "unshuffled_tokens" not in data:
# Backward compatibility
# TODO v0.x: Remove
assert self._truncate_documents
Expand Down Expand Up @@ -485,6 +553,8 @@ def __init__(
)
self._config = sampling.config
self._parameters = sampling.parameters
if self._parameters.use_preference_loss_spans:
raise NotImplementedError("Legacy sampling does not support preference loss masking.")

if sampling.cache_directory is None:
log_main_rank(
Expand Down
Loading