-
Notifications
You must be signed in to change notification settings - Fork 30
Reference model support for distillation,. etc. #216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
distributed_config=self._distributed_config, | ||
phase=PhaseType.inference, | ||
) | ||
assert fast_llm_model.config.distributed.world_size == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like us to work on this in a separate PR soon, because this is needed for running generative benchmarks during training.
# TODO: Add support. | ||
Assert.eq(self.model.distributed.pipeline_parallel, 1) | ||
# TODO: Check if these work. | ||
Assert.eq(self.model.distributed.tensor_parallel, 1) | ||
Assert.eq(self.model.distributed.sequence_data_parallel, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to lift these restrictions because the reference models can and will be large, and the sequences will be long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the distributed config shared between the student and the teacher models? if so, are we currently unable to use reference modes when training a model with pp, tp, sdp > 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PP would be really hard, but TP and SDP won't be that hard. The tricky part will be to make loss functions that use it compatible.
f" using a freshly initialized model...", | ||
log_fn=logger.warning, | ||
) | ||
reference_model.fast_llm_model.initialize_weights() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just error out here unless we are in debug mode. I don't see a single use case for this other than testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't really have a "debug mode"...
# TODO: Improve. | ||
kwargs[f"{self._name}_logits"] = preprocess_kwargs["logits"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works well for now, since we'll need the logits downstream.
One small note: because the reference model has a user-defined name, the distillation loss in #214 and the DPO loss in #209 will need to use that name to access the logits. So the user-defined name should also be passed as an input to the loss functions, which complicates configuration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not, we can probably get away with a hard-coded name or field default on the loss side, just like we enforce the training dataset to be called "training"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. This unblocks #214 and #209 with reference model.
Before we merge, can we address the limitations of this approach? I'm seeing tp = 1
, pp = 1
, and sdp = 1
assertions, which gives me pause. In many cases, we'll want to distill from large models that won't fit on a single GPU, so FSDP alone won't cut it. Let's discuss how we can lift those constraints in a follow-up.
This kind of inference doesn't take much memory, so FSDP alone should be enough for models well above 100B parameters. |
Ok, thanks, and I guess we are free to use any zero level here? Or are there limitations as well? I guess with zero 0 we wouldn’t be able to fit a 70B model, say. Also, can you answer the question about whether the distributed configs of reference and student model are linked? |
The distributed configs are the same, at least for now. We can explore other options later, but having different configs would be really complicated. We should be able to use any ZeRO level, as long as we're not doing generative inference. (Then ZeRO-3 is very inefficient.) |
✨ Description
Fix:#212
Add support for a reference model that can be used for distillation, etc.
reference_models: dict[str, PretrainedFastLLMModelConfig]
. We probably won't need more than one such model, but this reduces config bloat in the default case.InferenceRunner
class from the Hugging Face wrapper that can be used for Fast-LLM inference without the need to go through Hugging Face.Distributed
to deal with the distributed config being copied in the reference model.It seems to be working and the tests pass, but we will need an actual use to properly test it.
Select all that apply: