Skip to content

Frozen reference model support for DPO, distillation, etc. #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
3 of 4 tasks
tscholak opened this issue Mar 27, 2025 · 2 comments
Closed
3 of 4 tasks

Frozen reference model support for DPO, distillation, etc. #212

tscholak opened this issue Mar 27, 2025 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@tscholak
Copy link
Collaborator

tscholak commented Mar 27, 2025

🎯 Goal (What & Why)

Add support for loading and executing a (potentially very large) frozen reference model in parallel with the main trainable model.

This is a prerequisite for implementing Direct Preference Optimization (DPO) #209 with normalization and for distillation-style training #214. The reference model must be completely frozen (no gradients, no optimizer), support parallelism (data, tensor, etc.), and be configurable independently from the primary model.

🚀 Execution Plan

Step 1: What is the smallest working version?

  • Extend the Fast-LLM training stack to accept an optional reference_model config block:
    • Separate architecture and checkpoint path
    • May use HuggingFace or pre-converted checkpoints
  • Load the reference model at the beginning of training:
    • With the same distributed settings as the target model (the one being trained)
    • Distribute it across the same GPUs
  • Ensure reference model:
    • Runs in torch.no_grad() mode
    • Never receives gradients or enters optimizer step
    • Uses same tokenizer as main model and has the same vocabulary size.
  • At each training step:
    • Forward the input batch through the reference model
    • Expose its logits by providing a utility function to access ref model outputs for loss computation (e.g. get_ref_logits(batch))

Step 2: What additional optimizations are possible (but optional)?

  • Allow reference model to use a subset of available GPUs instead of the same ones
  • Allow reference model to use a different distributed/parallelism setup as the main model

📌 Acceptance Criteria

  • Reference model can be configured and loaded independently
  • Reference model runs in forward pass alongside main model
  • Reference model is fully frozen (no optimizer, no gradients)
  • Reference model outputs are usable in custom loss functions
  • Training with a reference model does not break pretraining or SFT workflows
  • Documented with examples for both DPO and distillation use cases

🛠️ Project Management

  • Assign the project to the Fast-LLM project.
  • Set the Estimate field (in days) in the GitHub project.
  • Use the Size field to categorize the PR size (Large).
  • Assign an owner when opening the issue.
@jlamypoirier
Copy link
Collaborator

Plan looks reasonable, implementation should be quite straightforward. Some comments:

  • Extend the Fast-LLM training stack to accept an optional reference_model config block:

    • Separate architecture and checkpoint path
    • May use HuggingFace or pre-converted checkpoints
  • Load the reference model at the beginning of training:

    • With the same distributed settings as the target model (the one being trained)
    • Distribute it across the same GPUs

We can add a reference_model: PretrainedGPTModelConfig field and some bool flag to activate it. (Ideally we would make it optional or a dynamic class #126 so we don't have to create the class unnecessarily, but not sure how easy that is.)
That should be enough config-wise, but this will create two separate distributed configs and we need to make the one for the reference model isn't defined (maybe using #205).

  • Ensure reference model:

    • Runs in torch.no_grad() mode
    • Never receives gradients or enters optimizer step

Just need to setup with forward-only support, it will take care of everything.

  • Uses same tokenizer as main model and has the same vocabulary size.

We can check the vocab size, but the model itself has no knowledge of the tokenizer so we can't easily check that the reference model was trained with the same tokenizer.

  • At each training step:

    • Forward the input batch through the reference model
    • Expose its logits by providing a utility function to access ref model outputs for loss computation (e.g. get_ref_logits(batch))

We can add a reference_model.forward() somewhere in the training loop and add the output to kwargs to make it available to the model. This won't work with pipeline parallelism, but I don't think we care at this stage.

Step 2: What additional optimizations are possible (but optional)?

  • Allow reference model to use a subset of available GPUs instead of the same ones
  • Allow reference model to use a different distributed/parallelism setup as the main model

Also pipeline parallelism support, distillation example.

@jlamypoirier
Copy link
Collaborator

After a deeper look, the challenges seem to be the following:

  • The place for running the reference model is located deep in the schedule runner. This might be solvable with a preprocessing hook.
  • Advanced schedules (micro-sequences, breadth-first) will be difficult to support. We can just ignore for now.
  • We don't have a structure for "running a model in inference mode". (Closest is the huggingface runner but it won't work here.) I'll try making a simple one, otherwise I'll need to pollute the trainer and runner with lots of hacks and bloat.
  • The LM head is supposed to be able to return logits, but this hasn't been tested recently and probably wont work, so some extra work could be needed there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants