Skip to content

Minimal LoRA implementation #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
05a988a
Basic LoRA
jlamypoirier Mar 11, 2025
01e465a
cleanup
jlamypoirier Mar 11, 2025
18a3a53
misc
jlamypoirier Mar 11, 2025
0b93e9b
wip
jlamypoirier Mar 11, 2025
7121067
wip
jlamypoirier Mar 12, 2025
a06d678
fixes
jlamypoirier Mar 13, 2025
863bcf7
fixes
jlamypoirier Mar 13, 2025
1926da1
fix
jlamypoirier Mar 14, 2025
2e416b1
fix
jlamypoirier Mar 14, 2025
811739a
fix
jlamypoirier Mar 14, 2025
420bedc
separate shard wip
jlamypoirier Mar 15, 2025
9313c88
separate shards
jlamypoirier Mar 18, 2025
b9b017f
fix
jlamypoirier Mar 18, 2025
e086908
fixes
jlamypoirier Mar 18, 2025
59c1f8d
fix
jlamypoirier Mar 18, 2025
cc192d5
fix
jlamypoirier Mar 18, 2025
e878656
Add test
jlamypoirier Mar 18, 2025
81e39a9
Merge remote-tracking branch 'origin/main' into frozen_weights
jlamypoirier Mar 18, 2025
3f79798
fix
jlamypoirier Mar 18, 2025
963db68
Merge remote-tracking branch 'origin/main' into lora_small
jlamypoirier Mar 18, 2025
475fc4e
Merge branch 'frozen_weights' into lora_small
jlamypoirier Mar 18, 2025
399fde7
misc
jlamypoirier Mar 19, 2025
a3065a9
fix
jlamypoirier Mar 19, 2025
a55e1a2
fixes
jlamypoirier Mar 20, 2025
bd184c7
fixes
jlamypoirier Mar 20, 2025
cdc8945
Override module
jlamypoirier Mar 21, 2025
94416a2
Separate key and value
jlamypoirier Mar 21, 2025
f9f2883
Add warning
jlamypoirier Mar 21, 2025
d959b61
Merge branch 'frozen_weights' into lora_small
jlamypoirier Mar 21, 2025
e9113c7
Update fast_llm/engine/checkpoint/distributed.py
RaymondLi0 Mar 26, 2025
29af13d
Merge branch 'main' into frozen_weights
jlamypoirier Mar 27, 2025
e20a908
Merge remote-tracking branch 'origin/main' into frozen_weights
jlamypoirier Mar 27, 2025
a75f9c7
Merge branch 'frozen_weights' into lora_small
jlamypoirier Mar 27, 2025
84d7dd1
Merge remote-tracking branch 'origin/main' into lora_small
jlamypoirier Mar 27, 2025
6d6b112
fixes
jlamypoirier Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def __init__(
# TODO: Separate fsdp for tied weights?
self._fsdp_index = {name: i for i, fsdp in enumerate(self._fsdps) for name in fsdp.parameter_names}

@property
def requires_grad(self):
return any(fsdp.requires_grad for fsdp in self._fsdps)

@property
def mode(self) -> StageMode:
assert self._is_setup
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def _forward(self, context: BatchContext, step: Step) -> None:
losses=context.losses,
metrics=context.metrics,
)
if context.is_training:
if step.backward_step is not None:
context.contexts[step.backward_step.global_index] = grad_context
self._record_compute(context, step)
return output
Expand Down
41 changes: 29 additions & 12 deletions fast_llm/engine/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
phase=self._phase,
)

self._steps = self._create_steps()
self._steps, self._first_grad_stage = self._create_steps()

self._create_index()

Expand Down Expand Up @@ -214,8 +214,8 @@ def _create_index(self) -> None:
# Consistency checks
step_map = self._step_map.copy()
for data_index in range(self._batch_config.num_inputs):
for type_ in (StepType.forward, StepType.backward) if self._is_training else (StepType.forward,):
for stage in range(self._num_stages):
for type_ in (StepType.forward, StepType.backward):
for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages):
assert (
step_map.pop((type_, stage, data_index), None) is not None
), f"Missing {type_.value} step with stage={stage}, data_index={data_index}"
Expand All @@ -225,7 +225,8 @@ def _create_index(self) -> None:
for i, step in enumerate(self._steps):
if self._is_training:
if step.type_ == StepType.forward:
step.backward_step = self.get_step(StepType.backward, *step.map_index[1:])
if step.stage >= self._first_grad_stage:
step.backward_step = self.get_step(StepType.backward, *step.map_index[1:])
else:
step.forward_step = self.get_step(StepType.forward, *step.map_index[1:])
if step.type_ == StepType.forward and step.stage == 0:
Expand All @@ -236,7 +237,8 @@ def _create_index(self) -> None:
step.prev_step = self.get_step(
step.type_, step.stage + (1 if step.type_ == StepType.backward else -1), *step.map_index[2:]
)
if step.type_ == StepType.backward and step.stage == 0:

if step.type_ == StepType.backward and step.stage == self._first_grad_stage:
step.next_step = None
elif step.type_ == StepType.forward and step.stage == self._num_stages - 1:
step.next_step = self.get_step(StepType.backward, *step.map_index[1:]) if self._is_training else None
Expand All @@ -249,11 +251,15 @@ def _create_index(self) -> None:
for step in self._steps:
if self._is_training:
if step.type_ == StepType.forward:
Assert.gt(step.backward_step.global_index, step.global_index)
Assert.is_(step.backward_step.forward_step, step)
if step.stage >= self._first_grad_stage:
Assert.gt(step.backward_step.global_index, step.global_index)
Assert.is_(step.backward_step.forward_step, step)
else:
assert step.backward_step is None
else:
Assert.lt(step.forward_step.global_index, step.global_index)
Assert.is_(step.forward_step.backward_step, step)
if step.stage >= self._first_grad_stage:
Assert.is_(step.forward_step.backward_step, step)
if step.next_step is not None:
Assert.gt(step.next_step.global_index, step.global_index)
Assert.is_(step.next_step.prev_step, step)
Expand Down Expand Up @@ -303,7 +309,10 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None:
reduce_step.reduce_accumulate = reduction_count[reduce_step.stage] > 0
reduction_count[reduce_step.stage] += 1
for stage, count in enumerate(reduction_count):
assert (count > 0) == (stage % self._distributed.pipeline_parallel == self._distributed.pipeline_rank)
assert (count > 0) == (
stage >= self._first_grad_stage
and (stage % self._distributed.pipeline_parallel == self._distributed.pipeline_rank)
)

def _setup_timeline(self) -> None:
# TODO: Include network time
Expand Down Expand Up @@ -468,8 +477,16 @@ def get_data_index_split(
micro_sequence,
)

def _create_steps(self) -> list[Step]:
def _create_steps(self) -> tuple[list[Step], int]:
steps = []
if self._is_training:
# The first stage(s) may not have any trainable parameters,
# in which case we shouldn't run the backward pass.
first_grad_stage = 0
while first_grad_stage < self._num_stages and not self._multi_stage.stages[first_grad_stage].requires_grad:
first_grad_stage += 1
else:
first_grad_stage = self._num_stages
for depth_first_micro_batch in range(self._batch_config.depth_first_micro_batches):
for stage in range(self._num_stages):
for breadth_first_micro_batch in range(self._batch_config.breadth_first_micro_batches):
Expand All @@ -485,7 +502,7 @@ def _create_steps(self) -> list[Step]:
)
)
if self._is_training:
for stage in reversed(range(self._num_stages)):
for stage in reversed(range(first_grad_stage, self._num_stages)):
for breadth_first_micro_batch in range(self._batch_config.breadth_first_micro_batches):
for micro_sequence in reversed(range(self._batch_config.num_micro_sequences)):
steps.append(
Expand All @@ -498,4 +515,4 @@ def _create_steps(self) -> list[Step]:
type_=StepType.backward,
)
)
return steps
return steps, first_grad_stage
70 changes: 42 additions & 28 deletions fast_llm/functional/triton/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def triton_normalization_backward_kernel_1(
n_cols,
n_rows,
has_bias: tl_constexpr,
parameter_grad: tl_constexpr,
zero_centered: tl_constexpr,
block_size: tl_constexpr,
block_size_row: tl_constexpr,
Expand Down Expand Up @@ -108,18 +109,19 @@ def triton_normalization_backward_kernel_1(
tl.store(grad_input_ptr + offsets, grad_input, mask=mask)

# Parameter grad partial sums
parameter_offsets = tl.program_id(0) * n_cols + cols
grad_weight_partial_ptr = grad_weight_partial_ptr + parameter_offsets
grad_weight_partial = (grad_output * input_normalized).to(weight.dtype)
grad_weight_partial = tl.sum(grad_weight_partial, axis=0)[None, :]
if parameter_grad:
parameter_offsets = tl.program_id(0) * n_cols + cols
grad_weight_partial_ptr = grad_weight_partial_ptr + parameter_offsets
grad_weight_partial = (grad_output * input_normalized).to(weight.dtype)
grad_weight_partial = tl.sum(grad_weight_partial, axis=0)[None, :]

if has_bias:
grad_bias_partial_ptr = grad_bias_partial_ptr + parameter_offsets
grad_bias_partial = tl.sum(grad_output.to(weight.dtype), axis=0)[None, :]
if has_bias:
grad_bias_partial_ptr = grad_bias_partial_ptr + parameter_offsets
grad_bias_partial = tl.sum(grad_output.to(weight.dtype), axis=0)[None, :]

tl.store(grad_weight_partial_ptr, grad_weight_partial, mask=col_mask)
if has_bias:
tl.store(grad_bias_partial_ptr, grad_bias_partial, mask=col_mask) # noqa
tl.store(grad_weight_partial_ptr, grad_weight_partial, mask=col_mask)
if has_bias:
tl.store(grad_bias_partial_ptr, grad_bias_partial, mask=col_mask) # noqa


@triton_jit()
Expand Down Expand Up @@ -211,6 +213,11 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
context.clear()
has_bias = bias is not None

parameter_grad = weight.requires_grad
assert parameter_grad == hasattr(weight, "grad_buffer")
if has_bias:
assert parameter_grad == bias.requires_grad

grad_output = grad_output.contiguous()

n_rows = grad_output.shape[:-1].numel()
Expand All @@ -232,12 +239,17 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin

grad_input = torch.empty_like(grad_output)

grad_is_zero = param_get_and_unset_is_zero(weight)
grad_weight = weight.grad_buffer
# TODO: Any point in making it full precision?
grad_weight_partial = grad_output.new_empty(num_blocks_row, n_cols)
if parameter_grad:
grad_is_zero = param_get_and_unset_is_zero(weight)
grad_weight = weight.grad_buffer
# TODO: Any point in making it full precision?
grad_weight_partial = grad_output.new_empty(num_blocks_row, n_cols)
else:
grad_is_zero = True
grad_weight = None
grad_weight_partial = None

if has_bias:
if has_bias and parameter_grad:
assert param_get_and_unset_is_zero(bias) == grad_is_zero
grad_bias = bias.grad_buffer
grad_bias_partial = grad_output.new_empty(num_blocks_row, n_cols)
Expand All @@ -256,24 +268,26 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
n_cols,
n_rows,
has_bias,
parameter_grad,
zero_centered,
block_size,
block_size_row,
num_warps=num_warps,
)
triton_normalization_backward_kernel_2[(triton.cdiv(n_cols, block_size_n),)](
grad_weight_partial,
grad_bias_partial,
grad_weight,
grad_bias,
num_blocks_row,
n_cols,
has_bias,
not grad_is_zero,
block_size_m,
block_size_n,
num_ctas=1,
)
if parameter_grad:
triton_normalization_backward_kernel_2[(triton.cdiv(n_cols, block_size_n),)](
grad_weight_partial,
grad_bias_partial,
grad_weight,
grad_bias,
num_blocks_row,
n_cols,
has_bias,
not grad_is_zero,
block_size_m,
block_size_n,
num_ctas=1,
)
return grad_input


Expand Down
57 changes: 57 additions & 0 deletions fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

if typing.TYPE_CHECKING:
from fast_llm.engine.config_utils.tensor_space import TensorDim
from fast_llm.layers.common.linear import LinearBase, LinearLike
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm


Expand Down Expand Up @@ -115,3 +116,59 @@ def _from_dict(
cls._handle_renamed_field(default, "normalization_implementation", "implementation")
cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range")
return super()._from_dict(default, strict, flat)


class PeftType(str, enum.Enum):
# TODO : Use a dynamic config type instead.
none = "none"
lora = "lora"


@config_class()
class PeftArchitectureConfig(BaseModelArchitectureConfig):
_abstract = False


@config_class()
class PeftConfig(PeftArchitectureConfig, BaseModelConfig):
# TODO: Architecture/non-architecture split might not make much sense here.

type: PeftType = Field(
default=PeftType.none,
desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.",
hint=FieldHint.core,
)
rank: int = Field(
default=8,
desc="The LoRA rank, i.e. the size of the intermediate dimension.",
hint=FieldHint.stability,
)
alpha: float = Field(
default=8.0,
desc="The LoRA scaling parameter.",
hint=FieldHint.stability,
)
dropout: float = Field(
default=0.0,
desc="Dropout rate for LoRA.",
hint=FieldHint.stability,
)

def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike":
if self.type == PeftType.none:
return linear
elif self.type == PeftType.lora:
from fast_llm.layers.common.peft import lora_linear

# TODO: Init method?
return lora_linear(
linear,
linear.weight.param_init_method,
linear.weight.param_init_method,
self.rank,
self.alpha,
self.dropout,
**kwargs,
)
else:
raise NotImplementedError(self.type)
Loading