diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index d81b7ddf..3262de8a 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -8,6 +8,7 @@ from guidellm.backend import BackendType from guidellm.benchmark import ProfileType, benchmark_generative_text from guidellm.config import print_config +from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset from guidellm.scheduler import StrategyType STRATEGY_PROFILE_CHOICES = set( @@ -290,5 +291,187 @@ def config(): print_config() +@cli.group(help="Preprocessing utilities for datasets.") +def preprocess(): + pass + + +@preprocess.command( + help="Convert a dataset to have specific prompt and output token sizes.\n\n" + "INPUT_DATA: Path to the input dataset or dataset ID.\n" + "OUTPUT_PATH: Directory to save the converted dataset. " + "The dataset will be saved as an Arrow dataset (.arrow) inside the directory." +) +@click.argument( + "input_data", + type=str, + metavar="INPUT_DATA", + required=True, +) +@click.argument( + "output_path", + type=click.Path(file_okay=True, dir_okay=False, writable=True, resolve_path=True), + metavar="OUTPUT_PATH", + required=True, +) +@click.option( + "--processor", + type=str, + required=True, + help=( + "The processor or tokenizer to use to calculate token counts for statistics " + "and synthetic data generation." + ), +) +@click.option( + "--processor-args", + default=None, + callback=parse_json, + help=( + "A JSON string containing any arguments to pass to the processor constructor " + "as a dict with **kwargs." + ), +) +@click.option( + "--data-args", + callback=parse_json, + help=( + "A JSON string containing any arguments to pass to the dataset creation " + "as a dict with **kwargs." + ), +) +@click.option( + "--short-prompt-strategy", + type=click.Choice([s.value for s in ShortPromptStrategy]), + default=ShortPromptStrategy.IGNORE.value, + show_default=True, + help="Strategy to handle prompts shorter than the target length. ", +) +@click.option( + "--pad-token", + type=str, + default=None, + help="The token to pad short prompts with when using the 'pad' strategy.", +) +@click.option( + "--prompt-tokens-average", + type=int, + default=10, + show_default=True, + help="Average target number of tokens for prompts.", +) +@click.option( + "--prompt-tokens-stdev", + type=int, + default=None, + help="Standard deviation for prompt tokens sampling.", +) +@click.option( + "--prompt-tokens-min", + type=int, + default=None, + help="Minimum number of prompt tokens.", +) +@click.option( + "--prompt-tokens-max", + type=int, + default=None, + help="Maximum number of prompt tokens.", +) +@click.option( + "--prompt-random-seed", + type=int, + default=42, + show_default=True, + help="Random seed for prompt token sampling.", +) +@click.option( + "--output-tokens-average", + type=int, + default=10, + show_default=True, + help="Average target number of tokens for outputs.", +) +@click.option( + "--output-tokens-stdev", + type=int, + default=None, + help="Standard deviation for output tokens sampling.", +) +@click.option( + "--output-tokens-min", + type=int, + default=None, + help="Minimum number of output tokens.", +) +@click.option( + "--output-tokens-max", + type=int, + default=None, + help="Maximum number of output tokens.", +) +@click.option( + "--output-random-seed", + type=int, + default=123, + show_default=True, + help="Random seed for output token sampling.", +) +@click.option( + "--push-to-hub", + is_flag=True, + help="Set this flag to push the converted dataset to the Hugging Face Hub.", +) +@click.option( + "--hub-dataset-id", + type=str, + default=None, + help="The Hugging Face Hub dataset ID to push to. " + "Required if --push-to-hub is used.", +) +def dataset( + input_data, + output_path, + processor, + processor_args, + data_args, + short_prompt_strategy, + pad_token, + prompt_tokens_average, + prompt_tokens_stdev, + prompt_tokens_min, + prompt_tokens_max, + prompt_random_seed, + output_tokens_average, + output_tokens_stdev, + output_tokens_min, + output_tokens_max, + output_random_seed, + push_to_hub, + hub_dataset_id, +): + process_dataset( + input_data=input_data, + output_path=output_path, + processor=processor, + processor_args=processor_args, + data_args=data_args, + short_prompt_strategy=short_prompt_strategy, + pad_token=pad_token, + prompt_tokens_average=prompt_tokens_average, + prompt_tokens_stdev=prompt_tokens_stdev, + prompt_tokens_min=prompt_tokens_min, + prompt_tokens_max=prompt_tokens_max, + prompt_random_seed=prompt_random_seed, + output_tokens_average=output_tokens_average, + output_tokens_stdev=output_tokens_stdev, + output_tokens_min=output_tokens_min, + output_tokens_max=output_tokens_max, + output_random_seed=output_random_seed, + push_to_hub=push_to_hub, + hub_dataset_id=hub_dataset_id, + ) + + if __name__ == "__main__": cli() diff --git a/src/guidellm/preprocess/__init__.py b/src/guidellm/preprocess/__init__.py new file mode 100644 index 00000000..95d01e5f --- /dev/null +++ b/src/guidellm/preprocess/__init__.py @@ -0,0 +1,3 @@ +from .dataset import ShortPromptStrategy, process_dataset + +__all__ = ["ShortPromptStrategy", "process_dataset"] diff --git a/src/guidellm/preprocess/dataset.py b/src/guidellm/preprocess/dataset.py new file mode 100644 index 00000000..baae943c --- /dev/null +++ b/src/guidellm/preprocess/dataset.py @@ -0,0 +1,224 @@ +import os +from collections.abc import Iterator +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Optional, Union + +from datasets import Dataset +from loguru import logger +from transformers import PreTrainedTokenizerBase + +from guidellm.dataset import load_dataset as guidellm_load_dataset +from guidellm.utils import IntegerRangeSampler, check_load_processor + +SUPPORTED_TYPES = { + ".json", + ".csv", + ".parquet", +} + + +class ShortPromptStrategy(str, Enum): + IGNORE = "ignore" + CONCATENATE = "concatenate" + PAD = "pad" + + +def handle_ignore_strategy( + current_prompt: str, + min_prompt_tokens: int, + tokenizer: PreTrainedTokenizerBase, + **_kwargs, +) -> Optional[str]: + if len(tokenizer.encode(current_prompt)) < min_prompt_tokens: + logger.warning("Prompt too short, ignoring") + return None + return current_prompt + + +def handle_concatenate_strategy( + current_prompt: str, + min_prompt_tokens: int, + dataset_iterator: Iterator[dict[str, Any]], + prompt_column: str, + tokenizer: PreTrainedTokenizerBase, + **_kwargs, +) -> Optional[str]: + tokens_len = len(tokenizer.encode(current_prompt)) + while tokens_len < min_prompt_tokens: + try: + next_row = next(dataset_iterator) + except StopIteration: + logger.warning( + "Could not concatenate enough prompts to reach minimum length, ignoring" + ) + return None + current_prompt += next_row[prompt_column] + tokens_len = len(tokenizer.encode(current_prompt)) + return current_prompt + + +def handle_pad_strategy( + current_prompt: str, + min_prompt_tokens: int, + tokenizer: PreTrainedTokenizerBase, + pad_token: str, + **_kwargs, +) -> str: + while len(tokenizer.encode(current_prompt)) < min_prompt_tokens: + current_prompt += pad_token + return current_prompt + + +STRATEGY_HANDLERS: dict[ShortPromptStrategy, Callable] = { + ShortPromptStrategy.IGNORE: handle_ignore_strategy, + ShortPromptStrategy.CONCATENATE: handle_concatenate_strategy, + ShortPromptStrategy.PAD: handle_pad_strategy, +} + + +def save_dataset_to_file(dataset: Dataset, output_path: Union[str, Path]) -> None: + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + if output_path.suffix == ".csv": + dataset.to_csv(str(output_path)) + elif output_path.suffix == ".json": + dataset.to_json(str(output_path)) + elif output_path.suffix == ".parquet": + dataset.to_parquet(str(output_path)) + else: + raise ValueError( + f"Unsupported file suffix '{output_path.suffix}' in output_path " + f"'{output_path}'. Only {SUPPORTED_TYPES} are supported." + ) + + +def _validate_output_suffix(output_path: Union[str, Path]) -> None: + output_path = Path(output_path) + suffix = output_path.suffix.lower() + if suffix not in SUPPORTED_TYPES: + raise ValueError( + f"Unsupported file suffix '{suffix}' in output_path '{output_path}'. " + f"Only {SUPPORTED_TYPES} are supported." + ) + + +def process_dataset( + input_data: Union[str, Path], + output_path: Union[str, Path], + processor: Union[str, Path, PreTrainedTokenizerBase], + processor_args: Optional[dict[str, Any]] = None, + data_args: Optional[dict[str, Any]] = None, + short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE, + pad_token: Optional[str] = None, + prompt_tokens_average: int = 10, + prompt_tokens_stdev: Optional[int] = None, + prompt_tokens_min: Optional[int] = None, + prompt_tokens_max: Optional[int] = None, + prompt_random_seed: int = 42, + output_tokens_average: int = 10, + output_tokens_stdev: Optional[int] = None, + output_tokens_min: Optional[int] = None, + output_tokens_max: Optional[int] = None, + output_random_seed: int = 123, + push_to_hub: bool = False, + hub_dataset_id: Optional[str] = None, +) -> None: + _validate_output_suffix(output_path) + logger.info( + f"Starting dataset conversion | Input: {input_data} | " + f"Output directory: {output_path}" + ) + + dataset, column_mappings = guidellm_load_dataset( + input_data, data_args, processor, processor_args + ) + tokenizer = check_load_processor( + processor, + processor_args, + "dataset conversion.", + ) + prompt_column = column_mappings.get("prompt_column") + output_column = column_mappings.get( + "output_tokens_count_column", "output_tokens_count" + ) + + prompt_token_sampler = iter( + IntegerRangeSampler( + average=prompt_tokens_average, + variance=prompt_tokens_stdev, + min_value=prompt_tokens_min, + max_value=prompt_tokens_max, + random_seed=prompt_random_seed, + ) + ) + + output_token_sampler = iter( + IntegerRangeSampler( + average=output_tokens_average, + variance=output_tokens_stdev, + min_value=output_tokens_min, + max_value=output_tokens_max, + random_seed=output_random_seed, + ) + ) + + dataset_iterator = iter(dataset) + processed_prompts = [] + handler = STRATEGY_HANDLERS[short_prompt_strategy] + + for prompt_row in dataset_iterator: + prompt_text = prompt_row[prompt_column] + target_prompt_len = next(prompt_token_sampler) + + if len(tokenizer.encode(prompt_text)) < target_prompt_len: + prompt_text = handler( + current_prompt=prompt_text, + min_prompt_tokens=target_prompt_len, + dataset_iterator=dataset_iterator, + prompt_column=prompt_column, + tokenizer=tokenizer, + pad_token=pad_token, + ) + if prompt_text is None: + continue + + if len(tokenizer.encode(prompt_text)) > target_prompt_len: + tokens = tokenizer.encode(prompt_text, add_special_tokens=True) + prompt_text = tokenizer.decode( + tokens[:target_prompt_len], skip_special_tokens=True + ) + + processed_prompt = prompt_row.copy() + processed_prompt[prompt_column] = prompt_text + processed_prompt["prompt_tokens_count"] = target_prompt_len + processed_prompt[output_column] = next(output_token_sampler) + + processed_prompts.append(processed_prompt) + + if not processed_prompts: + logger.error("No prompts remained after processing") + return + + logger.info(f"Generated processed dataset with {len(processed_prompts)} prompts") + + processed_dataset = Dataset.from_list(processed_prompts) + save_dataset_to_file(processed_dataset, output_path) + logger.info(f"Conversion complete. Dataset saved to: {output_path}") + + if push_to_hub: + push_dataset_to_hub(hub_dataset_id, processed_dataset) + logger.info(f"Pushed dataset to: {hub_dataset_id}") + + +def push_dataset_to_hub( + hub_dataset_id: Optional[str], processed_dataset: Dataset, +) -> None: + hf_token = os.environ.get("HF_TOKEN") + if not hub_dataset_id or not hf_token: + raise ValueError( + "hub_dataset_id and HF_TOKEN env var must be provided when push_to_hub" + " is True" + ) + processed_dataset.push_to_hub(hub_dataset_id, token=hf_token) diff --git a/tests/unit/preprocess/__init__.py b/tests/unit/preprocess/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/preprocess/test_dataset.py b/tests/unit/preprocess/test_dataset.py new file mode 100644 index 00000000..dd4ab645 --- /dev/null +++ b/tests/unit/preprocess/test_dataset.py @@ -0,0 +1,276 @@ +import os +from pathlib import Path +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +if TYPE_CHECKING: + from collections.abc import Iterator + +import pytest +from datasets import Dataset +from transformers import PreTrainedTokenizerBase + +from guidellm.preprocess.dataset import ( + STRATEGY_HANDLERS, + ShortPromptStrategy, + handle_concatenate_strategy, + handle_ignore_strategy, + handle_pad_strategy, + process_dataset, + push_dataset_to_hub, + save_dataset_to_file, +) + + +@pytest.fixture +def tokenizer_mock(): + tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + tokenizer.encode.side_effect = lambda x: [1] * len(x) + tokenizer.decode.side_effect = lambda x, *args, **kwargs: "".join( + str(item) for item in x + ) + return tokenizer + + +@patch(f"{process_dataset.__module__}.guidellm_load_dataset") +@patch(f"{process_dataset.__module__}.check_load_processor") +@patch(f"{process_dataset.__module__}.Dataset") +@patch(f"{process_dataset.__module__}.IntegerRangeSampler") +def test_strategy_handler_called( + mock_sampler, + mock_dataset_class, + mock_check_processor, + mock_load_dataset, + tokenizer_mock, +): + mock_handler = MagicMock(return_value="processed_prompt") + with patch.dict(STRATEGY_HANDLERS, {ShortPromptStrategy.IGNORE: mock_handler}): + mock_dataset = [{"prompt": "abc"}, {"prompt": "def"}] + mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) + mock_check_processor.return_value = tokenizer_mock + mock_sampler.side_effect = lambda **kwargs: [10, 10] + + mock_dataset_obj = MagicMock(spec=Dataset) + mock_dataset_class.from_list.return_value = mock_dataset_obj + + process_dataset( + "input", + "output_dir/data.json", + tokenizer_mock, + short_prompt_strategy=ShortPromptStrategy.IGNORE, + ) + + assert mock_handler.call_count == 2 + mock_load_dataset.assert_called_once() + mock_check_processor.assert_called_once() + + +def test_handle_ignore_strategy_too_short(tokenizer_mock): + result = handle_ignore_strategy("short", 10, tokenizer_mock) + assert result is None + tokenizer_mock.encode.assert_called_with("short") + + +def test_handle_ignore_strategy_sufficient_length(tokenizer_mock): + result = handle_ignore_strategy("long prompt", 5, tokenizer_mock) + assert result == "long prompt" + tokenizer_mock.encode.assert_called_with("long prompt") + + +def test_handle_concatenate_strategy_enough_prompts(tokenizer_mock): + dataset_iter = iter([{"prompt": "longer"}]) + result = handle_concatenate_strategy( + "short", 10, dataset_iter, "prompt", tokenizer_mock + ) + assert result == "shortlonger" + + +def test_handle_concatenate_strategy_not_enough_prompts(tokenizer_mock): + dataset_iter: Iterator = iter([]) + result = handle_concatenate_strategy( + "short", 10, dataset_iter, "prompt", tokenizer_mock + ) + assert result is None + + +def test_handle_pad_strategy(tokenizer_mock): + result = handle_pad_strategy("short", 10, tokenizer_mock, "p") + assert result == "shortppppp" + + +@patch("guidellm.preprocess.dataset.save_dataset_to_file") +@patch("guidellm.preprocess.dataset.Dataset") +@patch("guidellm.preprocess.dataset.guidellm_load_dataset") +@patch("guidellm.preprocess.dataset.check_load_processor") +@patch("guidellm.preprocess.dataset.IntegerRangeSampler") +def test_process_dataset_non_empty( + mock_sampler, + mock_check_processor, + mock_load_dataset, + mock_dataset_class, + mock_save_to_file, + tokenizer_mock, +): + from guidellm.preprocess.dataset import process_dataset + + mock_dataset = [{"prompt": "Hello"}, {"prompt": "How are you?"}] + mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) + mock_check_processor.return_value = tokenizer_mock + mock_sampler.side_effect = lambda **kwargs: [3, 3, 3] + + mock_dataset_obj = MagicMock(spec=Dataset) + mock_dataset_class.from_list.return_value = mock_dataset_obj + + output_path = "output_dir/data.json" + process_dataset("input", output_path, tokenizer_mock) + + mock_load_dataset.assert_called_once() + mock_check_processor.assert_called_once() + mock_dataset_class.from_list.assert_called_once() + mock_save_to_file.assert_called_once_with(mock_dataset_obj, output_path) + + args, _ = mock_dataset_class.from_list.call_args + processed_list = args[0] + assert len(processed_list) == 2 + for item in processed_list: + assert "prompt" in item + assert "prompt_tokens_count" in item + assert "output_tokens_count" in item + assert len(tokenizer_mock.encode(item["prompt"])) <= 3 + + +@patch(f"{process_dataset.__module__}.Dataset") +@patch(f"{process_dataset.__module__}.guidellm_load_dataset") +@patch(f"{process_dataset.__module__}.check_load_processor") +@patch(f"{process_dataset.__module__}.IntegerRangeSampler") +def test_process_dataset_empty_after_processing( + mock_sampler, + mock_check_processor, + mock_load_dataset, + mock_dataset_class, + tokenizer_mock, +): + mock_dataset = [{"prompt": ""}] + mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) + mock_check_processor.return_value = tokenizer_mock + mock_sampler.side_effect = lambda **kwargs: [10] + + process_dataset("input", "output_dir/data.json", tokenizer_mock) + + mock_load_dataset.assert_called_once() + mock_check_processor.assert_called_once() + mock_dataset_class.from_list.assert_not_called() + + +@patch(f"{process_dataset.__module__}.push_dataset_to_hub") +@patch(f"{process_dataset.__module__}.Dataset") +@patch(f"{process_dataset.__module__}.guidellm_load_dataset") +@patch(f"{process_dataset.__module__}.check_load_processor") +@patch(f"{process_dataset.__module__}.IntegerRangeSampler") +def test_process_dataset_push_to_hub_called( + mock_sampler, + mock_check_processor, + mock_load_dataset, + mock_dataset_class, + mock_push, + tokenizer_mock, +): + mock_dataset = [{"prompt": "abc"}] + mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) + mock_check_processor.return_value = tokenizer_mock + mock_sampler.side_effect = lambda **kwargs: [3] + + mock_dataset_obj = MagicMock(spec=Dataset) + mock_dataset_class.from_list.return_value = mock_dataset_obj + + process_dataset( + "input", + "output_dir/data.json", + tokenizer_mock, + push_to_hub=True, + hub_dataset_id="id123", + ) + mock_push.assert_called_once_with("id123", mock_dataset_obj) + + +@patch(f"{process_dataset.__module__}.push_dataset_to_hub") +@patch(f"{process_dataset.__module__}.Dataset") +@patch(f"{process_dataset.__module__}.guidellm_load_dataset") +@patch(f"{process_dataset.__module__}.check_load_processor") +@patch(f"{process_dataset.__module__}.IntegerRangeSampler") +def test_process_dataset_push_to_hub_not_called( + mock_sampler, + mock_check_processor, + mock_load_dataset, + mock_dataset_class, + mock_push, + tokenizer_mock, +): + mock_dataset = [{"prompt": "abc"}] + mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) + mock_check_processor.return_value = tokenizer_mock + mock_sampler.side_effect = lambda **kwargs: [3] + + mock_dataset_obj = MagicMock(spec=Dataset) + mock_dataset_class.from_list.return_value = mock_dataset_obj + + process_dataset("input", "output_dir/data.json", tokenizer_mock, push_to_hub=False) + mock_push.assert_not_called() + + +def test_push_dataset_to_hub_success(): + os.environ["HF_TOKEN"] = "token" + mock_dataset = MagicMock(spec=Dataset) + push_dataset_to_hub("dataset_id", mock_dataset) + mock_dataset.push_to_hub.assert_called_once_with("dataset_id", token="token") + + +def test_push_dataset_to_hub_error_no_env(): + if "HF_TOKEN" in os.environ: + del os.environ["HF_TOKEN"] + mock_dataset = MagicMock(spec=Dataset) + with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"): + push_dataset_to_hub("dataset_id", mock_dataset) + + +def test_push_dataset_to_hub_error_no_id(): + os.environ["HF_TOKEN"] = "token" + mock_dataset = MagicMock(spec=Dataset) + with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"): + push_dataset_to_hub(None, mock_dataset) + + +@patch.object(Path, "mkdir") +def test_save_dataset_to_file_csv(mock_mkdir): + mock_dataset = MagicMock(spec=Dataset) + output_path = Path("some/path/output.csv") + save_dataset_to_file(mock_dataset, output_path) + mock_dataset.to_csv.assert_called_once_with(str(output_path)) + mock_mkdir.assert_called_once() + + +@patch.object(Path, "mkdir") +def test_save_dataset_to_file_json(mock_mkdir): + mock_dataset = MagicMock(spec=Dataset) + output_path = Path("some/path/output.json") + save_dataset_to_file(mock_dataset, output_path) + mock_dataset.to_json.assert_called_once_with(str(output_path)) + mock_mkdir.assert_called_once() + + +@patch.object(Path, "mkdir") +def test_save_dataset_to_file_parquet(mock_mkdir): + mock_dataset = MagicMock(spec=Dataset) + output_path = Path("some/path/output.parquet") + save_dataset_to_file(mock_dataset, output_path) + mock_dataset.to_parquet.assert_called_once_with(str(output_path)) + mock_mkdir.assert_called_once() + + +@patch.object(Path, "mkdir") +def test_save_dataset_to_file_unsupported_type(mock_mkdir): + mock_dataset = MagicMock(spec=Dataset) + output_path = Path("some/path/output.txt") + with pytest.raises(ValueError, match=r"Unsupported file suffix '.txt'.*"): + save_dataset_to_file(mock_dataset, output_path) + mock_mkdir.assert_called_once()