Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
import typing

from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.data.config import MultiprocessingContext
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SampledDatasetConfig
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig
from fast_llm.data.sample.gpt import GPTSample
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.sample.language_model import LanguageModelSample
logger = logging.getLogger(__name__)


Expand All @@ -22,12 +23,11 @@ class GPTDataConfig(DataConfig):
_abstract = False

# TODO: Review field. Move closer to phase definition in training config?
datasets: dict[str, SampledDatasetConfig[GPTSample]] = Field(
datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field(
default_factory=dict,
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
sampling: GPTSamplingConfig = FieldUpdate()
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand Down
33 changes: 3 additions & 30 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pathlib
import typing
import warnings
from functools import partial

import torch
import torch.utils.data
Expand All @@ -14,7 +13,7 @@
from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.sample.gpt import GPTBatch, GPTSample
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
Expand All @@ -24,32 +23,9 @@
logger = logging.getLogger(__name__)


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
stacked_spans = None
sequence_lengths = None
stacked_chosen_spans = None
stacked_rejected_spans = None
if sampling_parameters.use_loss_masking_spans:
stacked_spans = [sample.loss_masking_spans for sample in batch]
if sampling_parameters.use_preference_loss_spans:
stacked_chosen_spans = [sample.chosen_span for sample in batch]
stacked_rejected_spans = [sample.rejected_span for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [sample.sequence_lengths for sample in batch]
return GPTBatch(
token_ids=torch.stack([sample.token_ids for sample in batch]),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
)


class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
"""
A global class for all dataset needs, including loading, splitting, sampling and iteration.
Currently hard-coded to a GPT dataset.
TODO: Separate generic and GPT classes.
"""

_datasets: dict[str, SampledDataset]
Expand Down Expand Up @@ -124,7 +100,7 @@ def get_iterator(
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[GPTBatch]:
) -> typing.Iterator[LanguageModelBatch]:
assert self._is_setup

# Some dataset names may come from phases and are capitalized,
Expand All @@ -149,10 +125,7 @@ def get_iterator(
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=partial(
gpt_data_collate_fn,
sampling_parameters=sampling_parameters,
),
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
)
34 changes: 30 additions & 4 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import enum
import functools
import itertools
import math
Expand All @@ -15,6 +16,17 @@
from fast_llm.engine.distributed.distributed import Distributed


class ShufflingType(str, enum.Enum):
# Shuffle all epochs together. Not extendable.
full = "full"
# Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled.
epoch = "epoch"
# Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones.
skip_first_epoch = "skip_first_epoch"
# Disable shuffling entirely.
disabled = "disabled"


@config_class()
class SamplingConfig(Config):
"""
Expand All @@ -26,6 +38,18 @@ class SamplingConfig(Config):
desc="Seed for random sampling.",
hint=FieldHint.feature,
)
gpu: bool = Field(
default=True,
desc="Enable fast sampling on GPU."
" Note that random sampling works differently on GPU,"
" so the sample won't match the CPU equivalent.",
hint=FieldHint.feature,
)
shuffle: ShufflingType = Field(
default=ShufflingType.epoch,
desc="Shuffling strategy.",
hint=FieldHint.feature,
)


@dataclasses.dataclass(kw_only=True)
Expand All @@ -34,7 +58,12 @@ class SamplingParameters:
Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model.
"""

sequence_length: int
num_samples: int
truncate_documents: bool = True
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1


@dataclasses.dataclass(kw_only=True)
Expand Down Expand Up @@ -118,10 +147,7 @@ class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[Sampl
def build(self) -> "ConcatenatedDataset":
from fast_llm.data.dataset.indexed import ConcatenatedDataset

return self._build(ConcatenatedDataset)

def _build[T: ConcatenatedDataset](self, cls: type[T]) -> T:
return cls(self.name, [dataset.build() for dataset in self.datasets])
return ConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets])


@config_class(dynamic_type={SampledDatasetConfig: "slice"})
Expand Down
71 changes: 17 additions & 54 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import enum
import pathlib
import time
import typing
Expand All @@ -13,64 +12,27 @@
IndexedDatasetConfig,
SamplableDatasetConfig,
SampledDatasetConfig,
SamplingConfig,
SamplingData,
SamplingParameters,
)
from fast_llm.data.sample.gpt import GPTSample
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.fim import GPTFimDataset
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset


class ShufflingType(str, enum.Enum):
# Shuffle all epochs together. Not extendable.
full = "full"
# Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled.
epoch = "epoch"
# Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones.
skip_first_epoch = "skip_first_epoch"
# Disable shuffling entirely.
disabled = "disabled"


@config_class()
class GPTSamplingConfig(SamplingConfig):
"""
A dataset-dependent configuration for sampling.
"""

gpu: bool = Field(
default=True,
desc="Enable fast sampling on GPU."
" Note that random sampling works differently on GPU,"
" so the sample won't match the CPU equivalent.",
hint=FieldHint.feature,
)
shuffle: ShufflingType = Field(
default=ShufflingType.epoch,
desc="Shuffling strategy.",
hint=FieldHint.feature,
)


@dataclasses.dataclass(kw_only=True)
class GPTSamplingParameters(SamplingParameters):
"""
Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model.
"""

sequence_length: int
vocab_size: int
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
truncate_documents: bool = True
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1


@dataclasses.dataclass(kw_only=True)
Expand All @@ -80,27 +42,26 @@ class GPTSamplingData(SamplingData):
usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`.
"""

config: GPTSamplingConfig
parameters: GPTSamplingParameters


@config_class(dynamic_type={SampledDatasetConfig: "random"})
class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]):
class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]):
_abstract: typing.ClassVar[bool] = False
name: str = Field(
default="dummy",
desc="The name of the dataset.",
hint=FieldHint.core,
)

def build(self) -> "GPTRandomDataset":
def build(self) -> "GPTRandomDataset[SampleType]":
from fast_llm.data.dataset.gpt.random import GPTRandomDataset

return GPTRandomDataset(self.name)
return GPTRandomDataset[SampleType](self.name)


@config_class(dynamic_type={SampledDatasetConfig: "memmap"})
class GPTMemmapDatasetConfig[SampleType: GPTSample](IndexedDatasetConfig[SampleType]):
class GPTMemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]):
_abstract: typing.ClassVar[bool] = False
path: pathlib.Path = Field(
default=None,
Expand All @@ -118,14 +79,16 @@ class GPTMemmapDatasetConfig[SampleType: GPTSample](IndexedDatasetConfig[SampleT
hint=FieldHint.optional,
)

def build(self) -> "GPTMemmapDataset":
def build(self) -> "GPTMemmapDataset[SampleType]":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)
return GPTMemmapDataset[SampleType](
str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens
)


@config_class(dynamic_type={SampledDatasetConfig: "file"})
class GPTDatasetFromFileConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]):
class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]):
_abstract: typing.ClassVar[bool] = False
path: pathlib.Path = Field(
default=None,
Expand Down Expand Up @@ -235,30 +198,30 @@ class FimConfig(Config):


@config_class(dynamic_type={SampledDatasetConfig: "fim"})
class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType], FimConfig):
class GPTFimSampledDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType], FimConfig):
"""
Configuration for FIM.
"""

_abstract: typing.ClassVar[bool] = False

dataset: SampledDatasetConfig = Field(
dataset: SampledDatasetConfig[SampleType] = Field(
default=None,
desc="The dataset to wrap with fim.",
hint=FieldHint.core,
)

def build_and_sample(
self,
sampling: SamplingData,
) -> SampledDataset:
sampling: GPTSamplingData,
) -> "GPTFimDataset[SampleType]":
from fast_llm.data.dataset.gpt.fim import GPTFimDataset

return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling)
return GPTFimDataset[SampleType](self, self.dataset.build_and_sample(sampling), sampling)


@config_class(dynamic_type={SampledDatasetConfig: "test_slow"})
class GPTTestSlowDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType]):
class GPTTestSlowDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]):
"""
A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout.
"""
Expand Down
Loading