Skip to content

DPO #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 55 commits into
base: main
Choose a base branch
from
Open

DPO #223

wants to merge 55 commits into from

Conversation

tobyzl2
Copy link

@tobyzl2 tobyzl2 commented Apr 3, 2025

Description

This PR introduces the implementation of Direct Preference Optimization (DPO) training on Fast-LLM. DPO enhances model fine-tuning by directly incorporating user preferences into the optimization process, ensuring that the model better aligns with the desired output behavior.

Closes #209

Throughput Numbers (reference free - 04/14/25)

Model Tok/Sec/GPU Seq Len Num Steps # GPUs
Mistral 7b 10168.015 8192 400 32
Apriel 5b SFT 14722.243 8192 6000 32

Throughput Numbers (with reference model - 04/29/25)

Model Tok/Sec/GPU Seq Len Num Steps # GPUs
Mistral 7b 7837.249 8192 400 32
Apriel 5b SFT 11331.242 8192 6000 32

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

List the key changes introduced in this PR:

  1. introduced DPO training with simplified DPO loss
  2. allows users to configure DPO parameters (beta value)
  3. allows option for packing to be turned off (currently dpo is implemented without packing)

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@tobyzl2 tobyzl2 changed the title Toby/dpo DPO Apr 3, 2025
@tobyzl2 tobyzl2 requested a review from sohamparikh April 3, 2025 22:55
Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks great already, a few functional tests would be good.
maybe extend

def get_test_dataset(

@tobyzl2 tobyzl2 marked this pull request as ready for review April 9, 2025 01:20
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried too many unrelated things need to be set right for things to work (dataset format, sampling config, loss function), and if not things will crash too late with a cryptic error. Let's try to simplify this a bit.

desc="Read preference loss masking spans from the dataset.",
hint=FieldHint.feature,
)
enable_packing: bool | None = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's packing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made some changes to remove this flag, packing would be having multiple documents in the same sequence in pretraining but for DPO wanted to have a way to pad to the end of sequence instead.

@@ -57,6 +57,16 @@ class GPTSamplingConfig(SamplingConfig):
desc="Read loss masking spans from the dataset.",
hint=FieldHint.feature,
)
use_preference_loss_masking_spans: bool | None = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to have both normal and preference loss masking spans?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this might be better suited for GPTSamplingData. This way the trainer could set this value based on the training objective and avoid the complicated relationship between the parameters.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just made some changes to remove this extra flag, it's reading it automatically from the flag in the saved memmap dataset now.


class LossFunctionType(str, enum.Enum):
cross_entropy = "cross_entropy"
dpo = "dpo"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline (please make sure to enable pre-commit)

loss, grad = compute_simplified_dpo_loss(
logits.flatten(0, -2),
labels,
kwargs[LanguageModelKwargs.chosen_spans],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we ensure it's there? Seems like this will crash unless:

  • All datasets have both chosen and rejected spans.
  • The sampling config for all datasets is set to use these spans.

It's a bad idea to wait so late for crash, we should aim to do the check sooner

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, let me see where I can add these checks so that we can detect it earlier.

@tobyzl2
Copy link
Author

tobyzl2 commented Apr 14, 2025

I'm worried too many unrelated things need to be set right for things to work (dataset format, sampling config, loss function), and if not things will crash too late with a cryptic error. Let's try to simplify this a bit.

Yeah this makes sense. Let me try to do a bit of refactoring so from the user perspective they will only have to specify one configuration (something like training objective = dpo) and this will automatically set the other flags without having to manually specify each one.

@tobyzl2 tobyzl2 closed this Apr 14, 2025
@tobyzl2 tobyzl2 reopened this Apr 14, 2025
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tobyzl2 are you planning on adding distillation here? I suggest merging this one first since it's almost ready.

@@ -196,6 +201,9 @@ def _validate(self) -> None:
if self.model.base_model.distillation_model is not None:
# TODO: Support loss masking for distillation?
assert not self.batch.use_loss_masking_spans
assert self.model.base_model.use_dpo_loss == self.batch.use_preference_loss_masking_spans
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to make the two parameters redundant

if target is not None:
if self._config.distillation_model is None:
if self._config.distillation_model is None or self._use_dpo_loss:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks for self._use_dpo_loss redundant since it doesn't support distillation?

@tscholak tscholak mentioned this pull request May 9, 2025
24 tasks
@jlamypoirier
Copy link
Collaborator

Are we planning on merging this?

@tobyzl2
Copy link
Author

tobyzl2 commented May 12, 2025

Are we planning on merging this?

Yes, seeing if we need to wait to merge #255 first @tscholak ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Direct Preference Optimization (DPO) support
4 participants