Skip to content

Reference model support for distillation,. etc. #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 7, 2025
Merged

Conversation

jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Apr 1, 2025

✨ Description

Fix:#212

Add support for a reference model that can be used for distillation, etc.

  • Config is specified as reference_models: dict[str, PretrainedFastLLMModelConfig]. We probably won't need more than one such model, but this reduces config bloat in the default case.
  • Model is run as a data preprocessor, and add its output to kwargs.
  • Keeping the hack that sends the logits through kwargs, but we'll want something better eventually.
  • Extracted an InferenceRunner class from the Hugging Face wrapper that can be used for Fast-LLM inference without the need to go through Hugging Face.
  • Some tweaks to Distributed to deal with the distributed config being copied in the reference model.

It seems to be working and the tests pass, but we will need an actual use to properly test it.

/app$ fast-llm train gpt training.train_iters=10 data.format=random reference_models.teacher={}
2025-04-04 03:13:01,628 Using the legacy dataset definition format. Specify it through `data.datasets` instead.
2025-04-04 03:13:10,274 Setting random seeds...
2025-04-04 03:13:12,993 Creating model...
2025-04-04 03:13:13,031   Splitting the model into 14 stages...
2025-04-04 03:13:13,033   Total parameters: 181,543,936 
2025-04-04 03:13:13,034 Weight buffer placement:
{1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 0: 13}
2025-04-04 03:13:13,034 Grad buffer placement:
{1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 0: 13}
2025-04-04 03:13:13,034 Creating `teacher reference model...
2025-04-04 03:13:13,132   Splitting the model into 14 stages...
2025-04-04 03:13:13,134   Total parameters: 181,543,936 
2025-04-04 03:13:13,135 Weight buffer placement:
{1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 0: 13}
2025-04-04 03:13:13,135 Grad buffer placement:
{1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 0: 13}
2025-04-04 03:13:13,136 Command run:
/usr/local/bin/fast-llm train gpt training.train_iters=10 data.format=random 'reference_models.teacher={}'
2025-04-04 03:13:13,139 
------- fast_llm.models.gpt.config.GPTTrainerConfig --------
[...]
reference_models:
  teacher:
    model:
      [...]
[...]
--------------------------- end ----------------------------
2025-04-04 03:13:13,139 Setting up model...
2025-04-04 03:13:13,139 >>> Allocating 14 weight buffers (692.54 MiB)
2025-04-04 03:13:13,547 >>> Allocating 14 grad buffers (692.54 MiB)
2025-04-04 03:13:13,547 >>> Allocating 4 shards (2,770.14 MiB)
2025-04-04 03:13:13,548 Total allocated: 4,155.21 MiB
2025-04-04 03:13:13,553 Setting up `teacher` reference model...
2025-04-04 03:13:13,553 >>> Allocating 14 weight buffers (692.54 MiB)
2025-04-04 03:13:13,554 >>> Allocating 1 shards (692.54 MiB)
2025-04-04 03:13:13,554 Total allocated: 1,385.07 MiB
2025-04-04 03:13:13,568 Preparing datasets...
/app/fast_llm/data/data/gpt/data.py:99: UserWarning: The following datasets are defined but not used: validation, test. Ensure this is intentional, or update the configuration accordingly.
  warnings.warn(
2025-04-04 03:13:13,568 Preparing dataset. This may take several minutes.
/app/fast_llm/data/data/gpt/data.py:110: UserWarning: Using the dataset directory for the index cache.
  warnings.warn(f"Using the dataset directory for the index cache.")
2025-04-04 03:13:13,580 Initializing training state from scratch...
2025-04-04 03:13:13,917 No pretrained checkpoint specified for `teacher` reference model, using a freshly initialized model...
2025-04-04 03:13:13,931 done with setup ...
2025-04-04 03:13:13,932 After initial setup:  allocated 5,540.28 MiB | max allocated 5,540.28 MiB | reserved 5,550.00 MiB | max reserved 5,550.00 MiB | global max reserved 5,550.00 MiB
2025-04-04 03:13:13,932 Initializing training dataset iterator from sample 0...
2025-04-04 03:13:13,966 Training ...
2025-04-04 03:13:16,172 Data loading took 2,068.80 ms

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@jlamypoirier jlamypoirier changed the base branch from main to preprocessor April 3, 2025 23:57
@jlamypoirier jlamypoirier changed the title [Prototype] Reference model support for distillation,. etc. Reference model support for distillation,. etc. Apr 4, 2025
@jlamypoirier jlamypoirier marked this pull request as ready for review April 4, 2025 02:54
Base automatically changed from preprocessor to main April 6, 2025 18:27
distributed_config=self._distributed_config,
phase=PhaseType.inference,
)
assert fast_llm_model.config.distributed.world_size == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like us to work on this in a separate PR soon, because this is needed for running generative benchmarks during training.

Comment on lines +388 to +392
# TODO: Add support.
Assert.eq(self.model.distributed.pipeline_parallel, 1)
# TODO: Check if these work.
Assert.eq(self.model.distributed.tensor_parallel, 1)
Assert.eq(self.model.distributed.sequence_data_parallel, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to lift these restrictions because the reference models can and will be large, and the sequences will be long.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the distributed config shared between the student and the teacher models? if so, are we currently unable to use reference modes when training a model with pp, tp, sdp > 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PP would be really hard, but TP and SDP won't be that hard. The tricky part will be to make loss functions that use it compatible.

f" using a freshly initialized model...",
log_fn=logger.warning,
)
reference_model.fast_llm_model.initialize_weights()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just error out here unless we are in debug mode. I don't see a single use case for this other than testing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really have a "debug mode"...

Comment on lines +28 to +29
# TODO: Improve.
kwargs[f"{self._name}_logits"] = preprocess_kwargs["logits"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works well for now, since we'll need the logits downstream.

One small note: because the reference model has a user-defined name, the distillation loss in #214 and the DPO loss in #209 will need to use that name to access the logits. So the user-defined name should also be passed as an input to the loss functions, which complicates configuration.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not, we can probably get away with a hard-coded name or field default on the loss side, just like we enforce the training dataset to be called "training"

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. This unblocks #214 and #209 with reference model.

Before we merge, can we address the limitations of this approach? I'm seeing tp = 1, pp = 1, and sdp = 1 assertions, which gives me pause. In many cases, we'll want to distill from large models that won't fit on a single GPU, so FSDP alone won't cut it. Let's discuss how we can lift those constraints in a follow-up.

@jlamypoirier
Copy link
Collaborator Author

Before we merge, can we address the limitations of this approach? I'm seeing tp = 1, pp = 1, and sdp = 1 assertions, which gives me pause. In many cases, we'll want to distill from large models that won't fit on a single GPU, so FSDP alone won't cut it. Let's discuss how we can lift those constraints in a follow-up.

This kind of inference doesn't take much memory, so FSDP alone should be enough for models well above 100B parameters.

@jlamypoirier jlamypoirier merged commit 5ba1f0f into main Apr 7, 2025
2 checks passed
@jlamypoirier jlamypoirier deleted the reference_model branch April 7, 2025 23:05
@tscholak
Copy link
Collaborator

tscholak commented Apr 7, 2025

Before we merge, can we address the limitations of this approach? I'm seeing tp = 1, pp = 1, and sdp = 1 assertions, which gives me pause. In many cases, we'll want to distill from large models that won't fit on a single GPU, so FSDP alone won't cut it. Let's discuss how we can lift those constraints in a follow-up.

This kind of inference doesn't take much memory, so FSDP alone should be enough for models well above 100B parameters.

Ok, thanks, and I guess we are free to use any zero level here? Or are there limitations as well? I guess with zero 0 we wouldn’t be able to fit a 70B model, say.

Also, can you answer the question about whether the distributed configs of reference and student model are linked?

@jlamypoirier
Copy link
Collaborator Author

Ok, thanks, and I guess we are free to use any zero level here? Or are there limitations as well? I guess with zero 0 we wouldn’t be able to fit a 70B model, say.

Also, can you answer the question about whether the distributed configs of reference and student model are linked?

The distributed configs are the same, at least for now. We can explore other options later, but having different configs would be really complicated.

We should be able to use any ZeRO level, as long as we're not doing generative inference. (Then ZeRO-3 is very inefficient.)
Note that the multi-stage config doesn't have to be the same, so we can use a different zero level for both.

@tscholak tscholak mentioned this pull request Apr 18, 2025
25 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants