Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@ assignees: ''
---

# 🎯 **Goal (What & Why)**
> **Clearly state the purpose of this feature.**
> **Clearly state the purpose of this feature.**
> _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_
# 🚀 **Execution Plan**
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
### **Step 1: What is the smallest working version?**
> _(Describe the simplest way to implement this feature with minimal effort.)_
> _(Describe the simplest way to implement this feature with minimal effort.)_
### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_
### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_
# 📌 **Acceptance Criteria** (Must-Haves for Completion)
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.

# 🛠️ **Project Management**
- [ ] **Assign the project to the Fast-LLM project.**
- [ ] **Set the `Estimate` field (in days) in the GitHub project.**
- [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).**
- [ ] **Assign an owner when opening the issue.**
- [ ] **Assign an owner when opening the issue.**
14 changes: 7 additions & 7 deletions .github/workflows/manual-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ jobs:
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true
- name: Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ inputs.commit_sha != '' && inputs.commit_sha || inputs.branch }}

- name: Get commit info
id: commit_info
run: |
Expand All @@ -48,7 +48,7 @@ jobs:
echo "full_sha=${COMMIT_SHA}" >> $GITHUB_OUTPUT
echo "short_sha=${COMMIT_SHORT}" >> $GITHUB_OUTPUT
echo "Building from commit: ${COMMIT_SHA}"
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
Expand All @@ -59,18 +59,18 @@ jobs:
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}-${{ steps.commit_info.outputs.short_sha }}
type=raw,value=latest-${{ inputs.tag_suffix }},enable=${{ inputs.branch == 'main' && inputs.commit_sha == '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Login to GHCR
if: ${{ inputs.push_image }}
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v6
with:
Expand All @@ -80,7 +80,7 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache
cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max

- name: Output build info
run: |
echo "Built Docker image with tags:"
Expand Down
28 changes: 28 additions & 0 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,38 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch
"format": "pt",
}

def _initialize_missing_parameters(self) -> None:
# Parameters that exist in the model but not in the checkpoint import converters
missing_params = set(self._export_converters.keys()) - {
weight_converter.fast_llm_name[0]
for weight_converter in self._import_converters.values()
if weight_converter.fast_llm_name
}

print(f"[INIT DEBUG] Checking for missing parameters in HuggingFace checkpoint...")
print(f"[INIT DEBUG] Model has {len(self._export_converters)} parameters")
print(f"[INIT DEBUG] Checkpoint has {len(self._import_converters)} parameters")
print(f"[INIT DEBUG] Missing: {len(missing_params)} parameters")

if missing_params:
logger.warning(
f"Initializing {len(missing_params)} parameters not in HuggingFace checkpoint"
)
print(f"[INIT DEBUG] Initializing {len(missing_params)} parameters:")
for param in sorted(missing_params)[:5]: # Show first 5
print(f"[INIT DEBUG] {param}")
if len(missing_params) > 5:
print(f"[INIT DEBUG] ... and {len(missing_params) - 5} more")
for stage in self._model._stages:
stage.initialize_weights_for_parameters(missing_params)

def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
print(f"[INIT DEBUG] HuggingfaceStateDictCheckpointHandler.load() called")
assert not config.optimizer_state
metadata = self._model.config.load_metadata(config)
self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning)
# Initialize parameters not covered by import converters
self._initialize_missing_parameters()
super().load(config)

def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
Expand Down
11 changes: 11 additions & 0 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,15 @@ def _replace(module: torch.nn.Module):
Assert.eq(i, len(self._parameter_metas))
assert not tied_parameter_duplicate_buffers, tied_parameter_duplicate_buffers.keys()

def initialize_weights_for_parameters(self, parameter_names: set[str]) -> None:
"""Initialize only the specified parameters. Used for partial initialization after checkpoint load."""
self._initialize_weights_internal(lambda meta: meta.tensor_name in parameter_names)

def initialize_weights(self) -> None:
"""Initialize all weights."""
self._initialize_weights_internal(lambda meta: True)

def _initialize_weights_internal(self, should_initialize: typing.Callable) -> None:
# TODO: Avoid all the _on_device checks
assert self._is_setup
with torch.no_grad():
Expand All @@ -180,6 +188,9 @@ def initialize_weights(self) -> None:
]

for meta in metas:
# Skip parameters we shouldn't initialize
if not should_initialize(meta):
continue
if meta.tensor_name in self._tied_parameter_duplicates:
# Initialization is not managed by this stage.
continue
Expand Down
90 changes: 89 additions & 1 deletion fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import enum
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.parameter import combine_lr_scales
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.layers.block.config import BlockConfig
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.common.normalization.config import NormalizationConfig
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer


class StochasticMixerKwargs(BlockKwargs):
"""Kwargs keys for stochastic mixer."""

mixer_name = "stochastic_mixer_name"


@config_class()
Expand Down Expand Up @@ -55,6 +63,13 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


class SamplingStrategy(str, enum.Enum):
"""Strategy for sampling mixers in a stochastic mixer."""

uniform = "uniform"
weighted = "weighted"


@config_class(registry=True)
class MixerConfig(BlockWithBiasConfig):
"""
Expand All @@ -71,6 +86,79 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


@config_class(dynamic_type={MixerConfig: "stochastic"})
class StochasticMixerConfig(MixerConfig):
"""
Stochastic mixer that uniformly samples from multiple mixer options during training.

For supernet training, each forward pass randomly selects one mixer to execute,
training all mixers with different subsets of data.
"""

_abstract = False

mixers: dict[str, MixerConfig] = Field(
desc="Dict of mixer options to sample from (must contain at least 1). "
"Keys are mixer names used for debugging and namespacing.",
hint=FieldHint.architecture,
)

sampling_strategy: SamplingStrategy = Field(
default=SamplingStrategy.uniform,
desc="Strategy for sampling mixers during training.",
hint=FieldHint.feature,
)

sampling_weights: dict[str, float] | None = Field(
default=None,
desc="Sampling probability for each mixer by name (must sum to 1.0). "
"Only used when sampling_strategy='weighted'. "
"If None with uniform strategy, all mixers have equal probability.",
hint=FieldHint.feature,
)

main_mixer_name: str | None = Field(
default=None,
desc="Name of the main mixer. "
"Used for inference/eval, checkpoint loading (receives pretrained weights), "
"and checkpoint saving (only this mixer is exported). "
"If None, uses the first mixer in the dict.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
super()._validate()

# Validate mixers dict is not empty
Assert.gt(len(self.mixers), 0)

# Set main_mixer_name to first mixer if not specified
if self.main_mixer_name is None:
with self._set_implicit_default():
self.main_mixer_name = next(iter(self.mixers.keys()))

# Validate main mixer name exists
if self.main_mixer_name not in self.mixers:
raise ValueError(f"main_mixer_name '{self.main_mixer_name}' not found in mixers")

# Validate sampling weights
if self.sampling_weights is not None:
Assert.eq(set(self.sampling_weights.keys()), set(self.mixers.keys()))
# Check sum is close to 1.0
weight_sum = sum(self.sampling_weights.values())
if abs(weight_sum - 1.0) > 1e-5:
raise ValueError(f"Sampling weights must sum to 1.0, got {weight_sum}")
# Check all weights are non-negative
if any(w < 0 for w in self.sampling_weights.values()):
raise ValueError("All sampling weights must be non-negative")

@property
def layer_class(self) -> "type[StochasticMixer]":
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer

return StochasticMixer


@config_class(dynamic_type={BlockConfig: "decoder"})
class DecoderBlockConfig(BlockConfig):
_abstract = False
Expand Down
Loading