Skip to content

Commit ab17636

Browse files
authored
Minimal LoRA implementation (#182)
1 parent 14c980b commit ab17636

File tree

14 files changed

+389
-74
lines changed

14 files changed

+389
-74
lines changed

fast_llm/engine/multi_stage/stage_base.py

+4
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def __init__(
8585
# TODO: Separate fsdp for tied weights?
8686
self._fsdp_index = {name: i for i, fsdp in enumerate(self._fsdps) for name in fsdp.parameter_names}
8787

88+
@property
89+
def requires_grad(self):
90+
return any(fsdp.requires_grad for fsdp in self._fsdps)
91+
8892
@property
8993
def mode(self) -> StageMode:
9094
assert self._is_setup

fast_llm/engine/schedule/runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _forward(self, context: BatchContext, step: Step) -> None:
406406
losses=context.losses,
407407
metrics=context.metrics,
408408
)
409-
if context.is_training:
409+
if step.backward_step is not None:
410410
context.contexts[step.backward_step.global_index] = grad_context
411411
self._record_compute(context, step)
412412
return output

fast_llm/engine/schedule/schedule.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
phase=self._phase,
142142
)
143143

144-
self._steps = self._create_steps()
144+
self._steps, self._first_grad_stage = self._create_steps()
145145

146146
self._create_index()
147147

@@ -214,8 +214,8 @@ def _create_index(self) -> None:
214214
# Consistency checks
215215
step_map = self._step_map.copy()
216216
for data_index in range(self._batch_config.num_inputs):
217-
for type_ in (StepType.forward, StepType.backward) if self._is_training else (StepType.forward,):
218-
for stage in range(self._num_stages):
217+
for type_ in (StepType.forward, StepType.backward):
218+
for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages):
219219
assert (
220220
step_map.pop((type_, stage, data_index), None) is not None
221221
), f"Missing {type_.value} step with stage={stage}, data_index={data_index}"
@@ -225,7 +225,8 @@ def _create_index(self) -> None:
225225
for i, step in enumerate(self._steps):
226226
if self._is_training:
227227
if step.type_ == StepType.forward:
228-
step.backward_step = self.get_step(StepType.backward, *step.map_index[1:])
228+
if step.stage >= self._first_grad_stage:
229+
step.backward_step = self.get_step(StepType.backward, *step.map_index[1:])
229230
else:
230231
step.forward_step = self.get_step(StepType.forward, *step.map_index[1:])
231232
if step.type_ == StepType.forward and step.stage == 0:
@@ -236,7 +237,8 @@ def _create_index(self) -> None:
236237
step.prev_step = self.get_step(
237238
step.type_, step.stage + (1 if step.type_ == StepType.backward else -1), *step.map_index[2:]
238239
)
239-
if step.type_ == StepType.backward and step.stage == 0:
240+
241+
if step.type_ == StepType.backward and step.stage == self._first_grad_stage:
240242
step.next_step = None
241243
elif step.type_ == StepType.forward and step.stage == self._num_stages - 1:
242244
step.next_step = self.get_step(StepType.backward, *step.map_index[1:]) if self._is_training else None
@@ -249,11 +251,15 @@ def _create_index(self) -> None:
249251
for step in self._steps:
250252
if self._is_training:
251253
if step.type_ == StepType.forward:
252-
Assert.gt(step.backward_step.global_index, step.global_index)
253-
Assert.is_(step.backward_step.forward_step, step)
254+
if step.stage >= self._first_grad_stage:
255+
Assert.gt(step.backward_step.global_index, step.global_index)
256+
Assert.is_(step.backward_step.forward_step, step)
257+
else:
258+
assert step.backward_step is None
254259
else:
255260
Assert.lt(step.forward_step.global_index, step.global_index)
256-
Assert.is_(step.forward_step.backward_step, step)
261+
if step.stage >= self._first_grad_stage:
262+
Assert.is_(step.forward_step.backward_step, step)
257263
if step.next_step is not None:
258264
Assert.gt(step.next_step.global_index, step.global_index)
259265
Assert.is_(step.next_step.prev_step, step)
@@ -303,7 +309,10 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None:
303309
reduce_step.reduce_accumulate = reduction_count[reduce_step.stage] > 0
304310
reduction_count[reduce_step.stage] += 1
305311
for stage, count in enumerate(reduction_count):
306-
assert (count > 0) == (stage % self._distributed.pipeline_parallel == self._distributed.pipeline_rank)
312+
assert (count > 0) == (
313+
stage >= self._first_grad_stage
314+
and (stage % self._distributed.pipeline_parallel == self._distributed.pipeline_rank)
315+
)
307316

308317
def _setup_timeline(self) -> None:
309318
# TODO: Include network time
@@ -468,8 +477,16 @@ def get_data_index_split(
468477
micro_sequence,
469478
)
470479

471-
def _create_steps(self) -> list[Step]:
480+
def _create_steps(self) -> tuple[list[Step], int]:
472481
steps = []
482+
if self._is_training:
483+
# The first stage(s) may not have any trainable parameters,
484+
# in which case we shouldn't run the backward pass.
485+
first_grad_stage = 0
486+
while first_grad_stage < self._num_stages and not self._multi_stage.stages[first_grad_stage].requires_grad:
487+
first_grad_stage += 1
488+
else:
489+
first_grad_stage = self._num_stages
473490
for depth_first_micro_batch in range(self._batch_config.depth_first_micro_batches):
474491
for stage in range(self._num_stages):
475492
for breadth_first_micro_batch in range(self._batch_config.breadth_first_micro_batches):
@@ -485,7 +502,7 @@ def _create_steps(self) -> list[Step]:
485502
)
486503
)
487504
if self._is_training:
488-
for stage in reversed(range(self._num_stages)):
505+
for stage in reversed(range(first_grad_stage, self._num_stages)):
489506
for breadth_first_micro_batch in range(self._batch_config.breadth_first_micro_batches):
490507
for micro_sequence in reversed(range(self._batch_config.num_micro_sequences)):
491508
steps.append(
@@ -498,4 +515,4 @@ def _create_steps(self) -> list[Step]:
498515
type_=StepType.backward,
499516
)
500517
)
501-
return steps
518+
return steps, first_grad_stage

fast_llm/functional/triton/normalization.py

+42-28
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def triton_normalization_backward_kernel_1(
6868
n_cols,
6969
n_rows,
7070
has_bias: tl_constexpr,
71+
parameter_grad: tl_constexpr,
7172
zero_centered: tl_constexpr,
7273
block_size: tl_constexpr,
7374
block_size_row: tl_constexpr,
@@ -108,18 +109,19 @@ def triton_normalization_backward_kernel_1(
108109
tl.store(grad_input_ptr + offsets, grad_input, mask=mask)
109110

110111
# Parameter grad partial sums
111-
parameter_offsets = tl.program_id(0) * n_cols + cols
112-
grad_weight_partial_ptr = grad_weight_partial_ptr + parameter_offsets
113-
grad_weight_partial = (grad_output * input_normalized).to(weight.dtype)
114-
grad_weight_partial = tl.sum(grad_weight_partial, axis=0)[None, :]
112+
if parameter_grad:
113+
parameter_offsets = tl.program_id(0) * n_cols + cols
114+
grad_weight_partial_ptr = grad_weight_partial_ptr + parameter_offsets
115+
grad_weight_partial = (grad_output * input_normalized).to(weight.dtype)
116+
grad_weight_partial = tl.sum(grad_weight_partial, axis=0)[None, :]
115117

116-
if has_bias:
117-
grad_bias_partial_ptr = grad_bias_partial_ptr + parameter_offsets
118-
grad_bias_partial = tl.sum(grad_output.to(weight.dtype), axis=0)[None, :]
118+
if has_bias:
119+
grad_bias_partial_ptr = grad_bias_partial_ptr + parameter_offsets
120+
grad_bias_partial = tl.sum(grad_output.to(weight.dtype), axis=0)[None, :]
119121

120-
tl.store(grad_weight_partial_ptr, grad_weight_partial, mask=col_mask)
121-
if has_bias:
122-
tl.store(grad_bias_partial_ptr, grad_bias_partial, mask=col_mask) # noqa
122+
tl.store(grad_weight_partial_ptr, grad_weight_partial, mask=col_mask)
123+
if has_bias:
124+
tl.store(grad_bias_partial_ptr, grad_bias_partial, mask=col_mask) # noqa
123125

124126

125127
@triton_jit()
@@ -211,6 +213,11 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
211213
context.clear()
212214
has_bias = bias is not None
213215

216+
parameter_grad = weight.requires_grad
217+
assert parameter_grad == hasattr(weight, "grad_buffer")
218+
if has_bias:
219+
assert parameter_grad == bias.requires_grad
220+
214221
grad_output = grad_output.contiguous()
215222

216223
n_rows = grad_output.shape[:-1].numel()
@@ -232,12 +239,17 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
232239

233240
grad_input = torch.empty_like(grad_output)
234241

235-
grad_is_zero = param_get_and_unset_is_zero(weight)
236-
grad_weight = weight.grad_buffer
237-
# TODO: Any point in making it full precision?
238-
grad_weight_partial = grad_output.new_empty(num_blocks_row, n_cols)
242+
if parameter_grad:
243+
grad_is_zero = param_get_and_unset_is_zero(weight)
244+
grad_weight = weight.grad_buffer
245+
# TODO: Any point in making it full precision?
246+
grad_weight_partial = grad_output.new_empty(num_blocks_row, n_cols)
247+
else:
248+
grad_is_zero = True
249+
grad_weight = None
250+
grad_weight_partial = None
239251

240-
if has_bias:
252+
if has_bias and parameter_grad:
241253
assert param_get_and_unset_is_zero(bias) == grad_is_zero
242254
grad_bias = bias.grad_buffer
243255
grad_bias_partial = grad_output.new_empty(num_blocks_row, n_cols)
@@ -256,24 +268,26 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
256268
n_cols,
257269
n_rows,
258270
has_bias,
271+
parameter_grad,
259272
zero_centered,
260273
block_size,
261274
block_size_row,
262275
num_warps=num_warps,
263276
)
264-
triton_normalization_backward_kernel_2[(triton.cdiv(n_cols, block_size_n),)](
265-
grad_weight_partial,
266-
grad_bias_partial,
267-
grad_weight,
268-
grad_bias,
269-
num_blocks_row,
270-
n_cols,
271-
has_bias,
272-
not grad_is_zero,
273-
block_size_m,
274-
block_size_n,
275-
num_ctas=1,
276-
)
277+
if parameter_grad:
278+
triton_normalization_backward_kernel_2[(triton.cdiv(n_cols, block_size_n),)](
279+
grad_weight_partial,
280+
grad_bias_partial,
281+
grad_weight,
282+
grad_bias,
283+
num_blocks_row,
284+
n_cols,
285+
has_bias,
286+
not grad_is_zero,
287+
block_size_m,
288+
block_size_n,
289+
num_ctas=1,
290+
)
277291
return grad_input
278292

279293

fast_llm/layers/common/config.py

+57
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
if typing.TYPE_CHECKING:
99
from fast_llm.engine.config_utils.tensor_space import TensorDim
10+
from fast_llm.layers.common.linear import LinearBase, LinearLike
1011
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm
1112

1213

@@ -115,3 +116,59 @@ def _from_dict(
115116
cls._handle_renamed_field(default, "normalization_implementation", "implementation")
116117
cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range")
117118
return super()._from_dict(default, strict, flat)
119+
120+
121+
class PeftType(str, enum.Enum):
122+
# TODO : Use a dynamic config type instead.
123+
none = "none"
124+
lora = "lora"
125+
126+
127+
@config_class()
128+
class PeftArchitectureConfig(BaseModelArchitectureConfig):
129+
_abstract = False
130+
131+
132+
@config_class()
133+
class PeftConfig(PeftArchitectureConfig, BaseModelConfig):
134+
# TODO: Architecture/non-architecture split might not make much sense here.
135+
136+
type: PeftType = Field(
137+
default=PeftType.none,
138+
desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.",
139+
hint=FieldHint.core,
140+
)
141+
rank: int = Field(
142+
default=8,
143+
desc="The LoRA rank, i.e. the size of the intermediate dimension.",
144+
hint=FieldHint.stability,
145+
)
146+
alpha: float = Field(
147+
default=8.0,
148+
desc="The LoRA scaling parameter.",
149+
hint=FieldHint.stability,
150+
)
151+
dropout: float = Field(
152+
default=0.0,
153+
desc="Dropout rate for LoRA.",
154+
hint=FieldHint.stability,
155+
)
156+
157+
def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike":
158+
if self.type == PeftType.none:
159+
return linear
160+
elif self.type == PeftType.lora:
161+
from fast_llm.layers.common.peft import lora_linear
162+
163+
# TODO: Init method?
164+
return lora_linear(
165+
linear,
166+
linear.weight.param_init_method,
167+
linear.weight.param_init_method,
168+
self.rank,
169+
self.alpha,
170+
self.dropout,
171+
**kwargs,
172+
)
173+
else:
174+
raise NotImplementedError(self.type)

0 commit comments

Comments
 (0)