Skip to content

[Prototype] LoRA #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions fast_llm/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"""

import logging
import typing

import torch
import torch._dynamo # noqa
import torch.autograd
from torch._C._distributed_c10d import Work

from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
from fast_llm.utils import Assert, div
Expand All @@ -18,12 +18,12 @@

def reduce_op(
input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
) -> tuple[torch.Tensor, Work] | torch.Tensor:
) -> tuple[torch.Tensor, typing.Callable[[], None]] | torch.Tensor:
if group:
handle = all_reduce(input_, group=group, async_op=async_op, op=op)
else:
handle = None
return (input_, handle) if async_op else input_
return (input_, handle.wait) if async_op else input_


def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]:
Expand Down Expand Up @@ -62,7 +62,7 @@ def swap_mult_dim(tensor: torch.Tensor, factor: int, old_dim: int, new_dim: int)

def gather_op(
input_: torch.Tensor, group: ProcessGroup | None, dim: int, async_op: bool = False, out=None
) -> tuple[torch.Tensor, Work] | torch.Tensor:
) -> tuple[torch.Tensor, typing.Callable[[], None]] | torch.Tensor:
"""Gather tensors and concatenate along the last dimension."""
# Bypass the function if we are using only 1 GPU.
if not group:
Expand All @@ -79,7 +79,7 @@ def gather_op(
assert not async_op
# TODO: contiguous?
out = swap_mult_dim(out, group.size(), 0, dim)
return (out, handle) if async_op else out
return (out, handle.wait) if async_op else out


def reduce_scatter_op(
Expand All @@ -89,7 +89,7 @@ def reduce_scatter_op(
op: ReduceOp = ReduceOp.SUM,
dim: int = 0,
async_op: bool = False,
) -> tuple[torch.Tensor, Work] | torch.Tensor:
) -> tuple[torch.Tensor, typing.Callable[[], None]] | torch.Tensor:
"""Reduce-scatter the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if not group:
Expand All @@ -99,7 +99,7 @@ def reduce_scatter_op(
input_ = swap_mult_dim(input_, group.size(), dim, 0)
# TODO: May give the wrong output without the contiguous call.
handle = reduce_scatter_tensor(output, input_.contiguous(), group=group, async_op=async_op, op=op)
return (output, handle) if async_op else output
return (output, handle.wait) if async_op else output


class _ReduceBackward(torch.autograd.Function):
Expand Down
22 changes: 17 additions & 5 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,29 @@
from fast_llm.utils import Assert


class Module(torch.nn.Module, abc.ABC):
""" """

def forward(self, input_, kwargs):
class FastLLMModule(torch.nn.Module, abc.ABC):
def forward(self, *args, **kwargs):
"""
Run a forward pass for the module, with autograd support.
"""
raise NotImplementedError()

def forward_only(self, *args, **kwargs) -> tuple[typing.Any, typing.Any]:
"""
Run only the forward pass, and return the output and context for backward.
TODO: Make a generic type for the context?
"""
raise NotImplementedError()

def backward(self, *grad_outputs: torch.Tensor, context: typing.Any) -> tuple[torch.Tensor, ...]:
"""
Run the full backward pass using the output grads and the context, and return the input grads.
Parameter gradients should be accumulated directly in their gradient buffer rather than returned.
"""
raise NotImplementedError()


class Layer(Module):
class Layer(FastLLMModule):
# Weight used to determine the stage size
layer_count: float = 1.0

Expand Down
252 changes: 0 additions & 252 deletions fast_llm/functional/linear.py

This file was deleted.

Loading
Loading