From b30d22f49eb5975d7b03fa82c9d46f98fed5d8e4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 7 Mar 2025 18:47:40 -0500 Subject: [PATCH 1/4] lora --- fast_llm/engine/base_model/base_model.py | 62 ++++++++++++++++++++-- fast_llm/functional/linear.py | 6 ++- fast_llm/layers/common/config.py | 52 +++++++++++++++++++ fast_llm/layers/common/linear.py | 13 ++++- fast_llm/layers/common/normalization.py | 5 +- fast_llm/layers/common/peft.py | 66 ++++++++++++++++++++++++ fast_llm/layers/transformer/attention.py | 63 +++++++++++----------- fast_llm/layers/transformer/config.py | 13 ++++- fast_llm/layers/transformer/mlp.py | 36 +++++++------ 9 files changed, 260 insertions(+), 56 deletions(-) create mode 100644 fast_llm/layers/common/peft.py diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 7233c183..7a23e97d 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -14,17 +14,71 @@ from fast_llm.utils import Assert -class Module(torch.nn.Module, abc.ABC): - """ """ +class FastLLMModule(torch.nn.Module, abc.ABC): + def forward(self, *args, **kwargs): + """ + Run a forward pass for the module, with autograd support. + """ + raise NotImplementedError() + + def forward_only(self, *args, **kwargs) -> tuple[typing.Any, typing.Any]: + """ + Run only the forward pass, and return the output and context for backward. + """ + raise NotImplementedError() + + def backward(self, *grad_outputs: torch.Tensor, context: typing.Any) -> tuple[torch.Tensor, ...]: + """ + Run the full backward pass using the output grads and the context, and return the input grads. + Parameter gradients should be accumulated directly in their gradient buffer rather than returned. + """ + raise NotImplementedError() + + def backward_input(self, *grad_outputs: torch.Tensor, context: typing.Any) -> tuple[torch.Tensor, ...]: + """ + Run the backward pass using the output grads and the context, and return the input grads. + Parameter gradients should be accumulated directly in their gradient buffer rather than returned. + """ + raise NotImplementedError() + + def backward_parameters(self, *grad_outputs: torch.Tensor, context: typing.Any) -> None: + """ + Run the backward pass using the output grads and the context, and return the input grads. + Parameter gradients should be accumulated directly in their gradient buffer rather than returned. + """ + raise NotImplementedError() + - def forward(self, input_, kwargs): +class SimpleFastLLMModule(FastLLMModule): + """ + A simple module with a single input and output. + """ + + def forward(self, input_) -> tuple[torch.Tensor, typing.Any]: """ Run a forward pass for the module, with autograd support. """ raise NotImplementedError() + def forward_only(self, input_) -> tuple[torch.Tensor, typing.Any]: + # If there is no custom implementation, revert back to autograd. + input_ = input_.detach().requires_grad_() + output = self.forward(input_) + return output.detach(), (input_, output) + + def backward(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tensor: + input_, output = context + output.backward(grad_output) + return input_.grad + + def backward_input(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tensor: + raise NotImplementedError() + + def backward_parameters(self, grad_output: torch.Tensor, context: typing.Any) -> None: + raise NotImplementedError() + -class Layer(Module): +class Layer(FastLLMModule): # Weight used to determine the stage size layer_count: float = 1.0 diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py index d583d1a9..dbc05184 100644 --- a/fast_llm/functional/linear.py +++ b/fast_llm/functional/linear.py @@ -42,7 +42,9 @@ def update_linear_gradients( input_ = input_.flatten(0, -2) lhs, rhs = (input_.t(), grad_output) if transposed_weight else (grad_output.t(), input_) - if TritonConfig.TRITON_LINEAR or sparse_map is not None: + if not weight.requires_grad: + pass + elif TritonConfig.TRITON_LINEAR or sparse_map is not None: # This assumes the transposed_weight is True for input_sparse, False for output_sparse. input_row_sparse_matmul( lhs, @@ -63,7 +65,7 @@ def update_linear_gradients( ) else: accumulate_gradient(weight, torch.mm(lhs, rhs)) - if bias is not None: + if bias is not None and bias.requires_grad: accumulate_gradient(bias, grad_output.sum(dim=0)) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index fff0548c..6fbf683f 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -3,9 +3,11 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.engine.base_model.base_model import SimpleFastLLMModule from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.common.normalization import LayerNorm, RMSNorm @@ -115,3 +117,53 @@ def _from_dict( cls._handle_renamed_field(default, "normalization_implementation", "implementation") cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") return super()._from_dict(default, strict, flat) + + +class PeftType(str, enum.Enum): + # TODO : Use a dynamic config type instead. + none = "none" + lora = "lora" + + +@config_class() +class PeftArchitectureConfig(BaseModelArchitectureConfig): + pass + + +@config_class() +class PeftConfig(PeftArchitectureConfig, BaseModelConfig): + # TODO: Architecture/non-architecture split might not make much sense here. + + type: PeftType = Field( + default=PeftType.none, + desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.", + hint=FieldHint.core, + ) + rank: int = Field( + default=8, + desc="The LoRA rank, i.e. the size of the intermediate dimension.", + hint=FieldHint.stability, + ) + alpha: float = Field( + default=8.0, + desc="The LoRA scaling parameter.", + hint=FieldHint.stability, + ) + dropout: float = Field( + default=0.0, + desc="Dropout rate for LoRA.", + hint=FieldHint.stability, + ) + + def apply_linear(self, linear: LinearBase) -> "SimpleFastLLMModule": + if self.type == PeftType.none: + return linear + elif self.type == PeftType.lora: + from fast_llm.layers.common.lora import LoRALinear + + # TODO: Init method? + return LoRALinear( + linear, linear._weight_init_method, linear._weight_init_method, self.rank, self.alpha, self.dropout + ) + else: + raise NotImplementedError(self.type) diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index afd0d96d..188383bf 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,7 +3,9 @@ import torch +from fast_llm.engine.base_model.base_model import SimpleFastLLMModule from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, input_parallel_linear_backward, @@ -20,7 +22,7 @@ logger = logging.getLogger(__name__) -class LinearBase(torch.nn.Module): +class LinearBase(SimpleFastLLMModule): """ A base module for linear layers holding weights and biases. """ @@ -41,6 +43,8 @@ def __init__( self._transposed_weight = transposed_weight self._in_dim = in_dim self._out_dim = out_dim + self._lr_scale = lr_scale + self._weight_init_method = weight_init_method self.weight = ParameterMeta.from_dims( (self._in_dim, self._out_dim) if self._transposed_weight else (self._out_dim, self._in_dim), init_method=weight_init_method, @@ -57,11 +61,18 @@ def __init__( ) else: self.bias = None + self._forward = wrap_forward_backward(self.forward_only, self.backward) @property def transposed_weight(self) -> bool: return self._transposed_weight + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self._forward(input) + + def backward(self, input: torch.Tensor) -> torch.Tensor: + pass + class Linear(LinearBase): """ diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 25e8090c..5840f673 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,5 +1,6 @@ import torch +from fast_llm.engine.base_model.base_model import SimpleFastLLMModule from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig @@ -130,7 +131,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, return grad_input, None, None, None -class LayerNorm(torch.nn.Module): +class LayerNorm(SimpleFastLLMModule): """ A layer normalization layer, supporting multiple implementations. """ @@ -209,7 +210,7 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.nn.functional.layer_norm(input_, self.normalized_shape, self.weight, self.bias, self._eps) -class RMSNorm(torch.nn.Module): +class RMSNorm(SimpleFastLLMModule): """ A RMS normalization layer. """ diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py new file mode 100644 index 00000000..ecfc88c2 --- /dev/null +++ b/fast_llm/layers/common/peft.py @@ -0,0 +1,66 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import SimpleFastLLMModule +from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.layers.common.linear import Linear, LinearBase + + +class LoRALinear(SimpleFastLLMModule): + def __init__( + self, + linear: LinearBase, + init_method_0, + init_method_1, + rank: int, + alpha: float, + dropout: float = 0.0, + ): + super().__init__() + self.linear = linear + self.linear.weight.requires_grad = False + if self.linear._in_dim.parallel_dim is not None or self.linear._out_dim.parallel_dim is not None: + # TODO: TP support. + raise ValueError("LoRA not supported with tensor parallelism.") + self._alpha = alpha + self._dropout = dropout + self._transposed_weight = self.linear._transposed_weight + middle_dim = TensorDim("lora_middle", rank) + + self.layer_0 = Linear( + self.linear._in_dim, + middle_dim, + bias=False, + weight_init_method=init_method_0, + transposed_weight=self.linear._transposed_weight, + lr_scale=self.linear._lr_scale, + ) + self.layer_1 = Linear( + middle_dim, + self.linear._out_dim, + bias=False, + weight_init_method=init_method_1, + transposed_weight=self.linear._transposed_weight, + lr_scale=self.linear._lr_scale, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # TODO: Optimize. + return self.linear(x) + (self._alpha / self._rank) * self.layer_1( + self.layer_0( + torch.nn.functional.dropout(x, self._dropout, training=self._training) if self._dropout > 0.0 else x + ) + ) + + def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, typing.Any]: + # TODO: Make a proper implementation. + # TODO: Make this generic + input_ = input_.detach().requires_grad_() + output = self.forward(input_) + return output.detach(), (input_, output) + + def backward(self, grad_output: torch.Tensor, context: typing.Any): + input_, output = context + output.backward(grad_output) + return input_.grad diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index f64de9f1..43574434 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -5,16 +5,13 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.kernels import flash_attn from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim +from fast_llm.engine.base_model.base_model import FastLLMModule from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, -) +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -44,7 +41,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(FastLLMModule): """ A self-attention layer. """ @@ -103,35 +100,41 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) - self.query = OutputParallelLinear( - hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), - bias=self._config.add_attn_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, - sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + self.query = self._config.peft.apply_linear( + OutputParallelLinear( + hidden_dim, + self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + bias=self._config.add_attn_qkv_bias, + weight_init_method=init_method_qkv, + bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + sequence_parallel=self._sequence_parallel, + lr_scale=self._config.attention_lr_scale, + ) ) - self.key_value = OutputParallelLinear( - hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), - bias=self._config.add_attn_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, - sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + self.key_value = self._config.peft.apply_linear( + OutputParallelLinear( + hidden_dim, + self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + bias=self._config.add_attn_qkv_bias, + weight_init_method=init_method_qkv, + bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + sequence_parallel=self._sequence_parallel, + lr_scale=self._config.attention_lr_scale, + ) ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Output. - self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), - hidden_dim, - bias=self._config.add_attn_dense_bias, - weight_init_method=init_method_std_attn_proj, - bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, - sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + self.dense = self._config.peft.apply_linear( + InputParallelLinear( + self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + hidden_dim, + bias=self._config.add_attn_dense_bias, + weight_init_method=init_method_std_attn_proj, + bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, + sequence_parallel=self._sequence_parallel, + lr_scale=self._config.attention_lr_scale, + ) ) def _attn_fused( diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf985392..28044e80 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -10,7 +10,12 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import NormalizationArchitectureConfig, NormalizationConfig +from fast_llm.layers.common.config import ( + NormalizationArchitectureConfig, + NormalizationConfig, + PeftArchitectureConfig, + PeftConfig, +) from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) @@ -163,6 +168,11 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.core, ) + peft: PeftArchitectureConfig = Field( + default_factory=PeftArchitectureConfig, + desc="Configuration for the parameter-efficient fine tuning.", + hint=FieldHint.core, + ) num_layers: int = Field( default=12, desc="Number of layers in the transformer.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) @@ -370,6 +380,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig) + peft: PeftConfig = FieldUpdate(default_factory=PeftConfig) # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) init_method_std: float = Field( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index adc6242d..03ca061f 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -39,23 +39,27 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) - self.layer_1 = LinearBase( - hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), - bias=config.add_mlp_bias, - weight_init_method=init_method_1, - bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale), + self.layer_1 = self._config.peft.apply_linear( + LinearBase( + hidden_dim, + tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + bias=config.add_mlp_bias, + weight_init_method=init_method_1, + bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, + lr_scale=tuple(config.mlp_lr_scale), + ) ) - self.layer_2 = LinearBase( - self._intermediate_dim, - hidden_dim, - bias=config.add_mlp_bias, - weight_init_method=init_method_2, - bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, - auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, - transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale), + self.layer_2 = self._config.peft.apply_linear( + LinearBase( + self._intermediate_dim, + hidden_dim, + bias=config.add_mlp_bias, + weight_init_method=init_method_2, + bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, + auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, + transposed_weight=True, + lr_scale=tuple(config.mlp_lr_scale), + ) ) From ea0279eec9423db4d338d8d1c440a3acabfb161a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 10 Mar 2025 11:08:33 -0400 Subject: [PATCH 2/4] misc --- fast_llm/core/ops.py | 14 +- fast_llm/engine/base_model/base_model.py | 44 +---- fast_llm/layers/common/linear.py | 235 ++++++++++++++++------- fast_llm/layers/common/normalization.py | 6 +- fast_llm/layers/common/peft.py | 50 +++-- fast_llm/layers/transformer/mlp.py | 143 +++++++++++++- 6 files changed, 354 insertions(+), 138 deletions(-) diff --git a/fast_llm/core/ops.py b/fast_llm/core/ops.py index a7492daa..6b2b89ef 100644 --- a/fast_llm/core/ops.py +++ b/fast_llm/core/ops.py @@ -4,11 +4,11 @@ """ import logging +import typing import torch import torch._dynamo # noqa import torch.autograd -from torch._C._distributed_c10d import Work from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_gather_into_tensor, all_reduce, reduce_scatter_tensor from fast_llm.utils import Assert, div @@ -18,12 +18,12 @@ def reduce_op( input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp = ReduceOp.SUM, async_op: bool = False -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, typing.Callable[[], None]] | torch.Tensor: if group: handle = all_reduce(input_, group=group, async_op=async_op, op=op) else: handle = None - return (input_, handle) if async_op else input_ + return (input_, handle.wait) if async_op else input_ def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: @@ -62,7 +62,7 @@ def swap_mult_dim(tensor: torch.Tensor, factor: int, old_dim: int, new_dim: int) def gather_op( input_: torch.Tensor, group: ProcessGroup | None, dim: int, async_op: bool = False, out=None -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, typing.Callable[[], None]] | torch.Tensor: """Gather tensors and concatenate along the last dimension.""" # Bypass the function if we are using only 1 GPU. if not group: @@ -79,7 +79,7 @@ def gather_op( assert not async_op # TODO: contiguous? out = swap_mult_dim(out, group.size(), 0, dim) - return (out, handle) if async_op else out + return (out, handle.wait) if async_op else out def reduce_scatter_op( @@ -89,7 +89,7 @@ def reduce_scatter_op( op: ReduceOp = ReduceOp.SUM, dim: int = 0, async_op: bool = False, -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, typing.Callable[[], None]] | torch.Tensor: """Reduce-scatter the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. if not group: @@ -99,7 +99,7 @@ def reduce_scatter_op( input_ = swap_mult_dim(input_, group.size(), dim, 0) # TODO: May give the wrong output without the contiguous call. handle = reduce_scatter_tensor(output, input_.contiguous(), group=group, async_op=async_op, op=op) - return (output, handle) if async_op else output + return (output, handle.wait) if async_op else output class _ReduceBackward(torch.autograd.Function): diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 7a23e97d..b0e941d0 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -24,6 +24,7 @@ def forward(self, *args, **kwargs): def forward_only(self, *args, **kwargs) -> tuple[typing.Any, typing.Any]: """ Run only the forward pass, and return the output and context for backward. + TODO: Make a generic type for the context? """ raise NotImplementedError() @@ -34,49 +35,6 @@ def backward(self, *grad_outputs: torch.Tensor, context: typing.Any) -> tuple[to """ raise NotImplementedError() - def backward_input(self, *grad_outputs: torch.Tensor, context: typing.Any) -> tuple[torch.Tensor, ...]: - """ - Run the backward pass using the output grads and the context, and return the input grads. - Parameter gradients should be accumulated directly in their gradient buffer rather than returned. - """ - raise NotImplementedError() - - def backward_parameters(self, *grad_outputs: torch.Tensor, context: typing.Any) -> None: - """ - Run the backward pass using the output grads and the context, and return the input grads. - Parameter gradients should be accumulated directly in their gradient buffer rather than returned. - """ - raise NotImplementedError() - - -class SimpleFastLLMModule(FastLLMModule): - """ - A simple module with a single input and output. - """ - - def forward(self, input_) -> tuple[torch.Tensor, typing.Any]: - """ - Run a forward pass for the module, with autograd support. - """ - raise NotImplementedError() - - def forward_only(self, input_) -> tuple[torch.Tensor, typing.Any]: - # If there is no custom implementation, revert back to autograd. - input_ = input_.detach().requires_grad_() - output = self.forward(input_) - return output.detach(), (input_, output) - - def backward(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tensor: - input_, output = context - output.backward(grad_output) - return input_.grad - - def backward_input(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tensor: - raise NotImplementedError() - - def backward_parameters(self, grad_output: torch.Tensor, context: typing.Any) -> None: - raise NotImplementedError() - class Layer(FastLLMModule): # Weight used to determine the stage size diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 188383bf..e6aa11ff 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -1,28 +1,64 @@ +import dataclasses import logging import typing import torch -from fast_llm.engine.base_model.base_model import SimpleFastLLMModule +from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op +from fast_llm.engine.base_model.base_model import FastLLMModule from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.functional.linear import ( - input_parallel_linear_autograd, - input_parallel_linear_backward, - input_parallel_linear_forward, - linear_autograd, - linear_backward, - linear_forward, - output_parallel_linear_autograd, - output_parallel_linear_backward, - output_parallel_linear_forward, -) +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.linear import maybe_transpose, update_linear_gradients +from fast_llm.functional.triton.sparse_copy import SparseMap +from fast_llm.functional.triton.sparse_linear import dense_matmul, input_inner_sparse_matmul, output_sparse_matmul from fast_llm.tensor import ParameterMeta, init_zeros_ logger = logging.getLogger(__name__) -class LinearBase(SimpleFastLLMModule): +@dataclasses.dataclass +class LinearContext: + input_: torch.Tensor + sparse_map: SparseMap | None + + +class LinearLike(FastLLMModule): + def __init__(self): + super().__init__() + self._forward = wrap_forward_backward(self.forward_only, self.backward) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return self._forward(input_) + + def forward_only( + self, input_: torch.Tensor, sparse_map: SparseMap | None = None + ) -> tuple[torch.Tensor, LinearContext]: + raise NotImplementedError() + + def backward(self, grad_output: torch.Tensor, context: LinearContext) -> torch.Tensor: + context, gather_handle = self.backward_gather_input(context) + grad_input, reduce_handle = self.backward_activation(grad_output, context) + if gather_handle is not None: + gather_handle() + self.backward_parameters(grad_output, context) + if reduce_handle is not None: + gather_handle() + return grad_input + + def backward_gather_input(self, context: LinearContext) -> tuple[LinearContext, typing.Callable[[], None] | None]: + return context, None + + def backward_activation( + self, grad_output: torch.Tensor, context: LinearContext + ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: + raise NotImplementedError() + + def backward_parameters(self, grad_output: torch.Tensor, context: LinearContext) -> None: + raise NotImplementedError() + + +class LinearBase(LinearLike): """ A base module for linear layers holding weights and biases. """ @@ -67,11 +103,10 @@ def __init__( def transposed_weight(self) -> bool: return self._transposed_weight - def forward(self, input: torch.Tensor) -> torch.Tensor: - return self._forward(input) - - def backward(self, input: torch.Tensor) -> torch.Tensor: - pass + def backward_parameters(self, grad_output: torch.Tensor, context: LinearContext) -> None: + update_linear_gradients( + context.input_, self.weight, self.bias, grad_output, self._transposed_weight, context.sparse_map + ) class Linear(LinearBase): @@ -102,16 +137,38 @@ def __init__( lr_scale=lr_scale, ) - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return linear_autograd(input_, weight=self.weight, bias=self.bias, transposed_weight=self._transposed_weight) - def forward_only( - self, input_: torch.Tensor - ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]]: - return linear_forward(input_, weight=self.weight, bias=self.bias, transposed_weight=self._transposed_weight) + self, input_: torch.Tensor, sparse_map: SparseMap | None = None + ) -> tuple[torch.Tensor, LinearContext]: + assert sparse_map is None + + # Matmul + if TritonConfig.TRITON_LINEAR: + assert self.bias is None + output = dense_matmul( + input_.flatten(0, -2), + maybe_transpose(self.weight, not self._transposed_weight), + ).unflatten(0, input_.shape[:-1]) + else: + output = torch.nn.functional.linear( + input_, maybe_transpose(self.weight, self._transposed_weight), self.bias + ) + return output, LinearContext(input_, None) + + def backward_activation( + self, grad_output: torch.Tensor, context: LinearContext + ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: + weight_t = maybe_transpose(self.weight, self._transposed_weight) + + # Input grad + if TritonConfig.TRITON_LINEAR: + grad_input = dense_matmul(grad_output.flatten(0, -2), weight_t).view( + *grad_output.shape[:-1], weight_t.size(-1) + ) + else: + grad_input = grad_output.matmul(weight_t) - def backward(self, grad_output: torch.Tensor, context) -> torch.Tensor: # noqa - return linear_backward(grad_output, context) + return grad_input, None class OutputParallelLinear(LinearBase): @@ -144,28 +201,58 @@ def __init__( lr_scale=lr_scale, ) - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return output_parallel_linear_autograd( - input_, - weight=self.weight, - bias=self.bias, - group=self._out_dim.parallel_group, - sequence_parallel=self._sequence_parallel, - transposed_weight=self._transposed_weight, - ) + def forward_only( + self, input_: torch.Tensor, sparse_map: SparseMap | None = None + ) -> tuple[torch.Tensor, LinearContext]: + + # Gather sequence-parallel slices (non-overlapped) + input1 = gather_op(input_, self._out_dim.parallel_group, dim=0) if self._sequence_parallel else input_ + + # Matmul + if TritonConfig.TRITON_LINEAR or sparse_map is not None: + assert self.bias is None + if sparse_map is not None: + assert not self._transposed_weight + output = output_sparse_matmul( + input1.flatten(0, -2), + maybe_transpose(self.weight, not self._transposed_weight), + sparse_map, + ).unflatten(0, input_.shape[:-1]) + else: + output = torch.nn.functional.linear( + input1, maybe_transpose(self.weight, self._transposed_weight), self.bias + ) + + return output, LinearContext(input_, sparse_map) + + def backward_gather_input(self, context: LinearContext) -> tuple[LinearContext, typing.Callable[[], None] | None]: + # Gather sequence-parallel slices (overlapped) + if self._sequence_parallel: + input_, gather_handle = gather_op(context.input_, self._out_dim.parallel_group, dim=0, async_op=True) + context = dataclasses.replace(context, input_=input_) + else: + gather_handle = None + return context, gather_handle + + def backward_activation( + self, grad_output: torch.Tensor, context: LinearContext + ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: + weight_t = maybe_transpose(self.weight, self._transposed_weight) - def forward_only(self, input_) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: - return output_parallel_linear_forward( - input_, - weight=self.weight, - bias=self.bias, - group=self._out_dim.parallel_group, - sequence_parallel=self._sequence_parallel, - transposed_weight=self._transposed_weight, + # Input grad + if TritonConfig.TRITON_LINEAR or context.sparse_map is not None: + grad_input = input_inner_sparse_matmul(grad_output.flatten(0, -2), weight_t, context.sparse_map).view( + *grad_output.shape[:-1], weight_t.size(-1) + ) + else: + grad_input = grad_output.matmul(weight_t) + + # Reduce input grad (overlapped) + grad_input, reduce_handle = (reduce_scatter_op if self._sequence_parallel else reduce_op)( + grad_input, group=self._out_dim.parallel_group, async_op=True ) - def backward(self, grad_output: torch.Tensor, context: tuple[typing.Any, ...]): # noqa - return output_parallel_linear_backward(grad_output, context) + return grad_input, reduce_handle class InputParallelLinear(LinearBase): @@ -200,26 +287,44 @@ def __init__( lr_scale=lr_scale, ) - def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: - return input_parallel_linear_autograd( - input_, - weight=self.weight, - bias=self.bias, - group=self._in_dim.parallel_group, - sequence_parallel=self._sequence_parallel, - transposed_weight=self._transposed_weight, - ) + def forward_only( + self, input_: torch.Tensor, sparse_map: SparseMap | None = None + ) -> tuple[tuple[torch.Tensor, torch.Tensor | None], LinearContext]: + # TODO: Fix signature + # Matmul + if TritonConfig.TRITON_LINEAR or sparse_map is not None: + assert self.bias is None + if sparse_map is not None: + assert self._transposed_weight + output = input_inner_sparse_matmul( + input_.flatten(0, -2), maybe_transpose(self.weight, not self._transposed_weight), sparse_map + ).unflatten(0, input_.shape[:-1]) + else: + output = torch.nn.functional.linear( + input_, maybe_transpose(self.weight, self._transposed_weight), self.bias + ) - def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None, tuple[typing.Any, ...]]: - output, context = input_parallel_linear_forward( - input_, - weight=self.weight, - bias=None if self._group else self.bias, - group=self._in_dim.parallel_group, - sequence_parallel=self._sequence_parallel, - transposed_weight=self._transposed_weight, + # Reduce input grad (non-overlapped) + output = (reduce_scatter_op if self._sequence_parallel else reduce_op)( + output, group=self._in_dim.parallel_group ) - return output, self.bias if self._group else None, context + return (output, self.bias if self._in_dim.parallel_group else None), LinearContext(input_, sparse_map) + + def backward_activation( + self, grad_output: torch.Tensor, context: LinearContext + ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: + weight_t = maybe_transpose(self.weight, self._transposed_weight) + + # Gather sequence-parallel slices (non-overlapped) + if self._sequence_parallel: + grad_output = gather_op(grad_output, self._in_dim.parallel_group, dim=0) + + # Input grad + if TritonConfig.TRITON_LINEAR or context.sparse_map is not None: + grad_input = output_sparse_matmul(grad_output.flatten(0, -2), weight_t, context.sparse_map).view( + *grad_output.shape[:-1], weight_t.size(-1) + ) + else: + grad_input = grad_output.matmul(weight_t) - def backward(self, grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor: # noqa - return input_parallel_linear_backward(grad_output, context) + return grad_input, None diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5840f673..c3cc8ff5 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,6 +1,6 @@ import torch -from fast_llm.engine.base_model.base_model import SimpleFastLLMModule +from fast_llm.engine.base_model.base_model import FastLLMModule from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig @@ -131,7 +131,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, return grad_input, None, None, None -class LayerNorm(SimpleFastLLMModule): +class LayerNorm(FastLLMModule): """ A layer normalization layer, supporting multiple implementations. """ @@ -210,7 +210,7 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.nn.functional.layer_norm(input_, self.normalized_shape, self.weight, self.bias, self._eps) -class RMSNorm(SimpleFastLLMModule): +class RMSNorm(FastLLMModule): """ A RMS normalization layer. """ diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index ecfc88c2..a775ba95 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -1,13 +1,20 @@ +import dataclasses import typing import torch -from fast_llm.engine.base_model.base_model import SimpleFastLLMModule from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.layers.common.linear import Linear, LinearBase +from fast_llm.functional.triton.sparse_copy import SparseMap +from fast_llm.layers.common.linear import Linear, LinearBase, LinearContext, LinearLike -class LoRALinear(SimpleFastLLMModule): +@dataclasses.dataclass +class LoRAContext(LinearContext): + input_: torch.Tensor + output: torch.Tensor + + +class LoRALinear(LinearLike): def __init__( self, linear: LinearBase, @@ -44,23 +51,28 @@ def __init__( transposed_weight=self.linear._transposed_weight, lr_scale=self.linear._lr_scale, ) + # TODO: Implement proper backward pass. + self.layer_0.weight.auto_grad_accumulation = True + self.layer_1.weight.auto_grad_accumulation = True - def forward(self, x: torch.Tensor) -> torch.Tensor: - # TODO: Optimize. - return self.linear(x) + (self._alpha / self._rank) * self.layer_1( - self.layer_0( - torch.nn.functional.dropout(x, self._dropout, training=self._training) if self._dropout > 0.0 else x - ) + def forward_only( + self, input_: torch.Tensor, sparse_map: SparseMap | None = None + ) -> tuple[torch.Tensor, LinearContext]: + # TODO: MoE support + assert sparse_map is None + # TODO: torch compile? + input_ = input_.detach().requires_grad_() + output = self.linear(input_) + (self._alpha / self._rank) * self.layer_1( + self.layer_0(torch.dropout(input_, self._dropout, self._training) if self._dropout > 0.0 else input_) ) + return output.detach(), LoRAContext(input_, sparse_map, input_, output) - def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, typing.Any]: - # TODO: Make a proper implementation. - # TODO: Make this generic - input_ = input_.detach().requires_grad_() - output = self.forward(input_) - return output.detach(), (input_, output) + def backward_activation( + self, grad_output: torch.Tensor, context: LoRAContext + ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: + # TODO: Separate backward pass + context.output.backward(grad_output) + return context.input_.grad, None - def backward(self, grad_output: torch.Tensor, context: typing.Any): - input_, output = context - output.backward(grad_output) - return input_.grad + def backward_parameters(self, grad_output: torch.Tensor, context: LoRAContext) -> None: + pass diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 03ca061f..0b369a5a 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -6,7 +6,19 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.functional.triton.mlp import ( + mlp_autograd, + torch_mlp_activation, + triton_mlp_activation_autograd, + triton_mlp_activation_backward, + triton_mlp_activation_forward, +) +from fast_llm.functional.triton.sparse_copy import ( + SparseMap, + copy_dense_to_sparse_forward, + copy_sparse_to_dense_backward, + copy_sparse_to_dense_forward, +) from fast_llm.layers.common.linear import LinearBase from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.tensor import init_normal_, init_zeros_ @@ -62,6 +74,135 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s ) ) + def forward_only( + self, + input_: torch.Tensor, + scores: torch.Tensor | None, + sparse_map: SparseMap | None = None, + ) -> tuple[torch.Tensor, list[typing.Any] | None]: + # Sparse copy + input_shape = input_.shape + intermediate_0 = input_ if sparse_map is None else copy_dense_to_sparse_forward(input_, sparse_map)[0] + + # Layer 1 + intermediate_1, _ = self.layer_1.forward_only(intermediate_0, sparse_map) + + if self._recompute_level.recompute_sparse_input: + intermediate_0 = None + else: + input_ = None + + # Activation + if TritonConfig.TRITON_ENABLED: + intermediate_2, _ = triton_mlp_activation_forward(intermediate_1, self._gated, self._activation_type) + else: + do_grad = self.training and not self._recompute_level.recompute_activation + with torch.set_grad_enabled(do_grad): + intermediate_2 = torch_mlp_activation( + intermediate_1.detach().requires_grad_(do_grad), self._gated, self._activation_type + ) + if self._recompute_level.recompute_layer_1: + intermediate_1 = None + + # Layer 2 + intermediate_3, _ = self.layer_2.forward_only(intermediate_2, sparse_map) + + # Context + if self._recompute_level.recompute_activation or not self.training: + intermediate_2 = None + + # Sparse copy + if sparse_map is None: + output = intermediate_3 + intermediate_3 = None + else: + output, _ = copy_sparse_to_dense_forward(intermediate_3, scores, sparse_map) + + context = ( + [ + input_, + scores, + intermediate_0, + intermediate_1, + intermediate_2, + intermediate_3, + sparse_map, + input_shape, + ] + if self.training + else None + ) + return output, context + + def backward(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tensor: + ( + input_, + scores, + intermediate_0, + intermediate_1, + intermediate_2, + intermediate_3, + sparse_map, + input_shape, + ) = context + context.clear() + + # Sparse copy + if sparse_map is None: + grad_scores = None + else: + grad_output, grad_scores = copy_sparse_to_dense_backward(grad_output, (sparse_map, intermediate_3, scores)) + + grad_intermediate_2, handle = self.layer_2.backward_input(grad_output, ()) + + # Sparse input recomputation + if intermediate_0 is None: + intermediate_0 = input_ if sparse_map is None else copy_dense_to_sparse_forward(input_, sparse_map)[0] + + # Layer 1 recomputation + if intermediate_1 is None: + intermediate_1 = self.layer_1.forward_only(intermediate_0, sparse_map) + + # Activation recomputation and/or backward + if TritonConfig.TRITON_ENABLED: + grad_intermediate_1, intermediate_2_ = triton_mlp_activation_backward( + grad_intermediate_2, (intermediate_1, self._gated, self._activation_type), intermediate_2 is None + ) + else: + if intermediate_2 is None: + with torch.set_grad_enabled(True): + intermediate_2_ = torch_mlp_activation( + intermediate_1.detach().requires_grad_(True), self._gated, self._activation_type + ) + else: + intermediate_2_ = intermediate_2 + intermediate_2_.backward(grad_intermediate_2) + grad_intermediate_1 = intermediate_1.grad + + # Layer 2 parameter grad + del grad_intermediate_2, intermediate_1 + update_linear_gradients( + intermediate_2_ if intermediate_2 is None else intermediate_2, + weight_2, + bias_2, + grad_output, + transposed_layer_2_weight, + sparse_map, + ) + del grad_output, intermediate_2, intermediate_2_ + + # Layer 1 backward + grad_input = output_parallel_linear_backward( + grad_intermediate_1, + (intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map), + ) + + # Sparse copy + if sparse_map is not None: + grad_input = copy_dense_to_sparse_backward(grad_input, (sparse_map, input_shape)) + + return grad_input, grad_scores + class MLP(MLPBase): def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): From d19842a8e350de662facecdb04e09d1c9234f8fe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 10 Mar 2025 13:15:16 -0400 Subject: [PATCH 3/4] MLP support --- fast_llm/functional/linear.py | 254 -------------------------- fast_llm/functional/triton/mlp.py | 200 -------------------- fast_llm/layers/common/config.py | 7 +- fast_llm/layers/common/linear.py | 59 +++++- fast_llm/layers/transformer/config.py | 3 + fast_llm/layers/transformer/mlp.py | 123 +++++++------ 6 files changed, 121 insertions(+), 525 deletions(-) delete mode 100644 fast_llm/functional/linear.py diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py deleted file mode 100644 index dbc05184..00000000 --- a/fast_llm/functional/linear.py +++ /dev/null @@ -1,254 +0,0 @@ -""" -Forward and backward pass of linear layers. -""" - -import typing - -import torch - -from fast_llm.core.distributed import ProcessGroup -from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op -from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton.sparse_copy import SparseMap -from fast_llm.functional.triton.sparse_linear import ( - dense_matmul, - input_inner_sparse_matmul, - input_row_sparse_matmul, - output_sparse_matmul, -) -from fast_llm.tensor import accumulate_gradient, param_get_and_unset_is_zero - - -def maybe_transpose(tensor: torch.Tensor, transpose: bool) -> torch.Tensor: - return tensor.t() if transpose else tensor - - -def update_linear_gradients( - input_: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - grad_output: torch.Tensor, - transposed_weight: bool, - sparse_map: SparseMap | None, -) -> None: - """ - Calculate the weight and bias gradients for a linear layer. - TODO: fused_dense_cuda fuses weight gradient with bias gradient, but not with grad accumulation. - Which one is best? (and can we fuse everything?) - """ - - grad_output = grad_output.flatten(0, -2) - input_ = input_.flatten(0, -2) - lhs, rhs = (input_.t(), grad_output) if transposed_weight else (grad_output.t(), input_) - - if not weight.requires_grad: - pass - elif TritonConfig.TRITON_LINEAR or sparse_map is not None: - # This assumes the transposed_weight is True for input_sparse, False for output_sparse. - input_row_sparse_matmul( - lhs, - rhs, - sparse_map, - out=weight.grad_buffer, # noqa - accumulate=not param_get_and_unset_is_zero(weight), - ) - elif weight.grad_buffer.dtype == grad_output.dtype: # noqa - beta = 1 - param_get_and_unset_is_zero(weight) - torch.addmm( - weight.grad_buffer, # noqa - lhs, - rhs, - beta=beta, - alpha=1, - out=weight.grad_buffer, # noqa - ) - else: - accumulate_gradient(weight, torch.mm(lhs, rhs)) - if bias is not None and bias.requires_grad: - accumulate_gradient(bias, grad_output.sum(dim=0)) - - -def linear_forward( - input_: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None, transposed_weight: bool = False -) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]]: - # Matmul - if TritonConfig.TRITON_LINEAR: - assert bias is None - output = dense_matmul( - input_.flatten(0, -2), - maybe_transpose(weight, not transposed_weight), - ).unflatten(0, input_.shape[:-1]) - else: - output = torch.nn.functional.linear(input_, maybe_transpose(weight, transposed_weight), bias) - return output, (input_, weight, bias, transposed_weight) - - -def linear_backward( - grad_output: torch.Tensor, context: tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool] -) -> torch.Tensor: - input_, weight, bias, transposed_weight = context - weight_t = maybe_transpose(weight, transposed_weight) - - # Input grad - if TritonConfig.TRITON_LINEAR: - grad_input = dense_matmul(grad_output.flatten(0, -2), weight_t).view_as(input_) - else: - grad_input = grad_output.matmul(weight_t) - - # Parameter grad - update_linear_gradients(input_, weight, bias, grad_output, transposed_weight, None) - return grad_input - - -def output_parallel_linear_forward( - input_: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - group: ProcessGroup | None, - sequence_parallel: bool, - transposed_weight: bool = False, - sparse_map: SparseMap | None = None, -) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: - # Gather sequence-parallel slices (non-overlapped) - input1 = gather_op(input_, group, dim=0) if sequence_parallel else input_ - - # Matmul - if TritonConfig.TRITON_LINEAR or sparse_map is not None: - assert bias is None - if sparse_map is not None: - assert not transposed_weight - output = output_sparse_matmul( - input1.flatten(0, -2), - maybe_transpose(weight, not transposed_weight), - sparse_map, - ).unflatten(0, input_.shape[:-1]) - else: - output = torch.nn.functional.linear(input1, maybe_transpose(weight, transposed_weight), bias) - - return output, ( - input_, - weight, - bias, - group, - sequence_parallel, - transposed_weight, - sparse_map, - ) - - -def output_parallel_linear_backward(grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor: - input_, weight, bias, group, sequence_parallel, transposed_weight, sparse_map = context - weight_t = maybe_transpose(weight, transposed_weight) - - # Gather sequence-parallel slices (overlapped) - if sequence_parallel: - input_, gather_handle = gather_op(input_, group, dim=0, async_op=True) - else: - gather_handle = None - - # Input grad - if TritonConfig.TRITON_LINEAR or sparse_map is not None: - grad_input = input_inner_sparse_matmul(grad_output.flatten(0, -2), weight_t, sparse_map).view_as(input_) - else: - grad_input = grad_output.matmul(weight_t) - - # Reduce input grad (overlapped) - grad_input, reduce_handle = (reduce_scatter_op if sequence_parallel else reduce_op)( - grad_input, group=group, async_op=True - ) - if sequence_parallel: - gather_handle.wait() - - # Parameter grad - update_linear_gradients(input_, weight, bias, grad_output, transposed_weight, sparse_map) - - if reduce_handle: - reduce_handle.wait() - return grad_input - - -def input_parallel_linear_forward( - input_: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - group: ProcessGroup | None, - sequence_parallel: bool, - transposed_weight: bool = False, - sparse_map: SparseMap | None = None, -) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: - # Matmul - if TritonConfig.TRITON_LINEAR or sparse_map is not None: - assert bias is None - if sparse_map is not None: - assert transposed_weight - output = input_inner_sparse_matmul( - input_.flatten(0, -2), maybe_transpose(weight, not transposed_weight), sparse_map - ).unflatten(0, input_.shape[:-1]) - else: - output = torch.nn.functional.linear(input_, maybe_transpose(weight, transposed_weight), bias) - - # Reduce input grad (non-overlapped) - output = (reduce_scatter_op if sequence_parallel else reduce_op)(output, group=group) - return output, ( - input_, - weight, - bias, - group, - sequence_parallel, - transposed_weight, - sparse_map, - ) - - -def input_parallel_linear_backward(grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor: - input_, weight, bias, group, sequence_parallel, transposed_weight, sparse_map = context - weight_t = maybe_transpose(weight, transposed_weight) - - # Gather sequence-parallel slices (non-overlapped) - if sequence_parallel: - grad_output = gather_op(grad_output, group, dim=0) - - # Input grad - if TritonConfig.TRITON_LINEAR or sparse_map is not None: - grad_input = output_sparse_matmul(grad_output.flatten(0, -2), weight_t, sparse_map).view_as(input_) - else: - grad_input = grad_output.matmul(weight_t) - - # Parameter grad - update_linear_gradients(input_, weight, bias, grad_output, transposed_weight, sparse_map) - - return grad_input - - -linear_autograd = wrap_forward_backward(linear_forward, linear_backward) - -output_parallel_linear_autograd = wrap_forward_backward( - output_parallel_linear_forward, output_parallel_linear_backward -) - -_input_parallel_linear = wrap_forward_backward(input_parallel_linear_forward, input_parallel_linear_backward) - - -def input_parallel_linear_autograd( - input_: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - group: ProcessGroup | None, - sequence_parallel: bool, - transposed_weight: bool = False, - sparse_map: SparseMap | None = None, -) -> tuple[torch.Tensor, torch.Tensor | None]: - # Autograd goes nuts it this goes in the function. - return ( - _input_parallel_linear( - input_, - weight, - None if group else bias, - group, - sequence_parallel, - transposed_weight, - sparse_map, - ), - bias if group else None, - ) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 8ab275ab..71ce9474 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -1,28 +1,11 @@ import math -import typing import torch from fast_llm.core.distributed import ProcessGroup -from fast_llm.core.ops import gather_op from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.functional.linear import ( - input_parallel_linear_forward, - maybe_transpose, - output_parallel_linear_backward, - output_parallel_linear_forward, - update_linear_gradients, -) from fast_llm.functional.triton import tl, tl_constexpr, triton_jit -from fast_llm.functional.triton.sparse_copy import ( - SparseMap, - copy_dense_to_sparse_backward, - copy_dense_to_sparse_forward, - copy_sparse_to_dense_backward, - copy_sparse_to_dense_forward, -) -from fast_llm.functional.triton.sparse_linear import output_sparse_matmul from fast_llm.tensor import param_get_and_unset_is_zero # Triton requires global variables to be annotated with `constexpr`. @@ -199,189 +182,6 @@ def torch_mlp_activation( return activation_type.activation_fn(input_) -def mlp_forward( - input_: torch.Tensor, - scores: torch.Tensor | None, - weight_1: torch.Tensor, - bias_1: torch.Tensor | None, - weight_2: torch.Tensor, - bias_2: torch.Tensor | None, - gated: bool, - activation_type: ActivationType, - group: ProcessGroup | None, - sequence_parallel: bool, - training: bool = True, - recompute_level: MLPRecomputeLevel = MLPRecomputeLevel.none, - transposed_layer_2_weight: bool = False, - sparse_map: SparseMap | None = None, -) -> tuple[torch.Tensor, list[typing.Any] | None]: - # Sparse copy - input_shape = input_.shape - intermediate_0 = input_ if sparse_map is None else copy_dense_to_sparse_forward(input_, sparse_map)[0] - - # Layer 1 - intermediate_1, _ = output_parallel_linear_forward( - intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map - ) - - if recompute_level.recompute_sparse_input: - intermediate_0 = None - else: - input_ = None - - # Activation - if TritonConfig.TRITON_ENABLED: - intermediate_2, _ = triton_mlp_activation_forward(intermediate_1, gated, activation_type) - else: - do_grad = training and not recompute_level.recompute_activation - with torch.set_grad_enabled(do_grad): - intermediate_2 = torch_mlp_activation( - intermediate_1.detach().requires_grad_(do_grad), gated, activation_type - ) - if recompute_level.recompute_layer_1: - intermediate_1 = None - - # Layer 2 - intermediate_3, _ = input_parallel_linear_forward( - intermediate_2, - weight_2, - bias_2, - group, - sequence_parallel, - transposed_layer_2_weight, - sparse_map, - ) - - # Context - if recompute_level.recompute_activation or not training: - intermediate_2 = None - - # Sparse copy - if sparse_map is None: - output = intermediate_3 - intermediate_3 = None - else: - output, _ = copy_sparse_to_dense_forward(intermediate_3, scores, sparse_map) - - context = ( - [ - input_, - scores, - intermediate_0, - intermediate_1, - intermediate_2, - intermediate_3, - weight_1, - bias_1, - weight_2, - bias_2, - gated, - activation_type, - group, - sequence_parallel, - transposed_layer_2_weight, - sparse_map, - input_shape, - ] - if training - else None - ) - return output, context - - -def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[torch.Tensor, torch.Tensor]: - ( - input_, - scores, - intermediate_0, - intermediate_1, - intermediate_2, - intermediate_3, - weight_1, - bias_1, - weight_2, - bias_2, - gated, - activation_type, - group, - sequence_parallel, - transposed_layer_2_weight, - sparse_map, - input_shape, - ) = context - context.clear() - - # Sparse copy - if sparse_map is None: - grad_scores = None - else: - grad_output, grad_scores = copy_sparse_to_dense_backward(grad_output, (sparse_map, intermediate_3, scores)) - - # Gather sequence-parallel slices (non-overlapped; from input_parallel_backward) - if sequence_parallel: - grad_output = gather_op(grad_output, group, dim=0) - - # Layer 2 input grad - weight_2_t = maybe_transpose(weight_2, transposed_layer_2_weight) - if sparse_map is None: - grad_intermediate_2 = grad_output.matmul(weight_2_t) - else: - grad_intermediate_2 = output_sparse_matmul(grad_output, weight_2_t, sparse_map) - - # Sparse input recomputation - if intermediate_0 is None: - intermediate_0 = input_ if sparse_map is None else copy_dense_to_sparse_forward(input_, sparse_map)[0] - - # Layer 1 recomputation - if intermediate_1 is None: - intermediate_1 = output_parallel_linear_forward( - intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map - )[0] - - # Activation recomputation and/or backward - if TritonConfig.TRITON_ENABLED: - grad_intermediate_1, intermediate_2_ = triton_mlp_activation_backward( - grad_intermediate_2, (intermediate_1, gated, activation_type), intermediate_2 is None - ) - else: - if intermediate_2 is None: - with torch.set_grad_enabled(True): - intermediate_2_ = torch_mlp_activation( - intermediate_1.detach().requires_grad_(True), gated, activation_type - ) - else: - intermediate_2_ = intermediate_2 - intermediate_2_.backward(grad_intermediate_2) - grad_intermediate_1 = intermediate_1.grad - - # Layer 2 parameter grad - del grad_intermediate_2, intermediate_1 - update_linear_gradients( - intermediate_2_ if intermediate_2 is None else intermediate_2, - weight_2, - bias_2, - grad_output, - transposed_layer_2_weight, - sparse_map, - ) - del grad_output, intermediate_2, intermediate_2_ - - # Layer 1 backward - grad_input = output_parallel_linear_backward( - grad_intermediate_1, - (intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map), - ) - - # Sparse copy - if sparse_map is not None: - grad_input = copy_dense_to_sparse_backward(grad_input, (sparse_map, input_shape)) - - return grad_input, grad_scores - - -mlp_autograd = wrap_forward_backward(mlp_forward, mlp_backward) - - class ChunkWeight(torch.autograd.Function): """ Chunk a weight without letting autograd know about it, i.e., make it believe it's actually separate weights. diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 6fbf683f..8305fb77 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -3,12 +3,11 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig -from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.base_model.base_model import SimpleFastLLMModule from fast_llm.engine.config_utils.tensor_space import TensorDim + from fast_llm.layers.common.linear import LinearBase, LinearLike from fast_llm.layers.common.normalization import LayerNorm, RMSNorm @@ -155,11 +154,11 @@ class PeftConfig(PeftArchitectureConfig, BaseModelConfig): hint=FieldHint.stability, ) - def apply_linear(self, linear: LinearBase) -> "SimpleFastLLMModule": + def apply_linear(self, linear: "LinearBase") -> "LinearLike": if self.type == PeftType.none: return linear elif self.type == PeftType.lora: - from fast_llm.layers.common.lora import LoRALinear + from fast_llm.layers.common.peft import LoRALinear # TODO: Init method? return LoRALinear( diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index e6aa11ff..04bcc990 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -9,17 +9,26 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig -from fast_llm.functional.linear import maybe_transpose, update_linear_gradients from fast_llm.functional.triton.sparse_copy import SparseMap -from fast_llm.functional.triton.sparse_linear import dense_matmul, input_inner_sparse_matmul, output_sparse_matmul -from fast_llm.tensor import ParameterMeta, init_zeros_ +from fast_llm.functional.triton.sparse_linear import ( + dense_matmul, + input_inner_sparse_matmul, + input_row_sparse_matmul, + output_sparse_matmul, +) +from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_zeros_, param_get_and_unset_is_zero logger = logging.getLogger(__name__) +def maybe_transpose(tensor: torch.Tensor, transpose: bool) -> torch.Tensor: + return tensor.t() if transpose else tensor + + @dataclasses.dataclass class LinearContext: - input_: torch.Tensor + # TODO: Check for memory leak + input_: torch.Tensor | None sparse_map: SparseMap | None @@ -28,8 +37,8 @@ def __init__(self): super().__init__() self._forward = wrap_forward_backward(self.forward_only, self.backward) - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_) + def forward(self, input_: torch.Tensor, sparse_map: SparseMap | None = None) -> torch.Tensor: + return self._forward(input_, sparse_map) def forward_only( self, input_: torch.Tensor, sparse_map: SparseMap | None = None @@ -104,9 +113,41 @@ def transposed_weight(self) -> bool: return self._transposed_weight def backward_parameters(self, grad_output: torch.Tensor, context: LinearContext) -> None: - update_linear_gradients( - context.input_, self.weight, self.bias, grad_output, self._transposed_weight, context.sparse_map - ) + """ + Calculate the weight and bias gradients for a linear layer. + TODO: fused_dense_cuda fuses weight gradient with bias gradient, but not with grad accumulation. + Which one is best? (and can we fuse everything?) + """ + + grad_output = grad_output.flatten(0, -2) + input_ = context.input_.flatten(0, -2) + lhs, rhs = (input_.t(), grad_output) if self._transposed_weight else (grad_output.t(), input_) + + if not self.weight.requires_grad: + pass + elif TritonConfig.TRITON_LINEAR or context.sparse_map is not None: + # This assumes the transposed_weight is True for input_sparse, False for output_sparse. + input_row_sparse_matmul( + lhs, + rhs, + context.sparse_map, + out=weight.grad_buffer, # noqa + accumulate=not param_get_and_unset_is_zero(self.weight), + ) + elif weight.grad_buffer.dtype == grad_output.dtype: # noqa + beta = 1 - param_get_and_unset_is_zero(self.weight) + torch.addmm( + weight.grad_buffer, # noqa + lhs, + rhs, + beta=beta, + alpha=1, + out=weight.grad_buffer, # noqa + ) + else: + accumulate_gradient(self.weight, torch.mm(lhs, rhs)) + if self.bias is not None and self.bias.requires_grad: + accumulate_gradient(self.bias, grad_output.sum(dim=0)) class Linear(LinearBase): diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 28044e80..114d3f94 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -15,6 +15,7 @@ NormalizationConfig, PeftArchitectureConfig, PeftConfig, + PeftType, ) from fast_llm.utils import Assert, div @@ -619,6 +620,8 @@ def _validate(self) -> None: Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) Assert.incl(len(self.mlp_lr_scale), (1, self.num_experts)) + if self.peft.type != PeftType.none and self.mlp_recompute_level != MLPRecomputeLevel.none: + raise ValueError("Activation recomputation not supported with Peft.") for scale in self.mlp_lr_scale: if scale is not None: Assert.geq(scale, 0) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 0b369a5a..628d9e84 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -1,5 +1,5 @@ +import dataclasses import typing -from abc import ABC import torch @@ -7,7 +7,6 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import ( - mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd, triton_mlp_activation_backward, @@ -15,17 +14,30 @@ ) from fast_llm.functional.triton.sparse_copy import ( SparseMap, + copy_dense_to_sparse_backward, copy_dense_to_sparse_forward, copy_sparse_to_dense_backward, copy_sparse_to_dense_forward, ) -from fast_llm.layers.common.linear import LinearBase +from fast_llm.layers.common.linear import LinearBase, LinearContext, LinearLike from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert -class MLPBase(Layer, ABC): +@dataclasses.dataclass +class MLPContext(LinearContext): + # TODO: Check for memory leak + scores: torch.Tensor | None + layer_1: LinearContext + layer_2: LinearContext + intermediate_1: torch.Tensor + intermediate_2: torch.Tensor + intermediate_3: torch.Tensor + input_shape: torch.Size + + +class MLPBase(Layer): def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): super().__init__() self._name = name @@ -51,7 +63,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) - self.layer_1 = self._config.peft.apply_linear( + self.layer_1: LinearLike = self._config.peft.apply_linear( LinearBase( hidden_dim, tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), @@ -61,7 +73,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s lr_scale=tuple(config.mlp_lr_scale), ) ) - self.layer_2 = self._config.peft.apply_linear( + self.layer_2: LinearLike = self._config.peft.apply_linear( LinearBase( self._intermediate_dim, hidden_dim, @@ -79,16 +91,16 @@ def forward_only( input_: torch.Tensor, scores: torch.Tensor | None, sparse_map: SparseMap | None = None, - ) -> tuple[torch.Tensor, list[typing.Any] | None]: + ) -> tuple[torch.Tensor, MLPContext | None]: # Sparse copy input_shape = input_.shape intermediate_0 = input_ if sparse_map is None else copy_dense_to_sparse_forward(input_, sparse_map)[0] # Layer 1 - intermediate_1, _ = self.layer_1.forward_only(intermediate_0, sparse_map) + intermediate_1, layer_1_context = self.layer_1.forward_only(intermediate_0, sparse_map) if self._recompute_level.recompute_sparse_input: - intermediate_0 = None + layer_1_context.input_ = None else: input_ = None @@ -105,11 +117,13 @@ def forward_only( intermediate_1 = None # Layer 2 - intermediate_3, _ = self.layer_2.forward_only(intermediate_2, sparse_map) + intermediate_3, layer_2_context = self.layer_2.forward_only(intermediate_2, sparse_map) # Context if self._recompute_level.recompute_activation or not self.training: intermediate_2 = None + # TODO: Doesn't work with LoRA. + layer_2_context.input_ = None # Sparse copy if sparse_map is None: @@ -119,87 +133,80 @@ def forward_only( output, _ = copy_sparse_to_dense_forward(intermediate_3, scores, sparse_map) context = ( - [ + MLPContext( input_, + sparse_map, scores, - intermediate_0, + layer_1_context, + layer_2_context, intermediate_1, intermediate_2, intermediate_3, - sparse_map, input_shape, - ] + ) if self.training else None ) return output, context - def backward(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tensor: - ( - input_, - scores, - intermediate_0, - intermediate_1, - intermediate_2, - intermediate_3, - sparse_map, - input_shape, - ) = context - context.clear() + def backward(self, grad_output: torch.Tensor, context: MLPContext) -> torch.Tensor: # Sparse copy - if sparse_map is None: + if context.sparse_map is None: grad_scores = None else: - grad_output, grad_scores = copy_sparse_to_dense_backward(grad_output, (sparse_map, intermediate_3, scores)) + grad_output, grad_scores = copy_sparse_to_dense_backward( + grad_output, (context.sparse_map, context.intermediate_3, context.scores) + ) - grad_intermediate_2, handle = self.layer_2.backward_input(grad_output, ()) + grad_intermediate_2, handle = self.layer_2.backward_activation(grad_output, context.layer_2) # Sparse input recomputation - if intermediate_0 is None: - intermediate_0 = input_ if sparse_map is None else copy_dense_to_sparse_forward(input_, sparse_map)[0] + if context.layer_1.input_ is None: + context.layer_1.input_ = ( + context.input_ + if context.sparse_map is None + else copy_dense_to_sparse_forward(context.input_, context.sparse_map)[0] + ) + + del context.input_, context.scores, context.intermediate_3 # Layer 1 recomputation - if intermediate_1 is None: - intermediate_1 = self.layer_1.forward_only(intermediate_0, sparse_map) + if context.intermediate_1 is None: + context.intermediate_1, _ = self.layer_1.forward_only(context.layer_1.input_, context.sparse_map) # Activation recomputation and/or backward if TritonConfig.TRITON_ENABLED: - grad_intermediate_1, intermediate_2_ = triton_mlp_activation_backward( - grad_intermediate_2, (intermediate_1, self._gated, self._activation_type), intermediate_2 is None + grad_intermediate_1, context.intermediate_2 = triton_mlp_activation_backward( + grad_intermediate_2, + (context.intermediate_1, self._gated, self._activation_type), + context.intermediate_2 is None, ) else: - if intermediate_2 is None: + if context.intermediate_2 is None: with torch.set_grad_enabled(True): - intermediate_2_ = torch_mlp_activation( - intermediate_1.detach().requires_grad_(True), self._gated, self._activation_type + context.intermediate_2 = torch_mlp_activation( + context.intermediate_1.detach().requires_grad_(True), self._gated, self._activation_type ) - else: - intermediate_2_ = intermediate_2 - intermediate_2_.backward(grad_intermediate_2) - grad_intermediate_1 = intermediate_1.grad + context.intermediate_2.backward(grad_intermediate_2) + grad_intermediate_1 = context.intermediate_1.grad # Layer 2 parameter grad - del grad_intermediate_2, intermediate_1 - update_linear_gradients( - intermediate_2_ if intermediate_2 is None else intermediate_2, - weight_2, - bias_2, - grad_output, - transposed_layer_2_weight, - sparse_map, - ) - del grad_output, intermediate_2, intermediate_2_ + del grad_intermediate_2, context.intermediate_1 + if context.layer_2.input_ is None: + context.layer_2.input_ = context.intermediate_2 + self.layer_2.backward_parameters(grad_output, context.layer_2) + del grad_output, context.intermediate_2, context.layer_2 # Layer 1 backward - grad_input = output_parallel_linear_backward( - grad_intermediate_1, - (intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map), - ) + grad_input = self.layer_1.backward(grad_intermediate_1, context.layer_1) + del context.layer_1, grad_intermediate_1 # Sparse copy - if sparse_map is not None: - grad_input = copy_dense_to_sparse_backward(grad_input, (sparse_map, input_shape)) + if context.sparse_map is not None: + grad_input = copy_dense_to_sparse_backward(grad_input, (context.sparse_map, context.input_shape)) + + del context.sparse_map return grad_input, grad_scores From f890633369a10b00e6317448e91ef74983036a15 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 10 Mar 2025 14:36:58 -0400 Subject: [PATCH 4/4] Select layers --- fast_llm/layers/transformer/attention.py | 16 ++++++--- fast_llm/layers/transformer/config.py | 45 ++++++++++++++++++++++-- fast_llm/layers/transformer/mlp.py | 8 +++-- 3 files changed, 59 insertions(+), 10 deletions(-) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 43574434..75ff8984 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -11,7 +11,12 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + TransformerLinearLayerName, +) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -109,7 +114,8 @@ def __init__( bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=self._config.attention_lr_scale, - ) + ), + TransformerLinearLayerName.query, ) self.key_value = self._config.peft.apply_linear( OutputParallelLinear( @@ -120,7 +126,8 @@ def __init__( bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=self._config.attention_lr_scale, - ) + ), + TransformerLinearLayerName.key_value, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -134,7 +141,8 @@ def __init__( bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=self._config.attention_lr_scale, - ) + ), + TransformerLinearLayerName.dense, ) def _attn_fused( diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 114d3f94..3c403b63 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -19,6 +19,9 @@ ) from fast_llm.utils import Assert, div +if typing.TYPE_CHECKING: + from fast_llm.layers.common.linear import LinearBase, LinearLike + logger = logging.getLogger(__name__) @@ -161,6 +164,35 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" +class TransformerLinearLayerName(str, enum.Enum): + # TODO: Use this to replace AddLinearBiasChoices. + query = "query" + key_value = "key_value" + dense = "dense" + mlp_1 = "mlp_1" + mlp_2 = "mlp_2" + + +@config_class() +class TransformerPeftConfig(PeftConfig): + layers: list[TransformerLinearLayerName] | None = Field( + default=None, + desc="The layers on which to apply LoRA.", + hint=FieldHint.feature, + ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear(self, linear: "LinearBase", layer_type: TransformerLinearLayerName | None = None) -> "LinearLike": + if layer_type is None or self.layers is None or layer_type in self.layers: + return super().apply_linear(linear) + elif self.freeze_others: + linear.weight.requires_grad = False + return linear + + @config_class() class TransformerArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -381,7 +413,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig) - peft: PeftConfig = FieldUpdate(default_factory=PeftConfig) + peft: TransformerPeftConfig = FieldUpdate(default_factory=PeftConfig) # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) init_method_std: float = Field( @@ -620,8 +652,15 @@ def _validate(self) -> None: Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) Assert.incl(len(self.mlp_lr_scale), (1, self.num_experts)) - if self.peft.type != PeftType.none and self.mlp_recompute_level != MLPRecomputeLevel.none: - raise ValueError("Activation recomputation not supported with Peft.") + if self.peft.type != PeftType.none and ( + self.peft.layers is None + or TransformerLinearLayerName.mlp_1 in self.peft.layers + or TransformerLinearLayerName.mlp_2 in self.peft.layers + ): + if self.mlp_recompute_level != MLPRecomputeLevel.none: + raise ValueError("Activation recomputation not supported with Peft.") + if self.num_experts > 1: + raise ValueError("Mixture of experts not supported with Peft.") for scale in self.mlp_lr_scale: if scale is not None: Assert.geq(scale, 0) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 628d9e84..c7a394cb 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -20,7 +20,7 @@ copy_sparse_to_dense_forward, ) from fast_llm.layers.common.linear import LinearBase, LinearContext, LinearLike -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerLinearLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -71,7 +71,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, lr_scale=tuple(config.mlp_lr_scale), - ) + ), + TransformerLinearLayerName.mlp_1, ) self.layer_2: LinearLike = self._config.peft.apply_linear( LinearBase( @@ -83,7 +84,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=tuple(config.mlp_lr_scale), - ) + ), + TransformerLinearLayerName.mlp_2, ) def forward_only(