-
Notifications
You must be signed in to change notification settings - Fork 30
DPO #223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks great already, a few functional tests would be good.
maybe extend
Line 278 in 5ba1f0f
def get_test_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm worried too many unrelated things need to be set right for things to work (dataset format, sampling config, loss function), and if not things will crash too late with a cryptic error. Let's try to simplify this a bit.
fast_llm/data/dataset/gpt/config.py
Outdated
desc="Read preference loss masking spans from the dataset.", | ||
hint=FieldHint.feature, | ||
) | ||
enable_packing: bool | None = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's packing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made some changes to remove this flag, packing would be having multiple documents in the same sequence in pretraining but for DPO wanted to have a way to pad to the end of sequence instead.
fast_llm/data/dataset/gpt/config.py
Outdated
@@ -57,6 +57,16 @@ class GPTSamplingConfig(SamplingConfig): | |||
desc="Read loss masking spans from the dataset.", | |||
hint=FieldHint.feature, | |||
) | |||
use_preference_loss_masking_spans: bool | None = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to have both normal and preference loss masking spans?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this might be better suited for GPTSamplingData
. This way the trainer could set this value based on the training objective and avoid the complicated relationship between the parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just made some changes to remove this extra flag, it's reading it automatically from the flag in the saved memmap dataset now.
fast_llm/functional/config.py
Outdated
|
||
class LossFunctionType(str, enum.Enum): | ||
cross_entropy = "cross_entropy" | ||
dpo = "dpo" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline (please make sure to enable pre-commit)
loss, grad = compute_simplified_dpo_loss( | ||
logits.flatten(0, -2), | ||
labels, | ||
kwargs[LanguageModelKwargs.chosen_spans], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we ensure it's there? Seems like this will crash unless:
- All datasets have both chosen and rejected spans.
- The sampling config for all datasets is set to use these spans.
It's a bad idea to wait so late for crash, we should aim to do the check sooner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree, let me see where I can add these checks so that we can detect it earlier.
Yeah this makes sense. Let me try to do a bit of refactoring so from the user perspective they will only have to specify one configuration (something like training objective = dpo) and this will automatically set the other flags without having to manually specify each one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tobyzl2 are you planning on adding distillation here? I suggest merging this one first since it's almost ready.
fast_llm/models/gpt/config.py
Outdated
@@ -196,6 +201,9 @@ def _validate(self) -> None: | |||
if self.model.base_model.distillation_model is not None: | |||
# TODO: Support loss masking for distillation? | |||
assert not self.batch.use_loss_masking_spans | |||
assert self.model.base_model.use_dpo_loss == self.batch.use_preference_loss_masking_spans |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to make the two parameters redundant
if target is not None: | ||
if self._config.distillation_model is None: | ||
if self._config.distillation_model is None or self._use_dpo_loss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checks for self._use_dpo_loss
redundant since it doesn't support distillation?
Are we planning on merging this? |
Description
This PR introduces the implementation of Direct Preference Optimization (DPO) training on Fast-LLM. DPO enhances model fine-tuning by directly incorporating user preferences into the optimization process, ensuring that the model better aligns with the desired output behavior.
Closes #209
Throughput Numbers (reference free - 04/14/25)
Throughput Numbers (with reference model - 04/29/25)
🔍 Type of change
Select all that apply:
📝 Changes
List the key changes introduced in this PR:
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.