Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
16 changes: 16 additions & 0 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ def allreduce_scalar(
return value


def all_gather_scalar(
value: float | int,
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
timeout: float | None = None,
):
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
add_ephemeral_timeout(group, timeout)
output_tensor = value.new_empty((group.size(),))
torch.distributed.all_gather_into_tensor(output_tensor, value, group=group)
return output_tensor.tolist()
else:
return value


def broadcast_scalar(
value: float | int,
dtype: torch.dtype = torch.float64,
Expand Down
60 changes: 47 additions & 13 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import functools
import typing

import torch.nn
Expand Down Expand Up @@ -52,10 +53,15 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
losses += layer.get_loss_definitions(count)
return losses

def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
for layer in self.get_layers():
if layer is not self:
layer.preprocess(batch, kwargs)
layer.preprocess(kwargs)

def unwrap(self) -> "LayerBase":
# Get the actual module contained in this layer,
# undoing any wrapping for the Fast-LLM engine (ex. `LayerBaseWithNamespace`)
return self


class Layer(LayerBase):
Expand All @@ -74,30 +80,63 @@ def forward(
pass

def unwrap(self) -> "Layer":
# Get the actual module contained in this layer,
# undoing any wrapping for the Fast-LLM engine (ex. `LayerWithNamespace`)
return self


class LayerWithNamespace(Layer):
class LayerBaseWithNamespace(LayerBase):
"""
A layer with its own namespace for preprocessing (kwargs),
A layer base with its own namespace for preprocessing (kwargs),
so that it doesn't inadvertently interact with other layers.
TODO: Consider namespace for losses and metrics?
"""

def __init__(self, layer: Layer, namespace: str = None):
def __init__(self, layer: LayerBase, namespace: str = None):
super().__init__(layer._distributed_config)
self._layer = layer
self._namespace = namespace
self.layer_count = self._layer.layer_count
self.get_compute_usage = self._layer.get_compute_usage
self.module_name = self._layer.module_name

def setup(self, distributed: Distributed) -> None:
self._layer.setup(distributed)
super().setup(distributed)

def get_layers(self) -> list["Layer"]:
"""
Wrap individual layers so the namespace is used in forward.
"""
return self._layers_with_namespace

def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
"""
Preprocess with namespace.
"""
if self._namespace not in kwargs:
kwargs[self._namespace] = kwargs.copy()
self._layer.preprocess(kwargs[self._namespace])

def unwrap(self) -> "LayerBase":
return self._layer.unwrap()

@functools.cached_property
def _layers_with_namespace(self) -> list[Layer]:
# This needs to be in a property because `module_name` is set after `__init__`.
# Wrap each set of blocks with identical config in a namespace
# using the unique module name of the first such block.
return [LayerWithNamespace(layer, self._namespace) for layer in self._layer.get_layers()]


class LayerWithNamespace(LayerBaseWithNamespace, Layer):
_layer: Layer

def __init__(self, layer: Layer, namespace: str = None):
super().__init__(layer, namespace)
self.layer_count = self._layer.layer_count

def get_layers(self) -> list["Layer"]:
# Need to override since `LayerBaseWithNamespace.get_layers` comes first in the MRO.
return [self]

def forward(
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
) -> torch.Tensor:
Expand All @@ -109,11 +148,6 @@ def forward(
assert isinstance(input_, TensorMeta)
return self._layer.forward(input_, kwargs, losses, metrics)

def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
assert self._namespace not in kwargs
kwargs[self._namespace] = kwargs.copy()
self._layer.preprocess(batch, kwargs[self._namespace])

def unwrap(self) -> "Layer":
return self._layer.unwrap()

Expand Down
14 changes: 12 additions & 2 deletions fast_llm/engine/config_utils/tensor_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@


class TensorDim:
def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None):
def __init__(
self, name: str, global_size: int, parallel_dim: DistributedDim | None = None, variable_size: bool = False
):
# TODO: Handle None for unknown sizes?
self._name = name
self._global_size = global_size
self._size = self._global_size if parallel_dim is None else div(global_size, parallel_dim.size)
self._parallel_dim = parallel_dim
self._variable_size = variable_size

def __repr__(self) -> str:
return (
Expand All @@ -28,6 +31,7 @@ def __repr__(self) -> str:
f" size={self._size},"
f" global_size={self._global_size},"
f" parallel_dim={self._parallel_dim}"
f" variable_size={self._variable_size}"
f")"
)

Expand Down Expand Up @@ -60,9 +64,13 @@ def parallel_group(self) -> "ProcessGroup|None":
# TODO: Make more flexible for derived classes?
return None if self._parallel_dim is None else self._parallel_dim.group

@property
def variable_size(self) -> bool:
return self._variable_size

def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self:
assert self.is_parallel
return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim)
return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim, self.variable_size)

def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor":
if self.is_parallel:
Expand Down Expand Up @@ -99,6 +107,7 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]):
assert parallel_dim is None
parallel_dim = tensor_dim.parallel_dim
self._parallel_dim_index = dim
assert not tensor_dim.variable_size

super().__init__(
name=name,
Expand Down Expand Up @@ -142,6 +151,7 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]):
for dim, tensor_dim in enumerate(tensor_dims[1:]):
# TODO: Allow more flexibility?
Assert.is_(tensor_dim.parallel_dim, parallel_dim)
assert not tensor_dim.variable_size

super().__init__(
name=name,
Expand Down
27 changes: 18 additions & 9 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class DistributedDimNames:
sequence_data = "sequence_data"
batch_data = "batch_data"
tensor_and_sequence_data = "tensor_and_sequence_data"
tensor_and_data = "tensor_and_data"


@config_class()
Expand Down Expand Up @@ -255,8 +256,6 @@ def _validate(self) -> None:
Assert.multiple(self.local_world_size, self.tensor_parallel)

if self.pipeline_first:
# Case is useless and would cause too many complications.
Assert.eq(self.sequence_data_parallel, 1)
# Smaller models can be more demanding on pipeline parallel.
self.data_rank = (self.rank // self.tensor_parallel) // self.pipeline_parallel
self.pipeline_rank = (self.rank // self.tensor_parallel) % self.pipeline_parallel
Expand Down Expand Up @@ -334,14 +333,24 @@ def _validate(self) -> None:
),
)
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_sequence_data,
size=self.sequence_data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1),
# Global ranks wrong with pipeline first, so we hide the dims as a safety check.
if not self.pipeline_first:
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_sequence_data,
size=self.sequence_data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1),
)
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_data,
size=self.data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.data_rank * self.tensor_parallel,
global_ranks=self._get_global_ranks(self.data_parallel * self.tensor_parallel, 1),
)
)
)

super()._validate()

Expand Down
11 changes: 8 additions & 3 deletions fast_llm/engine/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,14 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor])
self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data])
self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data])
self.tensor_and_sequence_data_group = self.add_group(
self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data]
)
# Global ranks wrong with pipeline first, so we hide the dims as a safety check.
if not self._config.pipeline_first:
self.tensor_and_sequence_data_group = self.add_group(
self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data]
)
self.tensor_and_data_group = self.add_group(
self._config.distributed_dims[DistributedDimNames.tensor_and_data]
)

self._config.log_first_rank(f"Setting random seeds...")

Expand Down
Loading