Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2ab1825
Fix rotary 2d
jlamypoirier Dec 4, 2025
8305dd5
stuff
jlamypoirier Dec 4, 2025
b6e38b8
stuff
jlamypoirier Dec 4, 2025
72f915d
Merge branch 'main' into jlp/consistent_preprocessing
jlamypoirier Dec 4, 2025
350fb3d
stuff
jlamypoirier Dec 5, 2025
d27a815
fix
jlamypoirier Dec 5, 2025
72f3a31
Merge branch 'main' into jlp/consistent_preprocessing
jlamypoirier Dec 5, 2025
5ab6cd0
fixes
jlamypoirier Dec 6, 2025
1e74469
Merge remote-tracking branch 'origin/main' into jlp/consistent_prepro…
jlamypoirier Dec 8, 2025
6454db4
stuff
jlamypoirier Dec 9, 2025
916af7a
cleanup
jlamypoirier Dec 9, 2025
355af7c
Merge remote-tracking branch 'origin/main' into jlp/consistent_prepro…
jlamypoirier Dec 9, 2025
8f6841e
Merge branch 'jlp/consistent_preprocessing' into jlp/varlen_tweaks
jlamypoirier Dec 10, 2025
bd7a8e6
cleanup
jlamypoirier Dec 10, 2025
660fecc
fix
jlamypoirier Dec 10, 2025
db93bb5
fixes
jlamypoirier Dec 10, 2025
a3fa577
Merge branch 'jlp/consistent_preprocessing' into jlp/varlen_tweaks
jlamypoirier Dec 10, 2025
a1c0ade
Merge remote-tracking branch 'origin/main' into jlp/varlen_tweaks
jlamypoirier Dec 10, 2025
96ce759
misc
jlamypoirier Dec 11, 2025
e23ea04
Merge remote-tracking branch 'origin/main' into jlp/varlen_tweaks
jlamypoirier Dec 11, 2025
e5fe8b2
stuff
jlamypoirier Dec 11, 2025
68f457b
fixes
jlamypoirier Dec 11, 2025
fa668fa
Merge remote-tracking branch 'origin/main' into jlp/varlen_tweaks
jlamypoirier Dec 12, 2025
f7c5d1b
Remove mamba and discrete mamba 2
jlamypoirier Dec 12, 2025
31d856d
fix
jlamypoirier Dec 12, 2025
e74d30d
Merge remote-tracking branch 'origin/main' into jlp_remove_mamba
jlamypoirier Dec 12, 2025
75ad78a
fixes
jlamypoirier Dec 12, 2025
30e0419
Add metadata to dataset config files
jlamypoirier Dec 12, 2025
9fae16e
fix
jlamypoirier Dec 13, 2025
2b6527a
Merge remote-tracking branch 'origin/main' into jlp_remove_mamba
jlamypoirier Dec 13, 2025
4ddabf1
fix
jlamypoirier Dec 13, 2025
69095ea
Merge branch 'jlp_remove_mamba' into jlp_dataset_metadata
jlamypoirier Dec 13, 2025
fabba8f
Merge remote-tracking branch 'origin/main' into jlp_dataset_metadata
jlamypoirier Dec 13, 2025
a8ad476
Merge remote-tracking branch 'origin/main' into jlp_dataset_metadata
jlamypoirier Dec 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy

def _load_config(self) -> SampledDatasetConfig[SampleType]:
assert self.path.is_file(), f"File {self.path} does not exist."
return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(yaml.safe_load(self.path.open("r"))))
config = yaml.safe_load(self.path.open("r"))
Assert.eq(config.keys(), {"config", "metadata"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlamypoirier this is causing crashes now. remove?

Suggested change
Assert.eq(config.keys(), {"config", "metadata"})

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I must have left this there by accident

if config.keys() == {"config", "metadata"}:
# Newer format with metadata
config = config["config"]
return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(config))

def _convert_paths(self, config):
# Recursively convert paths relative to `self.path.parent` to make them relative to cwd.
Expand Down
11 changes: 10 additions & 1 deletion fast_llm/data/dataset/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.indexed import IndexedDataset
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig, MemmapWriter, Sample
from fast_llm.data.sample.abstract import (
MemmapIndexDatasetReaderConfig,
MemmapIndexedDatasetReader,
MemmapWriter,
Sample,
)

FILE_HEADER = b"fast_llm_prepared_dataset"

Expand Down Expand Up @@ -82,6 +87,10 @@ def get_document_sizes(self) -> torch.Tensor:
def get_document_size(self, index: int) -> int:
return self._reader.get_document_size(index)

@property
def reader(self) -> MemmapIndexedDatasetReader:
return self._reader

@classmethod
def write_dataset(
cls,
Expand Down
65 changes: 34 additions & 31 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from fast_llm.data.sample.token import TokenSample
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
from fast_llm.utils import normalize_probabilities, padded_cumsum

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -346,16 +346,18 @@ def generate_config_yaml_for_sharded_dst(
# Create the config file(s) on rank 0
dataset_configs, reader_configs = zip(*dataset_and_reader_configs)
if self._config.splits:
for split_name, split_config in self._split_and_blend_dataset_configs(
for split_name, (split_config, metadata) in self._split_and_blend_dataset_configs(
dataset_configs, reader_configs, self._config.splits, self._config.output_path
).items():
self._save_dataset_config(
split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml"
split_config,
metadata,
output_path=self._config.output_path / f"fast_llm_config_{split_name}.yaml",
)
else:
self._save_dataset_config(
self._blend_dataset_configs(dataset_configs, reader_configs),
self._config.output_path / f"fast_llm_config.yaml",
*self._blend_dataset_configs(dataset_configs, reader_configs),
output_path=self._config.output_path / f"fast_llm_config.yaml",
)

# Save metadata on rank 0
Expand All @@ -368,37 +370,38 @@ def generate_config_yaml_for_sharded_dst(

@classmethod
def _save_dataset_config(
cls, dataset_config: IndexedDatasetConfig[_sample_type], output_path: pathlib.Path
cls,
dataset_config: IndexedDatasetConfig[_sample_type],
metadata: dict[str, typing.Any],
output_path: pathlib.Path,
) -> None:
logger.info(f"Saving config to {output_path}")
yaml.safe_dump(
dataset_config.to_dict(),
output_path.open("w"),
)
yaml.safe_dump({"config": dataset_config.to_dict(), "metadata": metadata}, output_path.open("w"))

@classmethod
def _blend_dataset_configs(
cls,
dataset_configs: list[MemmapDatasetConfig[_sample_type]],
reader_configs: list[MemmapIndexDatasetReaderConfig],
) -> IndexedDatasetConfig[_sample_type]:
) -> tuple[IndexedDatasetConfig[_sample_type], dict[str, typing.Any]]:
datasets_metadata = [reader_config.get_metadata() for reader_config in reader_configs]
if len(dataset_configs) == 1:
return dataset_configs[0]
return dataset_configs[0], datasets_metadata[0]
return SampledDatasetConfig[cls._sample_type].from_dict(
{
"type": "blended",
"datasets": dataset_configs,
"weights": [reader_config.num_tokens for reader_config in reader_configs],
}
)
), reader_configs[0].blend_metadata(datasets_metadata)

def _split_and_blend_dataset_configs(
self,
dataset_configs: list[MemmapDatasetConfig[_sample_type]],
reader_configs: list[MemmapIndexDatasetReaderConfig],
splits: dict[str, int | float],
output_path: pathlib.Path,
) -> dict[str, SampledDatasetConfig[_sample_type]]:
) -> dict[str, tuple[SampledDatasetConfig[_sample_type], dict[str, typing.Any]]]:
split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist()
dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs]
dataset_probabilities = normalize_probabilities(dataset_sizes)
Expand All @@ -407,7 +410,7 @@ def _split_and_blend_dataset_configs(

for split_index, split_name in enumerate(splits):
datasets_in_split = []
dataset_tokens_in_split = []
datasets_metadata = []
for dataset_index, (dataset_config, reader_config) in enumerate(
zip(dataset_configs, reader_configs, strict=True)
):
Expand All @@ -424,17 +427,17 @@ def _split_and_blend_dataset_configs(
if split_begin_in_dataset == 0 and split_end_in_dataset == 1:
# All the dataset belongs to the split.
datasets_in_split.append(dataset_configs[dataset_index])
dataset_tokens_in_split.append(dataset_sizes[dataset_index])
datasets_metadata.append(reader_config.get_metadata())

elif split_end_in_dataset > split_begin_in_dataset:
# Part of the dataset belongs to the split.
# TODO: Somehow getting a segfault when merging two lines below (numpy bug?).
dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build(
self._preprocessing_config
)
sizes_cumsum = dataset.get_document_sizes().numpy().cumsum()
Assert.eq(sizes_cumsum[-1], reader_config.num_tokens)
begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens)
end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens)
begin_index, end_index, metadata = dataset.reader.get_split(
split_begin_in_dataset, split_end_in_dataset
)
if end_index > begin_index:
datasets_in_split.append(
DatasetSliceConfig[self._sample_type].from_dict(
Expand All @@ -446,25 +449,25 @@ def _split_and_blend_dataset_configs(
}
)
)
dataset_tokens_in_split.append(
sizes_cumsum[end_index - 1].item()
- (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0)
)
datasets_metadata.append(metadata)

# [else] None of the dataset belongs to the split.

if len(datasets_in_split) == 0:
# This is a big problem, but we don't want to crash the whole run.
logger.error(f"Datasets split {split_name} is empty!")
elif len(datasets_in_split) == 1:
dataset_splits[split_name] = datasets_in_split[0]
dataset_splits[split_name] = (datasets_in_split[0], datasets_metadata[0])
else:
dataset_splits[split_name] = BlendedDatasetConfig[self._sample_type].from_dict(
{
"type": "blended",
"datasets": datasets_in_split,
"weights": dataset_tokens_in_split,
}
dataset_splits[split_name] = (
BlendedDatasetConfig[self._sample_type].from_dict(
{
"type": "blended",
"datasets": datasets_in_split,
"weights": [dataset_metadata["num_tokens"] for dataset_metadata in datasets_metadata],
}
),
reader_configs[0].blend_metadata(datasets_metadata),
)

return dataset_splits
Expand Down
17 changes: 17 additions & 0 deletions fast_llm/data/sample/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def expected_buffer_size(self) -> int:
"""
raise NotImplementedError()

def get_metadata(self) -> dict[str, typing.Any]:
raise NotImplementedError()

@classmethod
def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
raise NotImplementedError()


@config_class(dynamic_type={MemmapReaderBaseConfig: "none"})
class NullReaderConfig(MemmapReaderBaseConfig):
Expand Down Expand Up @@ -159,6 +166,13 @@ def reader_class(self) -> "type[MemmapIndexedDatasetReader]":
def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig) -> "MemmapIndexedDatasetReader":
return self.reader_class(self, buffer, model_preprocessing)

def get_metadata(self) -> dict[str, typing.Any]:
return {"num_tokens": self.num_tokens}

@classmethod
def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
return {"num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata)}


class MemmapReaderBase[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]):
@abc.abstractmethod
Expand Down Expand Up @@ -196,6 +210,9 @@ def get_document_sizes(self) -> "torch.Tensor":
def get_document_size(self, index: int) -> int:
pass

def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]:
raise NotImplementedError()


class MemmapWriter(abc.ABC):
def __init__(
Expand Down
54 changes: 54 additions & 0 deletions fast_llm/data/sample/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from fast_llm.data.sample.patch import (
EmptyPatchReader,
PatchBatch,
PatchReader,
PatchReaderBaseConfig,
PatchReaderConfig,
PatchSample,
Expand All @@ -31,6 +32,7 @@
from fast_llm.data.sample.range import (
EmptyRangeReader,
RangeBatch,
RangeReader,
RangeReaderBaseConfig,
RangeReaderConfig,
RangeSample,
Expand Down Expand Up @@ -222,6 +224,41 @@ def _expected_buffer_size(self) -> int:
+ self.image_patches.expected_buffer_size
)

def get_metadata(self) -> dict[str, typing.Any]:
out = super().get_metadata()
out["tokens"] = self.tokens.get_metadata()
if not isinstance(self.loss_masking_spans, NullReaderConfig):
out["loss_masking_spans"] = self.loss_masking_spans.get_metadata()
if not isinstance(self.chosen_spans, NullReaderConfig):
out["chosen_spans"] = self.chosen_spans.get_metadata()
if not isinstance(self.rejected_spans, NullReaderConfig):
out["rejected_spans"] = self.rejected_spans.get_metadata()
if not isinstance(self.image_patches, NullReaderConfig):
out["image_patches"] = self.image_patches.get_metadata()
return out

@classmethod
def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
out = super().blend_metadata(metadata)
out["tokens"] = TokenReaderConfig.blend_metadata([metadata_["tokens"] for metadata_ in metadata])
if "loss_masking_spans" in metadata[0]:
out["loss_masking_spans"] = RangeReaderConfig.blend_metadata(
[metadata_["loss_masking_spans"] for metadata_ in metadata]
)
if "chosen_spans" in metadata[0]:
out["chosen_spans"] = RangeReaderConfig.blend_metadata(
[metadata_["chosen_spans"] for metadata_ in metadata]
)
if "rejected_spans" in metadata[0]:
out["image_patches"] = RangeReaderConfig.blend_metadata(
[metadata_["image_patches"] for metadata_ in metadata]
)
if "image_patches" in metadata[0]:
out["image_patches"] = PatchReaderConfig.blend_metadata(
[metadata_["image_patches"] for metadata_ in metadata]
)
return out


class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]):
_model_preprocessing: LanguageModelPreprocessingConfig
Expand Down Expand Up @@ -305,6 +342,23 @@ def get_document_sizes(self) -> torch.Tensor:
def get_document_size(self, index: int) -> int:
return self._tokens.get_document_size(index)

def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]:
begin_index, end_index, token_metadata = self._tokens.get_split(begin_ratio, end_ratio)
metadata = {
"num_tokens": token_metadata["num_tokens"],
"tokens": token_metadata,
}
if hasattr(self, "_loss_masking_spans") and isinstance(self._loss_masking_spans, RangeReader):
metadata["loss_masking_spans"] = self._loss_masking_spans.get_split(begin_index, end_index)
if hasattr(self, "_chosen_spans") and isinstance(self._chosen_spans, RangeReader):
metadata["chosen_spans"] = self._chosen_spans.get_split(begin_index, end_index)
if hasattr(self, "_rejected_spans") and isinstance(self._rejected_spans, RangeReader):
metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index)
if hasattr(self, "_image_patches") and isinstance(self._image_patches, PatchReader):
metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index)

return begin_index, end_index, metadata


class LanguageModelWriter(MemmapWriter):
_preprocessing_config: LanguageModelPreprocessingConfig
Expand Down
34 changes: 34 additions & 0 deletions fast_llm/data/sample/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,27 @@ def _expected_buffer_size(self) -> int:
* torch.int32.itemsize
)

def get_metadata(self) -> dict[str, typing.Any]:
return {
"num_documents": self.num_documents,
"num_patches": self.num_patches,
"num_patch_groups": self.num_patch_groups,
"num_pixels": self.patch_size * self.num_patches,
"patch_shape": self.patch_shape,
"data_type": str(self.data_type),
}

@classmethod
def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
return {
"num_documents": sum(metadata_["num_documents"] for metadata_ in metadata),
"num_patches": sum(metadata_["num_patches"] for metadata_ in metadata),
"num_patch_groups": sum(metadata_["num_patch_groups"] for metadata_ in metadata),
"num_pixels": sum(metadata_["num_pixels"] for metadata_ in metadata),
"patch_shape": get_unique(metadata_["patch_shape"] for metadata_ in metadata),
"data_type": get_unique(metadata_["data_type"] for metadata_ in metadata),
}


class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]):
def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None):
Expand Down Expand Up @@ -253,6 +274,19 @@ def get_document(self, index: int, begin: int, end: int) -> Sample:
),
)

def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]:
Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents])
num_patches = self._patch_count_cumsums[end_index].item() - self._patch_count_cumsums[begin_index].item()
return {
"num_documents": end_index - begin_index,
"num_patches": num_patches,
"num_patch_groups": self._group_count_cumsums[end_index].item()
- self._group_count_cumsums[begin_index].item(),
"num_pixels": self._config.patch_size * num_patches,
"patch_shape": self._config.patch_shape,
"data_type": str(self._config.data_type),
}


class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]):
def get_document(self, index: int, begin: int, end: int) -> Sample:
Expand Down
20 changes: 20 additions & 0 deletions fast_llm/data/sample/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ def writer_class(self) -> "type[RangeWriter]":
def _expected_buffer_size(self) -> int:
return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize

def get_metadata(self) -> dict[str, typing.Any]:
return {
"num_documents": self.num_documents,
"num_ranges": self.num_ranges,
}

@classmethod
def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
return {
"num_documents": sum(metadata_["num_documents"] for metadata_ in metadata),
"num_ranges": sum(metadata_["num_ranges"] for metadata_ in metadata),
}


class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]):
def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None):
Expand All @@ -116,6 +129,13 @@ def get_document(self, index: int, begin: int, end: int) -> Sample:
)
return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size)

def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]:
Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents])
return {
"num_documents": end_index - begin_index,
"num_ranges": self._count_cumsums[end_index].item() - self._count_cumsums[begin_index].item(),
}


class EmptyRangeReader[ConfigType: RangeReaderBaseConfig](MemmapReaderBase[ConfigType]):
def get_document(self, index: int, begin: int, end: int) -> Sample:
Expand Down
Loading