Skip to content

Commit 929c1cf

Browse files
authored
Fix and test LM head (#240)
1 parent df30991 commit 929c1cf

File tree

13 files changed

+320
-58
lines changed

13 files changed

+320
-58
lines changed

fast_llm/engine/multi_stage/fsdp.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def setup(
167167
grad_shard: torch.Tensor | None,
168168
weight_buffer: torch.Tensor | None,
169169
grad_buffer: torch.Tensor | None,
170-
sequence_tensor_parallel: bool = False,
170+
sequence_tensor_parallel: bool,
171+
device: torch.device | None,
171172
) -> None:
172173
assert not self._is_setup
173174
self._is_setup = True
@@ -176,11 +177,19 @@ def setup(
176177

177178
# Validate and set the shards and buffers
178179
if self._mode.on_device:
179-
self._weight_shard = self._weight_shard_meta.validate(weight_shard)
180+
self._weight_shard = (
181+
torch.empty_like(self._weight_shard_meta, device=device)
182+
if weight_shard is None
183+
else self._weight_shard_meta.validate(weight_shard)
184+
)
180185
else:
181186
Assert.none(weight_shard)
182187
if self._mode.support_forward:
183-
self._weight_buffer = self._weight_buffer_meta.validate(weight_buffer)
188+
self._weight_buffer = (
189+
torch.empty_like(self._weight_buffer_meta, device=device)
190+
if weight_buffer is None
191+
else self._weight_buffer_meta.validate(weight_buffer)
192+
)
184193
# Pre-compute the local shard for restore ops.
185194
self._weight_buffer_local_shard = self._weight_buffer[
186195
self._fsdp_dim.rank * self._shard_size : (self._fsdp_dim.rank + 1) * self._shard_size
@@ -189,8 +198,16 @@ def setup(
189198
Assert.none(weight_buffer)
190199

191200
if self._mode.support_backward:
192-
self._grad_shard = self._grad_shard_meta.validate(grad_shard)
193-
self._grad_buffer = self._grad_buffer_meta.validate(grad_buffer)
201+
self._grad_shard = (
202+
torch.empty_like(self._grad_shard_meta, device=device)
203+
if grad_shard is None
204+
else self._grad_shard_meta.validate(grad_shard)
205+
)
206+
self._grad_buffer = (
207+
torch.empty_like(self._grad_buffer_meta, device=device)
208+
if grad_buffer is None
209+
else self._grad_buffer_meta.validate(grad_buffer)
210+
)
194211
# Pre-compute the local shard for reduce ops.
195212
self._grad_buffer_local_shard = self._grad_buffer[
196213
self._fsdp_dim.rank * self._shard_size : (self._fsdp_dim.rank + 1) * self._shard_size

fast_llm/engine/multi_stage/stage.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import logging
23
import typing
34

@@ -38,13 +39,13 @@ def setup( # noqa
3839
self,
3940
*,
4041
distributed: Distributed,
41-
weight_shards: list[torch.Tensor | None] | None,
42-
grad_shards: list[torch.Tensor | None] | None,
43-
weight_buffers: list[torch.Tensor | None] | None,
44-
grad_buffers: list[torch.Tensor | None] | None,
42+
weight_shards: list[torch.Tensor | None] | None = None,
43+
grad_shards: list[torch.Tensor | None] | None = None,
44+
weight_buffers: list[torch.Tensor | None] | None = None,
45+
grad_buffers: list[torch.Tensor | None] | None = None,
4546
mode: StageMode = StageMode.training,
4647
is_tied_weight_copy: bool = False,
47-
weight_buffer_shared_with: list["Stage"],
48+
weight_buffer_shared_with: collections.abc.Sequence["Stage"] = (),
4849
) -> None:
4950
super().setup(
5051
distributed=distributed,
@@ -92,7 +93,11 @@ def forward_meta(self, input_: TensorMeta, kwargs: dict) -> TensorMeta:
9293
return input_
9394

9495
def forward(
95-
self, input_: torch.Tensor, kwargs: dict, losses: dict[str, list[torch.Tensor]], metrics: dict | None = None
96+
self,
97+
input_: torch.Tensor,
98+
kwargs: dict,
99+
losses: dict[str, list[torch.Tensor]] | None = None,
100+
metrics: dict | None = None,
96101
) -> tuple[torch.Tensor | None, tuple[torch.Tensor | None, torch.Tensor | None]]:
97102
assert self._is_restored
98103
assert self._mode.support_forward

fast_llm/engine/multi_stage/stage_base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from fast_llm.config import Configurable
88
from fast_llm.core.distributed import check_parallel_match
9-
from fast_llm.engine.base_model.base_model import BaseModel
9+
from fast_llm.engine.base_model.base_model import BaseModel, Layer
1010
from fast_llm.engine.config_utils.data_type import DataType
1111
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
1212
from fast_llm.engine.distributed.distributed import Distributed
@@ -29,7 +29,7 @@ def __init__(
2929
self,
3030
*,
3131
config: StageConfig,
32-
base_model: BaseModel,
32+
base_model: BaseModel | list[Layer],
3333
distributed_config: DistributedConfig,
3434
begin: int,
3535
end: int,
@@ -153,6 +153,7 @@ def setup(
153153
weight_buffer=weight_buffer,
154154
grad_buffer=grad_buffer,
155155
sequence_tensor_parallel=self._distributed_config.sequence_tensor_parallel,
156+
device=self._distributed.device,
156157
)
157158

158159
if self._mode.support_forward:

fast_llm/functional/triton/normalization.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,17 @@ def triton_normalization_forward(
176176
training: bool,
177177
zero_centered: bool,
178178
) -> tuple[torch.Tensor, list[typing.Any]] | None:
179+
# Note: Converting input automatically to training dtype to match Apex behaviour,
180+
# needed for full precision residual.
181+
# TODO: Review this?
179182
assert weight.shape == input_.shape[-1:]
180183
if bias is not None:
181184
assert weight.shape == bias.shape
182185
assert input_.is_contiguous()
183186
n_rows = input_.shape[:-1].numel()
184187
n_cols = weight.numel()
185188

186-
output = torch.empty_like(input_)
189+
output = torch.empty_like(input_, dtype=weight.dtype)
187190
inv_var = torch.empty(n_rows, dtype=torch.float32, device="cuda")
188191

189192
block_size = triton.next_power_of_2(n_cols)

fast_llm/layers/common/auxiliary_loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]:
1616
def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> torch.Tensor:
1717
if logits_scale_factor != 1.0:
1818
logits *= logits_scale_factor
19-
return torch.mean(torch.square(torch.logsumexp(logits, dim=-1)))
19+
return torch.mean(torch.logsumexp(logits, dim=-1) ** 2)
2020

2121

2222
def z_loss(

fast_llm/layers/common/normalization.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from fast_llm.engine.config_utils.run import log_main_rank
44
from fast_llm.engine.config_utils.tensor_space import TensorDim
55
from fast_llm.functional.config import TritonConfig
6-
from fast_llm.functional.triton.normalization import rms_norm, triton_normalization_autograd
6+
from fast_llm.functional.triton.normalization import triton_normalization_autograd
77
from fast_llm.layers.common.config import NormalizationImplementation
88
from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_
99
from fast_llm.utils import Assert
@@ -141,6 +141,9 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None,
141141
class LayerNorm(torch.nn.Module):
142142
"""
143143
A layer normalization layer, supporting multiple implementations.
144+
Note: Converting input automatically to training dtype to match Apex behaviour,
145+
needed for full precision residual.
146+
TODO: Review this?
144147
"""
145148

146149
def __init__(
@@ -214,12 +217,15 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor:
214217
return FusedLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps)
215218

216219
def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor:
217-
return torch.nn.functional.layer_norm(input_, self.normalized_shape, self.weight, self.bias, self._eps)
220+
return torch.layer_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self.bias, self._eps)
218221

219222

220223
class RMSNorm(torch.nn.Module):
221224
"""
222225
A RMS normalization layer.
226+
Note: Converting input automatically to training dtype to match Apex behaviour,
227+
needed for full precision residual.
228+
TODO: Review this?
223229
"""
224230

225231
def __init__(
@@ -276,4 +282,4 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor:
276282
return FusedRMSNorm.apply(input_, self.normalized_shape, self.weight, self._eps)
277283

278284
def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor:
279-
return rms_norm(input_, self.weight, self._eps)
285+
return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps)

fast_llm/layers/language_model/head.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def __init__(
5050
self._group_size = tensor_space.distributed_config.tensor_parallel
5151
self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel
5252
self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings
53-
self._sequence_parallel_logits = self._sequence_parallel and not self._parallel_embeddings
53+
self._sequence_parallel_logits = (
54+
tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings
55+
)
5456
self._cross_entropy_splits = config.cross_entropy_splits
5557
if self._cross_entropy_splits is not None and self._sequence_parallel:
5658
assert not self._parallel_embeddings
@@ -67,7 +69,7 @@ def __init__(
6769
# >0: multi-token prediction (MTP)
6870
Assert.geq(prediction_distance, 0)
6971
self._prediction_distance = prediction_distance
70-
self.is_last_head = self._prediction_distance == config.prediction_heads - 1
72+
self._is_last_head = self._prediction_distance == config.prediction_heads - 1
7173

7274
self._init_output_weights(hidden_dim, config)
7375

@@ -114,7 +116,7 @@ def forward(
114116
tensor_name="Loss",
115117
reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa
116118
)
117-
if not self.is_last_head:
119+
if not self._is_last_head:
118120
# MTP: split the stacked input
119121
shared_hidden, input_ = torch.unbind(input_, dim=0)
120122
# TODO: Pytorch copies the grads in backward for no reason (not sure if still the case)
@@ -123,10 +125,10 @@ def forward(
123125
# TODO: Drop autograd entirely.
124126
# TODO: Skip cross-entropy backward if not needed.
125127
language_model_loss = self._forward(input_, kwargs, losses)
126-
if language_model_loss is not None:
128+
if losses is not None and language_model_loss is not None:
127129
losses[self._loss_name].append(language_model_loss)
128130
# TODO: Return the model output when needed.
129-
if self.is_last_head:
131+
if self._is_last_head:
130132
# Last head should return the loss for backward.
131133
return language_model_loss
132134
else:
@@ -147,14 +149,13 @@ def _forward_backward(
147149
if target is not None:
148150
if self._config.distillation_model is None:
149151
# MTP: Shift the labels
150-
target = (
151-
target[self._prediction_distance : self._prediction_distance + input_.size(0),]
152-
if kwargs[TransformerKwargs.sequence_first]
153-
else target[
154-
:,
155-
self._prediction_distance : self._prediction_distance + input_.size(1),
156-
]
152+
target_sequence_length = (
153+
target.size(1 - kwargs[TransformerKwargs.sequence_first]) + 1 - self._config.prediction_heads
157154
)
155+
if TransformerKwargs.sequence_q_dim in kwargs:
156+
Assert.eq(target_sequence_length, kwargs[TransformerKwargs.sequence_q_dim].size)
157+
target_slice = slice(self._prediction_distance, self._prediction_distance + target_sequence_length)
158+
target = target[target_slice] if kwargs[TransformerKwargs.sequence_first] else target[:, target_slice]
158159
target = target.flatten()
159160
else:
160161
# Target is reference model logits.

fast_llm/layers/transformer/attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
super().__init__()
8585
self._config = config
8686
self._tensor_space = tensor_space
87-
Assert.in_range_incl(layer_index, 1, self._config.num_layers)
87+
Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1))
8888
self._layer_index = layer_index
8989
self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel
9090
self._debug_transformer = self._config.debug_transformer

fast_llm/layers/transformer/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -674,11 +674,11 @@ def _validate(self) -> None:
674674
if self.init_method_std_qkv is None:
675675
self.init_method_std_qkv = self.init_method_std
676676
if self.init_method_std_attn_proj is None:
677-
self.init_method_std_attn_proj = self.init_method_std / (2 * self.num_layers) ** 0.5
677+
self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5
678678
if self.init_method_std_mlp_1 is None:
679679
self.init_method_std_mlp_1 = self.init_method_std
680680
if self.init_method_std_mlp_2 is None:
681-
self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5
681+
self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5
682682
if self.init_method_max_qkv is None:
683683
self.init_method_max_qkv = self.init_method_max
684684
if self.init_method_min_qkv is None:

fast_llm/models/gpt/model.py

+19-23
Original file line numberDiff line numberDiff line change
@@ -72,34 +72,30 @@ def __init__(
7272
self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space))
7373

7474
def get_output_layers(self) -> list[Layer]:
75-
return [
76-
layer
77-
for i in range(self._config.prediction_heads)
78-
for layer in [
79-
TransformerLayer(
80-
self._config.transformer,
81-
self._tensor_space,
82-
# TODO MTP: which index?
83-
layer_index=self._config.transformer.num_layers,
84-
# The last layer only returns the transformer output.
85-
# The previous layers return a stack of shared_hidden and transformer_output.
86-
return_input=i < self._config.prediction_heads - 1,
87-
),
75+
layers = []
76+
for i in range(self._config.prediction_heads):
77+
if i > 0:
78+
layers.append(
79+
TransformerLayer(
80+
self._config.transformer,
81+
self._tensor_space,
82+
# TODO MTP: which index?
83+
layer_index=max(self._config.transformer.num_layers, 1),
84+
# The last layer only returns the transformer output.
85+
# The previous layers return a stack of shared_hidden and transformer_output.
86+
return_input=i < self._config.prediction_heads - 1,
87+
)
88+
)
89+
layers.append(
8890
LanguageModelHead(
8991
self._config,
9092
self._tensor_space,
9193
prediction_distance=i,
92-
),
93-
]
94-
]
94+
)
95+
)
96+
return layers
9597

9698
def get_layers(self) -> list[Layer]:
97-
if self._config.transformer.num_layers == 0:
98-
Assert.eq(self._config.prediction_heads, 1)
99-
return [
100-
LanguageModelEmbedding(self._config, self._tensor_space),
101-
LanguageModelHead(self._config, self._tensor_space, 0),
102-
]
10399
return [
104100
LanguageModelEmbedding(self._config, self._tensor_space),
105101
*[
@@ -108,7 +104,7 @@ def get_layers(self) -> list[Layer]:
108104
self._tensor_space,
109105
layer_index=i + 1,
110106
)
111-
for i in range(self._config.transformer.num_layers - 1)
107+
for i in range(self._config.transformer.num_layers)
112108
],
113109
*self.get_output_layers(),
114110
]

fast_llm/utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,17 @@ def multiple(x, y):
144144
@staticmethod
145145
def rms_close(x, y, threshold):
146146
rms = rms_diff(x, y).item()
147-
assert rms <= threshold, f"Rms diff too big ({rms} > {threshold}) between tensors {x} and {y}"
147+
assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}"
148+
149+
@staticmethod
150+
def rms_close_relative(x, y, threshold, min_threshold=0):
151+
import torch
152+
153+
Assert.eq(x.shape, y.shape)
154+
scale = (torch.sum(x**2 + y**2) / (2 * x.numel())) ** 0.5
155+
threshold = max(threshold * scale, min_threshold)
156+
rms = rms_diff(x, y).item()
157+
assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}"
148158

149159
@staticmethod
150160
def all_equal(x, y):
@@ -156,7 +166,7 @@ def all_equal(x, y):
156166

157167
neq = x != y
158168
if neq.any().item(): # noqa
159-
index = torch.where(neq) # noqa
169+
index = None if x.numel() == 1 else torch.where(neq) # noqa
160170
raise AssertionError(
161171
f"Tensors have {index[0].numel()} different entries out of "
162172
f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}"

tests/layers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)