Skip to content

Add data cleaning in fast-llm prepare, concept #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 82 additions & 7 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@

import yaml

from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value
from fast_llm.utils import (
Assert,
Tag,
Registry,
get_type_name,
header,
log,
pop_nested_dict_value,
set_nested_dict_value,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -634,17 +643,17 @@ def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None:
value = str(value)
return value

def to_copy[
T
](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T:
def to_copy[T](
self: T,
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
) -> T:
return self.from_dict(self, *updates, strict=strict)

def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]:
return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True)

def to_logs[
T
](
def to_logs[T](
self,
verbose: int | None = FieldVerboseLevel.core,
log_fn: typing.Callable[[str], T] = logger.info,
Expand Down Expand Up @@ -916,3 +925,69 @@ def __init__(self, config: ConfigType, *args, **kwargs):
@property
def config(self) -> ConfigType:
return self._config


@config_class()
class TypeableConfig(Config):
"""
Base Config class that instantiates a subclass type
based on the 'type' field in config files or params.
The root class must define its own _registry, and
subclasses must set a unique _type. Final classes
to be instantiated should have _abstract as False.
"""

_abstract: typing.ClassVar[bool] = True
_registry: typing.ClassVar[Registry[str, type["TypeableConfig"]] | None] = None

type_: typing.ClassVar[str | None] = None
type: str | None = Field(
default=None,
desc="Config specifieble type of the class.",
hint=FieldHint.core,
)

def _validate(self) -> None:
if self.type is None:
self.type = self.type_
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.eq(self.type, self.__class__.type_)
super()._validate()

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
type_ = default.get("type")
if type_ is None:
actual_cls = cls
else:
if type_ not in cls._registry:
raise ValueError(
f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}"
)
actual_cls = cls._registry[type_]
Assert.custom(issubclass, actual_cls, cls)
if actual_cls == cls:
return super()._from_dict(default, strict=strict, flat=flat)
else:
return actual_cls._from_dict(default, strict=strict, flat=flat)

def __init_subclass__(cls) -> None:
registry = getattr(cls, "_registry")
if registry is None:
raise ValueError(f"Sublass {cls.__name__} or one of its parents needs to set __registry")
if cls._abstract and cls.type_ is not None:
# Abstract classes should not have a `type_`
raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.")
if cls.type_ is not None:
if cls.type_ in registry:
raise ValueError(
f"Registry {cls._registry.name} already contains type {cls.type_}."
f" Make sure all classes either have a unique or `None` type."
)
registry[cls.type_] = cls
super().__init_subclass__()
4 changes: 4 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert

from fast_llm.data.preparator.hf_processors.configs.agregator import AgregatorConfig

if typing.TYPE_CHECKING:
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator
MEMMAP_DTYPES = {
Expand Down Expand Up @@ -165,6 +167,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
hint=FieldHint.optional,
)

processors: AgregatorConfig = Field(default=AgregatorConfig)

def _validate(self) -> None:
assert self.tokenizer.path is not None
if self.dataset.data_type is not None:
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def run(self) -> None:
else:
tokenize_fn = self._tokenize_batch

# Process dataset before tokenizing
dataset = self._config.processors.apply(dataset)

# Tokenize the dataset in parallel
tokenized_dataset = dataset.map(
tokenize_fn,
Expand Down
21 changes: 21 additions & 0 deletions fast_llm/data/preparator/hf_processors/configs/agregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import datasets

from fast_llm.data.preparator.hf_processors.configs.base import Applicable, ShardProcessorConfig
from fast_llm.config import Field, Config, config_class

from fast_llm.data.preparator.hf_processors.configs.doc_length_filter import DocLengthFilterConfig

def default_processors():
"""Default processors to apply"""
return [DocLengthFilterConfig()]


@config_class
class AgregatorConfig(Config, Applicable):
steps: list[ShardProcessorConfig] = Field(default_factory=default_processors)

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.hf_processors.implementations.agregator import apply
return apply(self, dataset)


19 changes: 19 additions & 0 deletions fast_llm/data/preparator/hf_processors/configs/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import abc
import typing
import datasets

from fast_llm.config import TypeableConfig, config_class
from fast_llm.utils import Registry


class Applicable:
@abc.abstractmethod
def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
raise NotImplementedError


@config_class()
class ShardProcessorConfig(TypeableConfig, Applicable):
_registry: typing.ClassVar[Registry[str, type["ShardProcessorConfig"]] | None] = Registry(
"ShardProcessorConfig", {}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import abc
import typing
import datasets

from fast_llm.data.preparator.hf_processors.configs.base import Applicable, ShardProcessorConfig
from fast_llm.config import Field, config_class


@config_class
class DocLengthFilterConfig(ShardProcessorConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "length_filter"

field: str = Field(default='text')
min_length_chars: int = Field(default=0)
max_length_chars: int = Field(default=1_000_000)

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.hf_processors.implementations.doc_length_filter import apply
return apply(self, dataset)


Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import datasets

from fast_llm.data.preparator.hf_processors.configs.agregator import AgregatorConfig

def apply(config: AgregatorConfig, dataset: datasets.Dataset) -> datasets.Dataset:
# do something before applyting each processor
for step in config.steps:
dataset = step.apply(dataset)
# compute metrics
# save meterics, from all ranks?
return dataset
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import datasets

from fast_llm.data.preparator.hf_processors.configs.doc_length_filter import DocLengthFilterConfig

def apply(config: DocLengthFilterConfig, dataset: datasets.Dataset) -> datasets.Dataset:
# do dataset.filter eliminating too long or too short docs
return dataset
Loading