-
Notifications
You must be signed in to change notification settings - Fork 37
Add stochastic mixer for supernet training #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements a stochastic mixer layer that randomly samples from multiple mixer options during training, enabling supernet training where different architecture variants (e.g., attention vs. Mamba) are trained with different data subsets. Key components: - StochasticMixerConfig: Configuration for stochastic sampling strategy (uniform or weighted) with configurable main_mixer_index for inference - StochasticMixer: Layer implementation with distributed RNG support - Checkpoint conversion: Apriel converter handles stochastic mixers - Beam search tool: Hierarchical beam search for optimal mixer placement The beam search tool finds which layers benefit most from expensive mixers (e.g., full attention) vs. efficient mixers (e.g., linear attention) by evaluating different configurations using Fast-LLM's evaluation system. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
- Fix Assert.gt_len AttributeError by moving validation to _validate() method - Add AttentionConfig import to models/auto.py for proper registration - Mark all mixer parameters with allow_no_grad=True since only one mixer is active per forward pass 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Fixed nested config structure bug in AprielStochasticMixerConverter.import_config that was causing validation errors when loading Apriel checkpoints. The converter was returning the entire block config (with mixer, mlp, and normalization keys) instead of just the mixer config, causing these fields to be incorrectly nested under the mixer field during import. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
jlamypoirier
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, some minor comments
|
|
||
| with set_generator(generator): | ||
| # Sample from categorical distribution | ||
| idx = torch.multinomial(self._sampling_probs, num_samples=1).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This requires a costly cuda sync. How about we sample for all layers at once during preprocessing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now done during preprocessing
| mixer_idx = self._sample_mixer_index() | ||
|
|
||
| if self._debug.enabled: | ||
| logger.debug(f"StochasticMixer selecting mixer {mixer_idx}: {type(self.mixers[mixer_idx]).__name__}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ambiguous if multiple mixers share the same type. Use named mixers instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now using named mixers. we retrieve mixer_name from kwargs (line 151) and use it for logging (line 160) and accessing the correct mixer (line 163).
| we need to preprocess for all of them. This includes things like | ||
| attention masks, rotary embeddings, etc. | ||
| """ | ||
| for mixer in self.mixers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could be name conflicts. Consider namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now namespaced. see lines 214-216 where we prefix with f"{mixer_name}/{loss_def.name}".
|
|
||
| return int(expected_usage) | ||
|
|
||
| def get_loss_definitions(self, count: int = 1) -> list[LossDef]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit dangerous, there could be name conflicts and counts will be wrong for averaging. Not sure how to fix though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Acknowledged. The current approach ensures we allocate space for all possible losses, but you're right that counts won't match actual usage since only one mixer runs per forward pass. We could track which mixer was use and only record its losses, but that adds complexity. I think what we have is good enough for now.
| return converter_class.mixer_converter_class.export_config(inference_mixer) | ||
|
|
||
| @classmethod | ||
| def get_converters( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about import? I don't think it will work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import doesn't work as usual for stochastic mixers. we use drop_on_export=True for non-main mixers, and HF checkpoints only contain the main mixer. I think the correct way to handle this is to either support stochastic mixers in hf (out of scope) or initialize all other mixers randomly while importing only the main mixer.
| mixer_converter_class.get_converters( | ||
| mixer, | ||
| f"{fast_llm_prefix}.mixers.{mixer_index}", | ||
| hf_prefix if is_main_mixer else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hf_prefix. drop_on_export handles the rest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now uses just hf_prefix without the mixer name prefix.
| f"{hf_prefix}.{block_index}", | ||
| drop_on_export, | ||
| ) | ||
| match config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think match is warranted here, since it involves a (slow) initialization of configs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uses if-else with instance type check now.
tests/utils/model_configs.py
Outdated
| ModelTestingGroup.convert: ModelTestingGroupAction.normal, | ||
| ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.distributed: ModelTestingGroupAction.normal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave as unimportant. All this tests is the consistency of stochastic sampling, and I don't think that warrants the overhead of testing every time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set to ModelTestingGroupAction.unimportant now
- Add _is_lossy_hf_conversion() utility to detect when HF conversion drops weights - Skip incompatible tests (test_converted_round_trip, test_load_pretrained) for lossy conversions - Check converters for IgnoreExportWeightConverter instances - Factor out config loading into _load_config_from_test_dir() and _load_config_from_checkpoint() - Export main_mixer_type in stochastic mixer config for HF compatibility
# Conflicts: # fast_llm/models/gpt/conversion/apriel.py
…ess only selected mixer, remove caching
Summary
Implements a stochastic mixer layer for supernet training, enabling random sampling from multiple mixer options (e.g., attention vs. Mamba) during training. Includes checkpoint conversion support and a hierarchical beam search tool for finding optimal mixer placement post-training.
Implementation Details
Stochastic Mixer (
fast_llm/layers/decoder/stochastic_mixer.py)main_mixer_indexfor deterministic behaviorConfiguration (
fast_llm/layers/decoder/config.py)StochasticMixerConfig: List-based mixer configuration with sampling strategymain_mixer_index: Specifies which mixer to use during inference and which receives pretrained weights during checkpoint conversionCheckpoint Conversion (
fast_llm/models/gpt/conversion/apriel.py)AprielStochasticMixerConverter: Handles conversion between Fast-LLM and Apriel formatsmain_mixer_indexweights are exported/imported (other mixers randomly initialized during supernet training)Beam Search Tool (
tools/supernet_beam_search.py)main_mixer_indexin-place for each candidateTests (
tests/utils/model_configs.py)stochastic_mixertest configuration with FA/Mamba mixersAprielHybridSSMCheckpointFormatUse Case
Supernet Training: Train a model where each layer can be either full attention or Mamba, with random sampling at each step. After training, use beam search to find which specific layers benefit most from full attention vs. Mamba, given a budget constraint (e.g., "I can afford 4 FA layers").
Testing
Run the stochastic mixer tests:
pytest tests/models/test_checkpoint.py::test_checkpoint_and_eval tests/models/test_checkpoint.py::test_conversion -k "stochastic_mixer" -vExample beam search usage:
fast-llm tools/supernet_beam_search.py \ training_config=path/to/supernet_config.yaml \ budgets=[4,8] \ beam_width=12 \ score_metric="lm_eval/accuracy" \ output_path=results.json🤖 Generated with Claude Code
Co-Authored-By: Claude [email protected]