Skip to content

Commit 5ba1f0f

Browse files
authored
Reference model support for distillation,. etc. (#216)
1 parent 9d99dc2 commit 5ba1f0f

File tree

19 files changed

+306
-116
lines changed

19 files changed

+306
-116
lines changed

fast_llm/engine/config_utils/tensor_space.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ def __init__(self, distributed_config: DistributedConfig):
119119
self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1))
120120

121121
def setup(self, distributed: "Distributed") -> None:
122-
assert distributed.config is self._distributed_config
123122
assert not self._is_setup
123+
if distributed.config is not self._distributed_config:
124+
distributed.config.compare(self._distributed_config, ValueError)
124125
self._is_setup = True
125126
self._distributed = distributed
126127

fast_llm/engine/distributed/config.py

+77-57
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,12 @@ class DistributedConfig(Config):
251251
desc="Ensure the initialization is the same for any distributed configuration.",
252252
hint=FieldHint.testing,
253253
)
254+
reference_config: "DistributedConfig|None" = Field(
255+
default=None,
256+
init=False,
257+
desc="Pointer to the distributed config this one is an identical copy of.",
258+
hint=FieldHint.derived,
259+
)
254260

255261
def _validate(self) -> None:
256262
if self.world_size is None:
@@ -281,76 +287,90 @@ def _validate(self) -> None:
281287
if self.tensor_parallel == 1:
282288
self.sequence_tensor_parallel = False
283289

284-
self.distributed_dims = {}
290+
if self.reference_config is not None:
291+
self.reference_config.validate()
292+
if self.reference_config.reference_config is not None:
293+
self.reference_config = self.reference_config.reference_config
294+
assert self.reference_config.reference_config is None
295+
self.compare(self.reference_config, ValueError)
296+
self.distributed_dims = self.reference_config.distributed_dims
297+
else:
298+
self.distributed_dims = {}
285299

286-
self.add_distributed_dim(
287-
DistributedDim(name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None)
288-
)
289-
self.add_distributed_dim(
290-
DistributedDim(
291-
name=DistributedDimNames.data,
292-
size=self.data_parallel,
293-
rank=self.data_rank,
294-
id_=f"x_{self.pipeline_rank}_{self.tensor_rank}",
295-
parent=DistributedDimNames.world,
300+
self._add_distributed_dim(
301+
DistributedDim(
302+
name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None
303+
)
296304
)
297-
)
298-
self.add_distributed_dim(
299-
DistributedDim(
300-
name=DistributedDimNames.pipeline,
301-
size=self.pipeline_parallel,
302-
rank=self.pipeline_rank,
303-
id_=f"x_{self.data_rank}_{self.tensor_rank}",
304-
parent=DistributedDimNames.world,
305+
self._add_distributed_dim(
306+
DistributedDim(
307+
name=DistributedDimNames.data,
308+
size=self.data_parallel,
309+
rank=self.data_rank,
310+
id_=f"x_{self.pipeline_rank}_{self.tensor_rank}",
311+
parent=DistributedDimNames.world,
312+
)
305313
)
306-
)
307-
self.add_distributed_dim(
308-
DistributedDim(
309-
name=DistributedDimNames.tensor,
310-
size=self.tensor_parallel,
311-
rank=self.tensor_rank,
312-
id_=f"x_{self.data_rank}_{self.pipeline_rank}",
313-
parent=DistributedDimNames.world,
314+
self._add_distributed_dim(
315+
DistributedDim(
316+
name=DistributedDimNames.pipeline,
317+
size=self.pipeline_parallel,
318+
rank=self.pipeline_rank,
319+
id_=f"x_{self.data_rank}_{self.tensor_rank}",
320+
parent=DistributedDimNames.world,
321+
)
314322
)
315-
)
316-
self.add_distributed_dim(
317-
DistributedDim(
318-
name=DistributedDimNames.sequence_data,
319-
size=self.sequence_data_parallel,
320-
rank=self.sequence_data_rank,
321-
id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
322-
parent=DistributedDimNames.data,
323+
self._add_distributed_dim(
324+
DistributedDim(
325+
name=DistributedDimNames.tensor,
326+
size=self.tensor_parallel,
327+
rank=self.tensor_rank,
328+
id_=f"x_{self.data_rank}_{self.pipeline_rank}",
329+
parent=DistributedDimNames.world,
330+
)
323331
)
324-
)
325-
self.add_distributed_dim(
326-
DistributedDim(
327-
name=DistributedDimNames.batch_data,
328-
size=self.batch_data_parallel,
329-
rank=self.batch_data_rank,
330-
id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
331-
parent=DistributedDimNames.data,
332+
self._add_distributed_dim(
333+
DistributedDim(
334+
name=DistributedDimNames.sequence_data,
335+
size=self.sequence_data_parallel,
336+
rank=self.sequence_data_rank,
337+
id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
338+
parent=DistributedDimNames.data,
339+
)
332340
)
333-
)
334-
self.add_distributed_dim(
335-
DistributedDim(
336-
name=DistributedDimNames.tensor_and_sequence_data,
337-
size=self.sequence_data_parallel * self.tensor_parallel,
338-
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
339-
id_=f"{self.batch_data_rank}_{self.pipeline_rank}",
340-
parent=(
341-
DistributedDimNames.tensor
342-
if self.sequence_data_parallel == 1
343-
else DistributedDimNames.sequence_data if self.tensor_parallel == 1 else DistributedDimNames.world
344-
),
341+
self._add_distributed_dim(
342+
DistributedDim(
343+
name=DistributedDimNames.batch_data,
344+
size=self.batch_data_parallel,
345+
rank=self.batch_data_rank,
346+
id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
347+
parent=DistributedDimNames.data,
348+
)
349+
)
350+
self._add_distributed_dim(
351+
DistributedDim(
352+
name=DistributedDimNames.tensor_and_sequence_data,
353+
size=self.sequence_data_parallel * self.tensor_parallel,
354+
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
355+
id_=f"{self.batch_data_rank}_{self.pipeline_rank}",
356+
parent=(
357+
DistributedDimNames.tensor
358+
if self.sequence_data_parallel == 1
359+
else (
360+
DistributedDimNames.sequence_data
361+
if self.tensor_parallel == 1
362+
else DistributedDimNames.world
363+
)
364+
),
365+
)
345366
)
346-
)
347367

348368
super()._validate()
349369

350370
Assert.in_range(self.rank, 0, self.world_size)
351371
Assert.in_range(self.local_rank, 0, self.local_world_size)
352372

353-
def add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
373+
def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
354374
if distributed_dim.name in self.distributed_dims:
355375
Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name])
356376
else:

fast_llm/engine/distributed/distributed.py

+8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]):
3333

3434
def __init__(self, config: DistributedConfig, use_cpu: bool = False):
3535
super().__init__(config)
36+
assert self._config.reference_config is None
3637
self._use_cpu = use_cpu
3738

3839
if self._use_cpu:
@@ -148,6 +149,13 @@ def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None:
148149
distributed_dim.setup(group)
149150
return group
150151

152+
def check_config(self, config: DistributedConfig) -> None:
153+
# Allows using this `Distributed` on a model with a distributed config that is a copy of `self._config`
154+
if config.reference_config is None:
155+
Assert.is_(config, self._config)
156+
else:
157+
Assert.is_(config.reference_config, self._config)
158+
151159
def set_step(self, step: int, phase: PhaseType) -> None:
152160
"""
153161
Reseed pytorch for a given training step.

fast_llm/engine/huggingface/model.py renamed to fast_llm/engine/inference/huggingface.py

+14-33
Original file line numberDiff line numberDiff line change
@@ -4,57 +4,38 @@
44

55
import transformers.modeling_outputs
66

7-
from fast_llm.config import NoAutoValidate
87
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
9-
from fast_llm.engine.distributed.config import PhaseType
10-
from fast_llm.engine.huggingface.config import HuggingfaceModelConfig
8+
from fast_llm.engine.inference.config import HuggingfaceModelConfig
9+
from fast_llm.engine.inference.runner import InferenceRunner
1110
from fast_llm.engine.multi_stage.config import StageMode
1211
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
13-
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
14-
from fast_llm.engine.schedule.runner import ScheduleRunner
15-
from fast_llm.engine.schedule.schedule import Schedule
1612

1713

1814
class HuggingfacePreTrainedModel(transformers.PreTrainedModel):
1915
config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig
20-
model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel
16+
runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner
2117
config: HuggingfaceModelConfig
2218
# base_model_prefix = ""
2319
# _no_split_modules = None
2420
# _supports_cache_class = False
2521
# _tied_weights_keys = []
2622

2723
def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, **kwargs):
28-
assert self.model_class.config_class is config.model_config_class
24+
assert self.runner_class.model_class.config_class is config.model_config_class
2925
assert config.fast_llm_config is fast_llm_model.config
3026
assert isinstance(config, self.config_class)
27+
3128
super().__init__(config, **kwargs)
32-
self._fast_llm_config = config.fast_llm_config
33-
self._fast_llm_model = fast_llm_model
29+
30+
self._inference_runner = self.runner_class(fast_llm_model)
31+
if not fast_llm_model.is_setup:
32+
fast_llm_model.setup(mode=StageMode.inference)
33+
self._inference_runner.setup()
3434
# Transformers needs to be able to inspect the base model.
35-
self.fast_llm_base_model = self._fast_llm_model.base_model
36-
self._distributed_config = self._fast_llm_config.distributed
35+
self.fast_llm_base_model = fast_llm_model.base_model
3736
# TODO: Support distributed models?
38-
assert self._distributed_config.world_size == 1
39-
self._schedule_config = ScheduleConfig()
40-
# We only need a basic schedule and don't care about dimensions.
41-
# TODO: Sort things out.
42-
with NoAutoValidate():
43-
self._batch_config = BatchConfig()
44-
self._batch_config.setup(self._distributed_config)
45-
self._batch_config.validate()
46-
self._runner = ScheduleRunner(
47-
config=self._schedule_config, multi_stage=self._fast_llm_model, distributed_config=self._distributed_config
48-
)
49-
self._runner.setup(self._fast_llm_model.distributed)
50-
# TODO: Random state? (Distributed.set_step)
51-
self._schedule = Schedule(
52-
multi_stage=self._fast_llm_model,
53-
batch_config=self._batch_config,
54-
schedule_config=self._schedule_config,
55-
distributed_config=self._distributed_config,
56-
phase=PhaseType.inference,
57-
)
37+
assert fast_llm_model.config.distributed.world_size == 1
38+
5839
with transformers.modeling_utils.no_init_weights():
5940
self.post_init()
6041

@@ -79,7 +60,7 @@ def from_pretrained(
7960
config_updates[("distributed", "training_dtype")] = torch_dtype
8061

8162
# Create the model
82-
fast_llm_model = cls.model_class.from_pretrained(
63+
fast_llm_model = cls.runner_class.model_class.from_pretrained(
8364
pretrained_model_name_or_path, config_updates=config_updates, mode=mode
8465
)
8566
config = cls.config_class(fast_llm_model.config)

fast_llm/engine/inference/runner.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import abc
2+
import typing
3+
4+
from fast_llm.config import NoAutoValidate
5+
from fast_llm.engine.distributed.config import PhaseType
6+
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
7+
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
8+
from fast_llm.engine.schedule.runner import ScheduleRunner
9+
from fast_llm.engine.schedule.schedule import Schedule
10+
11+
12+
class InferenceRunner(abc.ABC):
13+
model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel
14+
15+
def __init__(self, fast_llm_model: FastLLMModel):
16+
assert isinstance(fast_llm_model, self.model_class)
17+
self._fast_llm_model = fast_llm_model
18+
# We only need a basic schedule and don't care about dimensions.
19+
self._schedule_config = ScheduleConfig()
20+
# TODO: Sort things out.
21+
with NoAutoValidate():
22+
self._batch_config = BatchConfig()
23+
self._batch_config.setup(self._fast_llm_model.config.distributed)
24+
self._batch_config.validate()
25+
self._runner = ScheduleRunner(
26+
config=self._schedule_config,
27+
multi_stage=self._fast_llm_model,
28+
distributed_config=self._fast_llm_model.config.distributed,
29+
)
30+
# TODO: Random state? (Distributed.set_step)
31+
self._schedule = Schedule(
32+
multi_stage=self._fast_llm_model,
33+
batch_config=self._batch_config,
34+
schedule_config=self._schedule_config,
35+
distributed_config=self._fast_llm_model.config.distributed,
36+
phase=PhaseType.inference,
37+
)
38+
39+
@property
40+
def fast_llm_model(self) -> FastLLMModel:
41+
return self._fast_llm_model
42+
43+
def setup(self):
44+
self._runner.setup(self._fast_llm_model.distributed)
45+
46+
def forward(
47+
self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False
48+
) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]:
49+
# TODO: Return an actual model output.
50+
reduced_losses, update_successful, metrics = self._runner.run_step(
51+
iter((((input_, kwargs),),)),
52+
self._schedule,
53+
iteration=iteration,
54+
return_metrics=return_metrics,
55+
preprocessed=True,
56+
)
57+
assert update_successful
58+
return reduced_losses, metrics

fast_llm/engine/multi_stage/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from fast_llm.utils import Assert
3030

3131
if typing.TYPE_CHECKING:
32-
from fast_llm.engine.huggingface.model import HuggingfacePreTrainedModel
32+
from fast_llm.engine.inference.model import HuggingfacePreTrainedModel
3333
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
3434

3535
logger = logging.getLogger(__name__)

fast_llm/engine/multi_stage/multi_stage.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,15 @@ def __init__(
209209
"Bfloat16 gradient accumulation and reduction is not recommended. (use --full_precision_gradients=1)"
210210
)
211211

212-
def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) -> None:
212+
def setup(self, distributed: Distributed | None = None, mode: StageMode = StageMode.training) -> None:
213213
# TODO: More checks?
214214
stage: Stage
215-
assert distributed.config is self._config.distributed
216215
assert not self._is_setup
217216
self._is_setup = True
217+
if distributed is None:
218+
distributed = Distributed(self._config.distributed)
219+
else:
220+
distributed.check_config(self._config.distributed)
218221
self._distributed = distributed
219222
self._mode = mode
220223
self._base_model.setup(distributed)
@@ -381,6 +384,10 @@ def get_shard(self, name: str) -> torch.Tensor:
381384
raise KeyError(f"Unknown shard name {name}")
382385
return self._shards[name]
383386

387+
@property
388+
def is_setup(self) -> bool:
389+
return self._is_setup
390+
384391
@property
385392
def support_forward(self) -> bool:
386393
assert self._is_setup
@@ -442,6 +449,7 @@ def is_parameter_on_device(self, parameter_name: str) -> bool:
442449

443450
@property
444451
def distributed(self) -> Distributed:
452+
assert self._is_setup
445453
return self._distributed
446454

447455
def invalidate_buffers(self) -> None:

fast_llm/engine/multi_stage/stage_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def setup(
128128
mode: StageMode = StageMode.training,
129129
) -> None:
130130
assert not self._is_setup
131-
assert distributed.config is self._distributed_config
131+
distributed.check_config(self._distributed_config)
132132
self._mode = mode
133133
self._is_setup = True
134134
self._distributed = distributed

fast_llm/engine/schedule/runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797

9898
def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
9999
assert not self._is_setup
100-
assert distributed.config is self._distributed_config
100+
distributed.check_config(self._distributed_config)
101101
self._is_setup = True
102102
self._optimizer = optimizer
103103
assert self._multi_stage.support_forward

0 commit comments

Comments
 (0)