Skip to content

Config update mechanism, keep track of explicitly set config parameters #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 14, 2025
Merged
345 changes: 217 additions & 128 deletions fast_llm/config.py

Large diffs are not rendered by default.

11 changes: 1 addition & 10 deletions fast_llm/data/data/config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import typing

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class
from fast_llm.config import Config, Field, config_class
from fast_llm.data.dataset.config import SamplingConfig, SamplingData


@config_class()
class SamplingDefaultConfig(SamplingConfig):
seed: int = FieldUpdate(
default=784569,
desc="Seed for random sampling.",
hint=FieldHint.feature,
)


@config_class()
class DataConfig(Config):
_abstract = True
Expand Down
12 changes: 2 additions & 10 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,19 @@

from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig
from fast_llm.data.data.config import DataConfig, SamplingDefaultConfig
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.gpt.config import (
GPTLegacyConfig,
GPTLegacyDatasetConfig,
GPTSampledDatasetConfig,
GPTSamplingConfig,
ShufflingType,
)
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


@config_class()
class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig):
gpu: bool = FieldUpdate(default=True)
use_loss_masking_spans: bool = FieldUpdate(default=False)
shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch)


@config_class()
class GPTDataConfig(DataConfig, GPTLegacyConfig):
"""
Expand All @@ -45,7 +37,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
sampling: GPTSamplingDefaultConfig = FieldUpdate(default_factory=GPTSamplingDefaultConfig)
sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig)
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand Down
26 changes: 9 additions & 17 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import typing

from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class
from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.utils import Assert, normalize_probabilities

Expand All @@ -16,20 +16,12 @@

@config_class()
class SamplingConfig(Config):
seed: int | None = Field(
default=None,
seed: int = Field(
default=784569,
desc="Seed for random sampling.",
hint=FieldHint.feature,
)

@property
def updates(self) -> dict[str, typing.Any]:
return {
key: value
for key, value in self.to_serialized(verbose=FieldVerboseLevel.everything).items()
if value is not None
}


@dataclasses.dataclass(kw_only=True)
class SamplingData:
Expand All @@ -43,10 +35,10 @@ class SamplingData:
# Using a mutable rather than an int so it's shared with all copies made with `update`.
_rank_counter: typing.Iterator[int] = itertools.count

def update(self, config: SamplingConfig, **kwargs):
if config_updates := config.updates:
kwargs["config"] = self.config.to_copy(config_updates)
return dataclasses.replace(self, **kwargs) if kwargs else self
def update_config(self, update: SamplingConfig):
return dataclasses.replace(
self, config=self.config.from_dict(self.config, update, update_type=UpdateType.update)
)

def get_next_rank(self) -> int:
# Counter that loops over ranks to try to distribute workloads evenly between ranks.
Expand Down Expand Up @@ -162,7 +154,7 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig):
Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument.
"""

_abstract = False
_abstract = True
sampling: SamplingConfig = Field(
default_factory=SamplingConfig,
desc="Optional override to sampling configuration parameters.",
Expand All @@ -175,7 +167,7 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig):
)

def build_and_sample(self, data: SamplingData) -> SampledDataset:
return self.dataset.build_and_sample(data.update(self.sampling))
return self.dataset.build_and_sample(data.update_config(self.sampling))


@config_class()
Expand Down
19 changes: 10 additions & 9 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@ class ShufflingType(str, enum.Enum):

@config_class()
class GPTSamplingConfig(SamplingConfig):
gpu: bool | None = Field(
default=None,
gpu: bool = Field(
default=True,
desc="Enable fast sampling on GPU."
" Note that random sampling works differently on GPU,"
" so the sample won't match the CPU equivalent.",
hint=FieldHint.feature,
)
use_loss_masking_spans: bool | None = Field(
default=None,
use_loss_masking_spans: bool = Field(
default=False,
desc="Read loss masking spans from the dataset.",
hint=FieldHint.feature,
)
shuffle: ShufflingType | None = Field(
default=None,
shuffle: ShufflingType = Field(
default=ShufflingType.epoch,
desc="Shuffling strategy.",
hint=FieldHint.feature,
)
Expand Down Expand Up @@ -211,6 +211,7 @@ def build(self) -> "GPTDatasetSlice":

@config_class()
class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig):
_abstract = False
type_: typing.ClassVar[str | None] = "sampled"
sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig)
dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig)
Expand Down Expand Up @@ -485,8 +486,8 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:
"type": "slice",
# TODO: this duplicates memmap datasets for each phase.
"dataset": {"type": "memmap", "path": prefix},
"begin": phase_splits[phase_index],
"end": phase_splits[phase_index + 1],
"begin": float(phase_splits[phase_index]),
"end": float(phase_splits[phase_index + 1]),
}
for prefix in dataset_prefixes
]
Expand All @@ -505,7 +506,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:
dataset_config = {
"type": "fim",
"dataset": dataset_config,
**self.fim.to_serialized(),
**self.fim.to_dict(),
}
# Legacy sampling config
dataset_config = {
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _sample(self) -> None:
"unshuffled_epochs": unshuffled_epochs,
"sequence_length": self._sequence_length,
"truncate_documents": self._truncate_documents,
"config": self._config.to_serialized(),
"config": self._config.to_dict(),
}
if self._truncate_documents:
yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa
def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None:
logger.info(f"Saving config to {output_path}")
yaml.safe_dump(
dataset_config.to_serialized(),
dataset_config.to_dict(),
output_path.open("w"),
)

Expand Down
5 changes: 3 additions & 2 deletions fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateCon

def _validate(self) -> None:
if self.optimizer_state is None:
# TODO: Make sure it's a type
self.optimizer_state = self.format.support_optimizer
with self._set_implicit_default():
# TODO: Make sure it's a type
self.optimizer_state = self.format.support_optimizer
super()._validate()
if self.optimizer_state:
assert self.format.support_optimizer
Expand Down
8 changes: 3 additions & 5 deletions fast_llm/engine/checkpoint/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetada
return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r")))

def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
serialized_metadata = metadata.to_serialized()
serialized_metadata = metadata.to_dict()
if self._model.config.distributed.rank == 0:
yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w"))
safetensors.torch.save_file(
Expand All @@ -50,10 +50,8 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
Assert.leq(set(self.get_shard_names(config)), set(metadata.shards))
Assert.eq(metadata.shards[: len(shard_names)], list(shard_names))

same_format = (
loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None)
and config.optimizer_state
)
# Using `log_fn=bool` sets the output to true if the error list is non-empty.
same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool)
# Make sure all nodes agree on which loading scheme to use.
# Note: they may not agree before the broadcast because of the rank comparison, but that's ok.
same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group)
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import torch

from fast_llm import __version__
from fast_llm.config import MISSING
from fast_llm.config import MISSING, get_nested_dict_value, set_nested_dict_value
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig
from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig
from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler
from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert, get_nested_dict_value, set_nested_dict_value
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -232,7 +232,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetada
fast_llm_version=__version__,
model=cls._model_class,
format=config.format,
config=cls._model_class.from_dict({"base_model": imported_model_config.to_serialized()}),
config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}),
shards=["weights"],
)

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch
huggingface_config = self._export_config(self._model.config.base_model)
self._save_config(config.path, huggingface_config)
return {
"fast_llm_metadata": metadata.to_serialized(),
"fast_llm_metadata": metadata.to_dict(),
"model_config": huggingface_config,
"format": "pt",
}
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metada
def _serialize_metadata(
self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata
) -> dict[str, typing.Any]:
return metadata.to_serialized()
return metadata.to_dict()

def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context:
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def __init__(
self._is_pipeline_parallel_main_rank = (
self._distributed_config.data_rank == 0 and self._distributed_config.tensor_rank == 0
)
config_dict = config.to_serialized()
config_dict_verbose = config.to_serialized(verbose=FieldVerboseLevel.performance)
config_dict = config.to_dict()
config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance)

if self._config.experiment_dir is not None:
self._experiment_directory = self._config.experiment_dir.resolve()
Expand Down
3 changes: 2 additions & 1 deletion fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def _validate(self) -> None:

self.tensor_rank = self.rank % self.tensor_parallel
if self.tensor_parallel == 1:
self.sequence_tensor_parallel = False
with self._set_implicit_default():
self.sequence_tensor_parallel = False

if self.reference_config is not None:
self.reference_config.validate()
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/engine/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import transformers

from fast_llm.config import FieldVerboseLevel
from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig, FastLLMCheckpointFormat
from fast_llm.engine.multi_stage.config import FastLLMModelConfig

Expand Down Expand Up @@ -90,12 +91,12 @@ def __eq__(self, other) -> bool:

def to_dict(self) -> dict[str, typing.Any]:
out = super().to_dict()
out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=None)
out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.everything)
return out

def to_diff_dict(self) -> dict[str, typing.Any]:
out = super().to_diff_dict()
out["fast_llm_config"] = self.fast_llm_config.to_serialized()
out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.explicit)
return out

def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True) -> None:
Expand Down
7 changes: 2 additions & 5 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ def setup(self, distributed_config: DistributedConfig) -> None:
def num_inputs(self) -> int:
return self.sequential_micro_batches * self.num_micro_sequences

@property
def _is_setup(self) -> bool:
return hasattr(self, "_distributed")

def _validate(self) -> None:
# Use the distributed properties to determine the batch size and its breakdown.
# Requires post-processed distributed config args
Expand Down Expand Up @@ -133,7 +129,8 @@ def _validate(self) -> None:
" Use at your own risk."
)
if self.micro_sequence_length is None:
self.micro_sequence_length = self.sequence_length
with self._set_implicit_default():
self.micro_sequence_length = self.sequence_length
self.num_micro_sequences = div(self.sequence_length, self.micro_sequence_length)
super()._validate()

Expand Down
4 changes: 3 additions & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class IntervalConfig(Config):

def _validate(self) -> None:
if self.interval:
self.offset %= self.interval
with self._set_implicit_default():
self.offset %= self.interval
super()._validate()

def enabled(self, iteration: int | None = None) -> bool:
Expand Down Expand Up @@ -120,6 +121,7 @@ class WandbAlertConfig(IntervalConfig):
"The update may be posted by email and/or slack depending on the Wandb account configuration.",
hint=FieldHint.feature,
)
post_alerts: bool = Field(init=False, repr=False)

def _validate(self) -> None:
if self.status_updates is None:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/training/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config):
if wandb_path is not None:
yaml.safe_dump(wandb_config, wandb_path.open("w"))
# TODO: Does wandb work with nested configs?
self._wandb = wandb.init(config=experiment_config.to_serialized(), **wandb_config)
self._wandb = wandb.init(config=experiment_config.to_dict(), **wandb_config)
else:
self._wandb = None

Expand Down
Loading