diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index fc326d366..41a2fe7ff 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -64,7 +64,12 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." - return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) + config = yaml.safe_load(self.path.open("r")) + Assert.eq(config.keys(), {"config", "metadata"}) + if config.keys() == {"config", "metadata"}: + # Newer format with metadata + config = config["config"] + return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(config)) def _convert_paths(self, config): # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index f80a48b0a..9831f81ba 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -8,7 +8,12 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.preprocessing.abstract import PreprocessingConfig -from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig, MemmapWriter, Sample +from fast_llm.data.sample.abstract import ( + MemmapIndexDatasetReaderConfig, + MemmapIndexedDatasetReader, + MemmapWriter, + Sample, +) FILE_HEADER = b"fast_llm_prepared_dataset" @@ -82,6 +87,10 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._reader.get_document_size(index) + @property + def reader(self) -> MemmapIndexedDatasetReader: + return self._reader + @classmethod def write_dataset( cls, diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 2ea81d8a6..e0f5f02fc 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -39,7 +39,7 @@ from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum +from fast_llm.utils import normalize_probabilities, padded_cumsum logger = logging.getLogger(__name__) @@ -346,16 +346,18 @@ def generate_config_yaml_for_sharded_dst( # Create the config file(s) on rank 0 dataset_configs, reader_configs = zip(*dataset_and_reader_configs) if self._config.splits: - for split_name, split_config in self._split_and_blend_dataset_configs( + for split_name, (split_config, metadata) in self._split_and_blend_dataset_configs( dataset_configs, reader_configs, self._config.splits, self._config.output_path ).items(): self._save_dataset_config( - split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" + split_config, + metadata, + output_path=self._config.output_path / f"fast_llm_config_{split_name}.yaml", ) else: self._save_dataset_config( - self._blend_dataset_configs(dataset_configs, reader_configs), - self._config.output_path / f"fast_llm_config.yaml", + *self._blend_dataset_configs(dataset_configs, reader_configs), + output_path=self._config.output_path / f"fast_llm_config.yaml", ) # Save metadata on rank 0 @@ -368,29 +370,30 @@ def generate_config_yaml_for_sharded_dst( @classmethod def _save_dataset_config( - cls, dataset_config: IndexedDatasetConfig[_sample_type], output_path: pathlib.Path + cls, + dataset_config: IndexedDatasetConfig[_sample_type], + metadata: dict[str, typing.Any], + output_path: pathlib.Path, ) -> None: logger.info(f"Saving config to {output_path}") - yaml.safe_dump( - dataset_config.to_dict(), - output_path.open("w"), - ) + yaml.safe_dump({"config": dataset_config.to_dict(), "metadata": metadata}, output_path.open("w")) @classmethod def _blend_dataset_configs( cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]], reader_configs: list[MemmapIndexDatasetReaderConfig], - ) -> IndexedDatasetConfig[_sample_type]: + ) -> tuple[IndexedDatasetConfig[_sample_type], dict[str, typing.Any]]: + datasets_metadata = [reader_config.get_metadata() for reader_config in reader_configs] if len(dataset_configs) == 1: - return dataset_configs[0] + return dataset_configs[0], datasets_metadata[0] return SampledDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": dataset_configs, "weights": [reader_config.num_tokens for reader_config in reader_configs], } - ) + ), reader_configs[0].blend_metadata(datasets_metadata) def _split_and_blend_dataset_configs( self, @@ -398,7 +401,7 @@ def _split_and_blend_dataset_configs( reader_configs: list[MemmapIndexDatasetReaderConfig], splits: dict[str, int | float], output_path: pathlib.Path, - ) -> dict[str, SampledDatasetConfig[_sample_type]]: + ) -> dict[str, tuple[SampledDatasetConfig[_sample_type], dict[str, typing.Any]]]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) @@ -407,7 +410,7 @@ def _split_and_blend_dataset_configs( for split_index, split_name in enumerate(splits): datasets_in_split = [] - dataset_tokens_in_split = [] + datasets_metadata = [] for dataset_index, (dataset_config, reader_config) in enumerate( zip(dataset_configs, reader_configs, strict=True) ): @@ -424,17 +427,17 @@ def _split_and_blend_dataset_configs( if split_begin_in_dataset == 0 and split_end_in_dataset == 1: # All the dataset belongs to the split. datasets_in_split.append(dataset_configs[dataset_index]) - dataset_tokens_in_split.append(dataset_sizes[dataset_index]) + datasets_metadata.append(reader_config.get_metadata()) + elif split_end_in_dataset > split_begin_in_dataset: # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build( self._preprocessing_config ) - sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() - Assert.eq(sizes_cumsum[-1], reader_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens) + begin_index, end_index, metadata = dataset.reader.get_split( + split_begin_in_dataset, split_end_in_dataset + ) if end_index > begin_index: datasets_in_split.append( DatasetSliceConfig[self._sample_type].from_dict( @@ -446,10 +449,7 @@ def _split_and_blend_dataset_configs( } ) ) - dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) - ) + datasets_metadata.append(metadata) # [else] None of the dataset belongs to the split. @@ -457,14 +457,17 @@ def _split_and_blend_dataset_configs( # This is a big problem, but we don't want to crash the whole run. logger.error(f"Datasets split {split_name} is empty!") elif len(datasets_in_split) == 1: - dataset_splits[split_name] = datasets_in_split[0] + dataset_splits[split_name] = (datasets_in_split[0], datasets_metadata[0]) else: - dataset_splits[split_name] = BlendedDatasetConfig[self._sample_type].from_dict( - { - "type": "blended", - "datasets": datasets_in_split, - "weights": dataset_tokens_in_split, - } + dataset_splits[split_name] = ( + BlendedDatasetConfig[self._sample_type].from_dict( + { + "type": "blended", + "datasets": datasets_in_split, + "weights": [dataset_metadata["num_tokens"] for dataset_metadata in datasets_metadata], + } + ), + reader_configs[0].blend_metadata(datasets_metadata), ) return dataset_splits diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 1d71363b7..494a5c4a5 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -73,6 +73,13 @@ def expected_buffer_size(self) -> int: """ raise NotImplementedError() + def get_metadata(self) -> dict[str, typing.Any]: + raise NotImplementedError() + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + raise NotImplementedError() + @config_class(dynamic_type={MemmapReaderBaseConfig: "none"}) class NullReaderConfig(MemmapReaderBaseConfig): @@ -159,6 +166,13 @@ def reader_class(self) -> "type[MemmapIndexedDatasetReader]": def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig) -> "MemmapIndexedDatasetReader": return self.reader_class(self, buffer, model_preprocessing) + def get_metadata(self) -> dict[str, typing.Any]: + return {"num_tokens": self.num_tokens} + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return {"num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata)} + class MemmapReaderBase[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): @abc.abstractmethod @@ -196,6 +210,9 @@ def get_document_sizes(self) -> "torch.Tensor": def get_document_size(self, index: int) -> int: pass + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + raise NotImplementedError() + class MemmapWriter(abc.ABC): def __init__( diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 3183a9ec1..22b89acf1 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -23,6 +23,7 @@ from fast_llm.data.sample.patch import ( EmptyPatchReader, PatchBatch, + PatchReader, PatchReaderBaseConfig, PatchReaderConfig, PatchSample, @@ -31,6 +32,7 @@ from fast_llm.data.sample.range import ( EmptyRangeReader, RangeBatch, + RangeReader, RangeReaderBaseConfig, RangeReaderConfig, RangeSample, @@ -222,6 +224,41 @@ def _expected_buffer_size(self) -> int: + self.image_patches.expected_buffer_size ) + def get_metadata(self) -> dict[str, typing.Any]: + out = super().get_metadata() + out["tokens"] = self.tokens.get_metadata() + if not isinstance(self.loss_masking_spans, NullReaderConfig): + out["loss_masking_spans"] = self.loss_masking_spans.get_metadata() + if not isinstance(self.chosen_spans, NullReaderConfig): + out["chosen_spans"] = self.chosen_spans.get_metadata() + if not isinstance(self.rejected_spans, NullReaderConfig): + out["rejected_spans"] = self.rejected_spans.get_metadata() + if not isinstance(self.image_patches, NullReaderConfig): + out["image_patches"] = self.image_patches.get_metadata() + return out + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + out = super().blend_metadata(metadata) + out["tokens"] = TokenReaderConfig.blend_metadata([metadata_["tokens"] for metadata_ in metadata]) + if "loss_masking_spans" in metadata[0]: + out["loss_masking_spans"] = RangeReaderConfig.blend_metadata( + [metadata_["loss_masking_spans"] for metadata_ in metadata] + ) + if "chosen_spans" in metadata[0]: + out["chosen_spans"] = RangeReaderConfig.blend_metadata( + [metadata_["chosen_spans"] for metadata_ in metadata] + ) + if "rejected_spans" in metadata[0]: + out["image_patches"] = RangeReaderConfig.blend_metadata( + [metadata_["image_patches"] for metadata_ in metadata] + ) + if "image_patches" in metadata[0]: + out["image_patches"] = PatchReaderConfig.blend_metadata( + [metadata_["image_patches"] for metadata_ in metadata] + ) + return out + class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): _model_preprocessing: LanguageModelPreprocessingConfig @@ -305,6 +342,23 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._tokens.get_document_size(index) + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + begin_index, end_index, token_metadata = self._tokens.get_split(begin_ratio, end_ratio) + metadata = { + "num_tokens": token_metadata["num_tokens"], + "tokens": token_metadata, + } + if hasattr(self, "_loss_masking_spans") and isinstance(self._loss_masking_spans, RangeReader): + metadata["loss_masking_spans"] = self._loss_masking_spans.get_split(begin_index, end_index) + if hasattr(self, "_chosen_spans") and isinstance(self._chosen_spans, RangeReader): + metadata["chosen_spans"] = self._chosen_spans.get_split(begin_index, end_index) + if hasattr(self, "_rejected_spans") and isinstance(self._rejected_spans, RangeReader): + metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index) + if hasattr(self, "_image_patches") and isinstance(self._image_patches, PatchReader): + metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index) + + return begin_index, end_index, metadata + class LanguageModelWriter(MemmapWriter): _preprocessing_config: LanguageModelPreprocessingConfig diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 221746752..7ae537104 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -192,6 +192,27 @@ def _expected_buffer_size(self) -> int: * torch.int32.itemsize ) + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_documents": self.num_documents, + "num_patches": self.num_patches, + "num_patch_groups": self.num_patch_groups, + "num_pixels": self.patch_size * self.num_patches, + "patch_shape": self.patch_shape, + "data_type": str(self.data_type), + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "num_patches": sum(metadata_["num_patches"] for metadata_ in metadata), + "num_patch_groups": sum(metadata_["num_patch_groups"] for metadata_ in metadata), + "num_pixels": sum(metadata_["num_pixels"] for metadata_ in metadata), + "patch_shape": get_unique(metadata_["patch_shape"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + } + class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): @@ -253,6 +274,19 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: ), ) + def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: + Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) + num_patches = self._patch_count_cumsums[end_index].item() - self._patch_count_cumsums[begin_index].item() + return { + "num_documents": end_index - begin_index, + "num_patches": num_patches, + "num_patch_groups": self._group_count_cumsums[end_index].item() + - self._group_count_cumsums[begin_index].item(), + "num_pixels": self._config.patch_size * num_patches, + "patch_shape": self._config.patch_shape, + "data_type": str(self._config.data_type), + } + class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]): def get_document(self, index: int, begin: int, end: int) -> Sample: diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index a77846725..53683342a 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -92,6 +92,19 @@ def writer_class(self) -> "type[RangeWriter]": def _expected_buffer_size(self) -> int: return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_documents": self.num_documents, + "num_ranges": self.num_ranges, + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "num_ranges": sum(metadata_["num_ranges"] for metadata_ in metadata), + } + class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): @@ -116,6 +129,13 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: ) return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: + Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) + return { + "num_documents": end_index - begin_index, + "num_ranges": self._count_cumsums[end_index].item() - self._count_cumsums[begin_index].item(), + } + class EmptyRangeReader[ConfigType: RangeReaderBaseConfig](MemmapReaderBase[ConfigType]): def get_document(self, index: int, begin: int, end: int) -> Sample: diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 1bc9ef1a1..cd4d7fa02 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -14,7 +14,7 @@ Sample, ) from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_unique def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]: @@ -110,6 +110,21 @@ def writer_class(self) -> "type[TokenWriter]": def _expected_buffer_size(self) -> int: return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.int64.itemsize + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_tokens": self.num_tokens, + "num_documents": self.num_documents, + "data_type": str(self.data_type), + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata), + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + } + class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): @@ -135,6 +150,28 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item() + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + Assert.custom(lambda x: x == sorted(x), [0, begin_ratio, end_ratio, 1]) + begin_index = _get_nearest_split(self._size_cumsums[1:], begin_ratio * self.num_tokens) + end_index = _get_nearest_split(self._size_cumsums[1:], end_ratio * self.num_tokens) + + return ( + begin_index, + end_index, + { + "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), + "num_documents": end_index - begin_index, + "data_type": str(self._config.data_type), + }, + ) + + +def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: + left = torch.searchsorted(cumsum, value, side="right") + if left == len(cumsum): + return left.item() + return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + class TokenWriter(MemmapWriter): def __enter__(self):