Skip to content

Commit 14c980b

Browse files
authored
Support frozen weights (#185)
1 parent 8f5de31 commit 14c980b

File tree

16 files changed

+1055
-655
lines changed

16 files changed

+1055
-655
lines changed

fast_llm/engine/checkpoint/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,5 @@ def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"):
251251
def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"):
252252
pass
253253

254-
def get_num_shards(self, config: CheckpointStateConfigBase) -> int:
255-
return len(self._model.state_shard_names) if config.optimizer_state else 1
256-
257254
def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]:
258255
return self._model.state_shard_names if config.optimizer_state else self._model.state_shard_names[:1]

fast_llm/engine/checkpoint/distributed.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
3636
if self._model.config.distributed.rank == 0:
3737
yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w"))
3838
safetensors.torch.save_file(
39-
tensors={"state_shard": self._model.state_shard[: self.get_num_shards(config)]},
39+
tensors={f"{shard_name}_shard": self._model.get_shard(shard_name) for shard_name in metadata.shards},
4040
filename=config.path / f"rank_{self._model.config.distributed.rank}.safetensors",
4141
metadata=export_safetensors_metadata(serialized_metadata),
4242
)
@@ -45,9 +45,10 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
4545
# TODO: More safety checks
4646
loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm})
4747
loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata)
48-
num_shards = self.get_num_shards(config)
4948
shard_names = self.get_shard_names(config)
50-
Assert.eq(metadata.shards[:num_shards], list(shard_names))
49+
# Make sure all shards to load are in the checkpoint.
50+
Assert.leq(set(self.get_shard_names(config)), set(metadata.shards))
51+
Assert.eq(metadata.shards[: len(shard_names)], list(shard_names))
5152

5253
same_format = (
5354
loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None)
@@ -58,49 +59,68 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
5859
same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group)
5960

6061
if same_format:
61-
log_main_rank("Checkpoint format matches, using fast load")
62+
log_main_rank("Checkpoint format matches, using fast load", log_fn=logger.info)
6263
# TODO: Add version without optimizer state?
6364
with safetensors.safe_open(
6465
config.path / f"rank_{self._model.config.distributed.rank}.safetensors",
6566
framework="pt",
6667
device=str(self._model.distributed.device),
6768
) as f:
68-
# TODO: Does this copy twice?
69-
self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards])
69+
if "state_shard" in f.keys():
70+
# Old format `state_shard` with shape `(num_shards, shard_size)
71+
# TODO v0.3: Use checkpoint version? Drop support?
72+
log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning)
73+
for shard_name in shard_names:
74+
self._model.get_shard(shard_name).copy_(
75+
f.get_slice("state_shard")[metadata.shards.index(shard_name)]
76+
)
77+
else:
78+
# TODO: Does this copy twice?
79+
for shard_name in shard_names:
80+
self._model.get_shard(shard_name).copy_(f.get_tensor(f"{shard_name}_shard"))
81+
7082
else:
71-
log_main_rank("Checkpoint format doesn't match, using safe load")
83+
log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info)
7284
self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn)
73-
with SafeLoad(self._model, num_shards=num_shards, timeout=config.timeout) as context:
85+
with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context:
7486
for rank in range(loaded_config.distributed.world_size):
7587
loaded_model = self._model.__class__(
7688
loaded_config.to_copy({("distributed", "rank"): rank}),
7789
optimizer_state_names=shard_names[1:],
7890
verbose=False,
7991
)
8092
path = config.path / f"rank_{rank}.safetensors"
81-
log_main_rank(f"Loading from {path}")
93+
log_main_rank(f"Loading from {path}", log_fn=logger.info)
8294
# TODO: skip shards without overlap.
8395
with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f:
8496
# TODO: Use self_shard
85-
loaded_shard = f.get_slice("state_shard")[:num_shards]
86-
loaded_model.state_shard_meta.validate(loaded_shard)
97+
if "state_shard" in f.keys():
98+
# Old format `state_shard` with shape `(num_shards, shard_size)
99+
# TODO v0.3: Use checkpoint version? Drop support?
100+
log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning)
101+
loaded_shards = {
102+
shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)]
103+
for shard_name in shard_names
104+
}
105+
else:
106+
loaded_shards = {
107+
shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names
108+
}
87109

88-
# TODO: Improve num shard selection.
89-
self_shard_split = self._model.state_shard[: loaded_shard.size(0)].split(
90-
self._model.stage_shard_sizes, 1
91-
)
92-
loaded_shard_split = loaded_shard.split(loaded_model.stage_shard_sizes, 1)
110+
for shard_name, loaded_shard in loaded_shards.items():
111+
loaded_model.get_shard_meta(shard_name).validate(loaded_shard)
112+
113+
self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names}
93114

94115
counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device)
95-
for loaded_shard_index, loaded_stage in enumerate(loaded_model.stages_on_device.values()):
96-
loaded_shards = (
97-
loaded_shard_split[loaded_shard_index].to(self._model.distributed.device).unbind(0)
98-
)
99-
for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()):
100-
self_stage._copy_shard_overlaps( # noqa
101-
loaded_stage,
102-
self_shard_split[self_shard_index].unbind(0),
103-
loaded_shards,
116+
for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards):
117+
for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards):
118+
self_fsdp.copy_shard_overlaps(
119+
loaded_fsdp,
120+
self_fsdp_shards,
121+
loaded_fsdp_shards,
104122
counter,
123+
self._model.distributed.device,
105124
)
125+
106126
context.mark_as_loaded(counter.item())

fast_llm/engine/checkpoint/safe_load.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,24 @@ class SafeLoad:
2424
In case of failure, it will attempt to find out as precisely as possible where the problem comes from.
2525
"""
2626

27-
def __init__(self, model: "FastLLMModel", *, num_shards: int, timeout: float | None = None):
27+
def __init__(self, model: "FastLLMModel", *, shard_names: tuple[str, ...], timeout: float | None = None):
2828
self._model = model
2929
self._distributed = self._model.distributed
30-
self._num_shards = num_shards
31-
self._self_shard = self._model.state_shard[: self._num_shards]
30+
# self._num_shards = num_shards
31+
self._self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names}
3232
self._timeout = timeout
3333

3434
def __enter__(self) -> "SafeLoad":
3535
self._loaded = 0
3636
self._loaded_parameters = {}
3737
# Track the number of loaded entries.
3838
# Use nan to mark non-loaded entries.
39-
triton_fill(self._self_shard, math.nan)
39+
for self_shard in self._self_shards.values():
40+
triton_fill(self_shard, math.nan)
4041
# Reset and count shard pads
41-
for shard in self._model.state_shard[: self._num_shards]:
42-
shard_split = shard.split(self._model.stage_shard_sizes, 0)
43-
for stage, stage_shard in zip(self._model.stages_on_device.values(), shard_split):
44-
self._loaded += stage.reset_shard_pad(stage_shard)
42+
for _, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards):
43+
for fsdp_shard in fsdp_shards.values():
44+
self._loaded += fsdp.reset_shard_pad(fsdp_shard)
4545
return self
4646

4747
def __exit__(self, exc_type, exc_val, exc_tb):
@@ -70,18 +70,19 @@ def _validate(self) -> None:
7070
logger.info(f"{self._loaded:,} state entries loaded successfully")
7171

7272
def _check_counter(self, errors: list[str]) -> None:
73-
to_load = self._self_shard.numel()
73+
to_load = sum(self_shard.numel() for self_shard in self._self_shards.values())
7474
if self._loaded != to_load:
7575
# Ensure the right amount of weights is loaded.
7676
errors.append(f"Loaded a total of {self._loaded:,}, state entries, expected {to_load:,}")
7777

7878
def _check_missing(self, errors: list[str]) -> None:
7979
# Ensure the loaded weights have a 1-1 mapping by looking for nans.
80-
missing = self._self_shard.new_zeros([], dtype=torch.int64)
80+
missing = torch.zeros([], dtype=torch.int64, device=self._distributed.device)
8181
# Count nans in slices of 100M parameters to limit memory usage.
8282
# TODO: Find better solution (triton kernel?)
83-
for shard_slice in self._self_shard.flatten().split(100000000):
84-
missing += shard_slice.isnan().sum()
83+
for shard in self._self_shards.values():
84+
for shard_slice in shard.flatten().split(100000000):
85+
missing += shard_slice.isnan().sum()
8586
local_missing = missing.item()
8687
if self._distributed.world_group is not None:
8788
all_reduce(missing, group=self._distributed.world_group)
@@ -90,32 +91,32 @@ def _check_missing(self, errors: list[str]) -> None:
9091
errors.append(f"{global_missing:,} state entries failed to load or corrupted (local={local_missing:,}).")
9192
# Determine where the missing values are coming from.
9293
global_total, local_total = 0, 0
93-
for shard_name, shard_ in zip(self._model.state_shard_names[: self._num_shards], self._self_shard):
94-
shard_split = shard_.split(self._model.stage_shard_sizes, 0)
95-
for stage, shard in zip(self._model.stages_on_device.values(), shard_split):
96-
buffer = stage._reconstruct_from_shard(shard)
97-
for i, parameter in enumerate(stage._split_buffer(buffer)):
94+
for stage, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards):
95+
for shard_name, fsdp_shard in fsdp_shards.items():
96+
buffer = fsdp.reconstruct_from_shard(fsdp_shard)
97+
for parameter_name, parameter in fsdp.split_buffer(buffer).items():
9898
missing_for_param = parameter.isnan().sum().item()
9999
if missing_for_param > 0:
100100
global_total += missing_for_param
101-
local_values = stage._split_shard(shard)[i]
101+
local_values = fsdp.split_shard(fsdp_shard)[parameter_name]
102102
local_missing_for_param = local_values.isnan().sum().item()
103103
local_total += local_missing_for_param
104104
errors.append(
105-
f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {stage.parameter_names[i]} in stage {stage.index}, shard {shard_name}"
105+
f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {parameter_name} in stage {stage.index}, shard {shard_name}"
106106
f" (locally {local_missing_for_param:,} out of {local_values.numel():,})"
107107
)
108-
missing_for_pad = buffer[-stage._global_pad :].isnan().sum().item()
108+
missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item()
109109
if missing_for_pad > 0:
110110
global_total += missing_for_pad
111111
local_missing_for_pad = (
112-
shard[-stage._shard_pad :].isnan().sum().item() if stage._shard_pad > 0 else 0
112+
fsdp_shard[-fsdp._shard_pad :].isnan().sum().item() if fsdp._shard_pad > 0 else 0
113113
)
114114
local_total += local_missing_for_pad
115115
errors.append(
116-
f"{missing_for_pad:,} values missing out of {stage._global_pad:,} for padding in stage {stage.index}, shard {shard_name}"
117-
f" (locally {local_missing_for_pad:,} out of {stage._shard_pad:,})"
116+
f"{missing_for_pad:,} values missing out of {fsdp._global_pad:,} for padding in stage {stage.index}, shard {shard_name}"
117+
f" (locally {local_missing_for_pad:,} out of {fsdp._shard_pad:,})"
118118
)
119+
119120
if global_total != global_missing:
120121
errors.append(
121122
f"Incorrect global breakdown of missing state entries (expected {global_missing:,}, got {global_total:,})"
@@ -127,7 +128,7 @@ def _check_missing(self, errors: list[str]) -> None:
127128

128129
def _check_parameters(self, errors: list[str]) -> None:
129130
loaded_shard_names = set(self._loaded_parameters)
130-
shard_names = set(self._model.state_shard_names[: self._num_shards])
131+
shard_names = set(self._self_shards)
131132
if loaded_shard_names != shard_names:
132133
errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}")
133134
for shard_name in shard_names & loaded_shard_names:

fast_llm/engine/checkpoint/state_dict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _serialize_metadata(
7272
return metadata.to_serialized()
7373

7474
def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
75-
with SafeLoad(self._model, num_shards=self.get_num_shards(config), timeout=config.timeout) as context:
75+
with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context:
7676
# The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from
7777
# `state_dict` that are ready for conversion,
7878
# and return a dict containing the converted tensors(s).
@@ -145,7 +145,7 @@ def _load_weights(
145145
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
146146
metadata = self.load_metadata(config)
147147
shard_names = self.get_shard_names(config)
148-
Assert.eq(metadata.shards[: self.get_num_shards(config)], list(shard_names))
148+
Assert.leq(set(shard_names), set(metadata.shards))
149149
for file_name in set(metadata.metadata["state_index"].values()):
150150
logger.info(f"Loading from {config.path / file_name}")
151151
with safetensors.safe_open(

fast_llm/engine/multi_stage/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ class StageConfig(Config):
7070
desc="Reduce and accumulate gradients in fp32 to improve numerical stability.",
7171
hint=FieldHint.optional,
7272
)
73+
store_frozen_weights_in_optimization_precision: bool = Field(
74+
# TODO: Implement and set default to False
75+
default=True,
76+
desc="Store frozen weights in full precision even if not not needed."
77+
"Allows preserving the precision for saved checkpoints,"
78+
" at the cost of memory and compute (copy) overheads.",
79+
hint=FieldHint.optional,
80+
)
7381
debug_layer_outputs: int = Field(
7482
default=0,
7583
desc="Log the output of each layer.",

fast_llm/engine/multi_stage/fast_llm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def initialize_weights(self, timeout: float | None = None) -> None:
9898

9999
def _finalize_load(self, reset_optimizer: bool = True) -> None:
100100
if reset_optimizer:
101-
triton_fill(self._state_shard[1:], 0.0)
101+
triton_fill(self._flat_shard[self._weight_shard_size :], 0.0)
102102
if self._mode.support_forward:
103103
self.invalidate_buffers()
104104
self._is_loaded = True

0 commit comments

Comments
 (0)