Skip to content

Reference model support for distillation,. etc. #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 7, 2025
3 changes: 2 additions & 1 deletion fast_llm/engine/config_utils/tensor_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def __init__(self, distributed_config: DistributedConfig):
self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1))

def setup(self, distributed: "Distributed") -> None:
assert distributed.config is self._distributed_config
assert not self._is_setup
if distributed.config is not self._distributed_config:
distributed.config.compare(self._distributed_config, ValueError)
self._is_setup = True
self._distributed = distributed

Expand Down
134 changes: 77 additions & 57 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ class DistributedConfig(Config):
desc="Ensure the initialization is the same for any distributed configuration.",
hint=FieldHint.testing,
)
reference_config: "DistributedConfig|None" = Field(
default=None,
init=False,
desc="Pointer to the distributed config this one is an identical copy of.",
hint=FieldHint.derived,
)

def _validate(self) -> None:
if self.world_size is None:
Expand Down Expand Up @@ -281,76 +287,90 @@ def _validate(self) -> None:
if self.tensor_parallel == 1:
self.sequence_tensor_parallel = False

self.distributed_dims = {}
if self.reference_config is not None:
self.reference_config.validate()
if self.reference_config.reference_config is not None:
self.reference_config = self.reference_config.reference_config
assert self.reference_config.reference_config is None
self.compare(self.reference_config, ValueError)
self.distributed_dims = self.reference_config.distributed_dims
else:
self.distributed_dims = {}

self.add_distributed_dim(
DistributedDim(name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.data,
size=self.data_parallel,
rank=self.data_rank,
id_=f"x_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.world,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.pipeline,
size=self.pipeline_parallel,
rank=self.pipeline_rank,
id_=f"x_{self.data_rank}_{self.tensor_rank}",
parent=DistributedDimNames.world,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.data,
size=self.data_parallel,
rank=self.data_rank,
id_=f"x_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.world,
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor,
size=self.tensor_parallel,
rank=self.tensor_rank,
id_=f"x_{self.data_rank}_{self.pipeline_rank}",
parent=DistributedDimNames.world,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.pipeline,
size=self.pipeline_parallel,
rank=self.pipeline_rank,
id_=f"x_{self.data_rank}_{self.tensor_rank}",
parent=DistributedDimNames.world,
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.sequence_data,
size=self.sequence_data_parallel,
rank=self.sequence_data_rank,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.data,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor,
size=self.tensor_parallel,
rank=self.tensor_rank,
id_=f"x_{self.data_rank}_{self.pipeline_rank}",
parent=DistributedDimNames.world,
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.batch_data,
size=self.batch_data_parallel,
rank=self.batch_data_rank,
id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.data,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.sequence_data,
size=self.sequence_data_parallel,
rank=self.sequence_data_rank,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.data,
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_sequence_data,
size=self.sequence_data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}",
parent=(
DistributedDimNames.tensor
if self.sequence_data_parallel == 1
else DistributedDimNames.sequence_data if self.tensor_parallel == 1 else DistributedDimNames.world
),
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.batch_data,
size=self.batch_data_parallel,
rank=self.batch_data_rank,
id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.data,
)
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_sequence_data,
size=self.sequence_data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}",
parent=(
DistributedDimNames.tensor
if self.sequence_data_parallel == 1
else (
DistributedDimNames.sequence_data
if self.tensor_parallel == 1
else DistributedDimNames.world
)
),
)
)
)

super()._validate()

Assert.in_range(self.rank, 0, self.world_size)
Assert.in_range(self.local_rank, 0, self.local_world_size)

def add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
if distributed_dim.name in self.distributed_dims:
Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name])
else:
Expand Down
8 changes: 8 additions & 0 deletions fast_llm/engine/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]):

def __init__(self, config: DistributedConfig, use_cpu: bool = False):
super().__init__(config)
assert self._config.reference_config is None
self._use_cpu = use_cpu

if self._use_cpu:
Expand Down Expand Up @@ -148,6 +149,13 @@ def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None:
distributed_dim.setup(group)
return group

def check_config(self, config: DistributedConfig) -> None:
# Allows using this `Distributed` on a model with a distributed config that is a copy of `self._config`
if config.reference_config is None:
Assert.is_(config, self._config)
else:
Assert.is_(config.reference_config, self._config)

def set_step(self, step: int, phase: PhaseType) -> None:
"""
Reseed pytorch for a given training step.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,38 @@

import transformers.modeling_outputs

from fast_llm.config import NoAutoValidate
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.huggingface.config import HuggingfaceModelConfig
from fast_llm.engine.inference.config import HuggingfaceModelConfig
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.config import StageMode
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
from fast_llm.engine.schedule.runner import ScheduleRunner
from fast_llm.engine.schedule.schedule import Schedule


class HuggingfacePreTrainedModel(transformers.PreTrainedModel):
config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig
model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel
runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner
config: HuggingfaceModelConfig
# base_model_prefix = ""
# _no_split_modules = None
# _supports_cache_class = False
# _tied_weights_keys = []

def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, **kwargs):
assert self.model_class.config_class is config.model_config_class
assert self.runner_class.model_class.config_class is config.model_config_class
assert config.fast_llm_config is fast_llm_model.config
assert isinstance(config, self.config_class)

super().__init__(config, **kwargs)
self._fast_llm_config = config.fast_llm_config
self._fast_llm_model = fast_llm_model

self._inference_runner = self.runner_class(fast_llm_model)
if not fast_llm_model.is_setup:
fast_llm_model.setup(mode=StageMode.inference)
self._inference_runner.setup()
# Transformers needs to be able to inspect the base model.
self.fast_llm_base_model = self._fast_llm_model.base_model
self._distributed_config = self._fast_llm_config.distributed
self.fast_llm_base_model = fast_llm_model.base_model
# TODO: Support distributed models?
assert self._distributed_config.world_size == 1
self._schedule_config = ScheduleConfig()
# We only need a basic schedule and don't care about dimensions.
# TODO: Sort things out.
with NoAutoValidate():
self._batch_config = BatchConfig()
self._batch_config.setup(self._distributed_config)
self._batch_config.validate()
self._runner = ScheduleRunner(
config=self._schedule_config, multi_stage=self._fast_llm_model, distributed_config=self._distributed_config
)
self._runner.setup(self._fast_llm_model.distributed)
# TODO: Random state? (Distributed.set_step)
self._schedule = Schedule(
multi_stage=self._fast_llm_model,
batch_config=self._batch_config,
schedule_config=self._schedule_config,
distributed_config=self._distributed_config,
phase=PhaseType.inference,
)
assert fast_llm_model.config.distributed.world_size == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like us to work on this in a separate PR soon, because this is needed for running generative benchmarks during training.


with transformers.modeling_utils.no_init_weights():
self.post_init()

Expand All @@ -79,7 +60,7 @@ def from_pretrained(
config_updates[("distributed", "training_dtype")] = torch_dtype

# Create the model
fast_llm_model = cls.model_class.from_pretrained(
fast_llm_model = cls.runner_class.model_class.from_pretrained(
pretrained_model_name_or_path, config_updates=config_updates, mode=mode
)
config = cls.config_class(fast_llm_model.config)
Expand Down
58 changes: 58 additions & 0 deletions fast_llm/engine/inference/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import abc
import typing

from fast_llm.config import NoAutoValidate
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
from fast_llm.engine.schedule.runner import ScheduleRunner
from fast_llm.engine.schedule.schedule import Schedule


class InferenceRunner(abc.ABC):
model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel

def __init__(self, fast_llm_model: FastLLMModel):
assert isinstance(fast_llm_model, self.model_class)
self._fast_llm_model = fast_llm_model
# We only need a basic schedule and don't care about dimensions.
self._schedule_config = ScheduleConfig()
# TODO: Sort things out.
with NoAutoValidate():
self._batch_config = BatchConfig()
self._batch_config.setup(self._fast_llm_model.config.distributed)
self._batch_config.validate()
self._runner = ScheduleRunner(
config=self._schedule_config,
multi_stage=self._fast_llm_model,
distributed_config=self._fast_llm_model.config.distributed,
)
# TODO: Random state? (Distributed.set_step)
self._schedule = Schedule(
multi_stage=self._fast_llm_model,
batch_config=self._batch_config,
schedule_config=self._schedule_config,
distributed_config=self._fast_llm_model.config.distributed,
phase=PhaseType.inference,
)

@property
def fast_llm_model(self) -> FastLLMModel:
return self._fast_llm_model

def setup(self):
self._runner.setup(self._fast_llm_model.distributed)

def forward(
self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False
) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]:
# TODO: Return an actual model output.
reduced_losses, update_successful, metrics = self._runner.run_step(
iter((((input_, kwargs),),)),
self._schedule,
iteration=iteration,
return_metrics=return_metrics,
preprocessed=True,
)
assert update_successful
return reduced_losses, metrics
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.huggingface.model import HuggingfacePreTrainedModel
from fast_llm.engine.inference.model import HuggingfacePreTrainedModel
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel

logger = logging.getLogger(__name__)
Expand Down
12 changes: 10 additions & 2 deletions fast_llm/engine/multi_stage/multi_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,15 @@ def __init__(
"Bfloat16 gradient accumulation and reduction is not recommended. (use --full_precision_gradients=1)"
)

def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) -> None:
def setup(self, distributed: Distributed | None = None, mode: StageMode = StageMode.training) -> None:
# TODO: More checks?
stage: Stage
assert distributed.config is self._config.distributed
assert not self._is_setup
self._is_setup = True
if distributed is None:
distributed = Distributed(self._config.distributed)
else:
distributed.check_config(self._config.distributed)
self._distributed = distributed
self._mode = mode
self._base_model.setup(distributed)
Expand Down Expand Up @@ -381,6 +384,10 @@ def get_shard(self, name: str) -> torch.Tensor:
raise KeyError(f"Unknown shard name {name}")
return self._shards[name]

@property
def is_setup(self) -> bool:
return self._is_setup

@property
def support_forward(self) -> bool:
assert self._is_setup
Expand Down Expand Up @@ -442,6 +449,7 @@ def is_parameter_on_device(self, parameter_name: str) -> bool:

@property
def distributed(self) -> Distributed:
assert self._is_setup
return self._distributed

def invalidate_buffers(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def setup(
mode: StageMode = StageMode.training,
) -> None:
assert not self._is_setup
assert distributed.config is self._distributed_config
distributed.check_config(self._distributed_config)
self._mode = mode
self._is_setup = True
self._distributed = distributed
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(

def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
assert not self._is_setup
assert distributed.config is self._distributed_config
distributed.check_config(self._distributed_config)
self._is_setup = True
self._optimizer = optimizer
assert self._multi_stage.support_forward
Expand Down
Loading