Skip to content

Feat/add preprocess dataset #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from guidellm.backend import BackendType
from guidellm.benchmark import ProfileType, benchmark_generative_text
from guidellm.config import print_config
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
from guidellm.scheduler import StrategyType

STRATEGY_PROFILE_CHOICES = set(
Expand Down Expand Up @@ -290,5 +291,187 @@ def config():
print_config()


@cli.group(help="Preprocessing utilities for datasets.")
def preprocess():
pass


@preprocess.command(
help="Convert a dataset to have specific prompt and output token sizes.\n\n"
"INPUT_DATA: Path to the input dataset or dataset ID.\n"
"OUTPUT_PATH: Directory to save the converted dataset. "
"The dataset will be saved as an Arrow dataset (.arrow) inside the directory."
)
@click.argument(
"input_data",
type=str,
metavar="INPUT_DATA",
required=True,
)
@click.argument(
"output_path",
type=click.Path(file_okay=True, dir_okay=False, writable=True, resolve_path=True),
metavar="OUTPUT_PATH",
required=True,
)
@click.option(
"--processor",
type=str,
required=True,
help=(
"The processor or tokenizer to use to calculate token counts for statistics "
"and synthetic data generation."
),
)
@click.option(
"--processor-args",
default=None,
callback=parse_json,
help=(
"A JSON string containing any arguments to pass to the processor constructor "
"as a dict with **kwargs."
),
)
@click.option(
"--data-args",
callback=parse_json,
help=(
"A JSON string containing any arguments to pass to the dataset creation "
"as a dict with **kwargs."
),
)
@click.option(
"--short-prompt-strategy",
type=click.Choice([s.value for s in ShortPromptStrategy]),
default=ShortPromptStrategy.IGNORE.value,
show_default=True,
help="Strategy to handle prompts shorter than the target length. ",
)
@click.option(
"--pad-token",
type=str,
default=None,
help="The token to pad short prompts with when using the 'pad' strategy.",
)
@click.option(
"--prompt-tokens-average",
type=int,
default=10,
show_default=True,
help="Average target number of tokens for prompts.",
)
@click.option(
"--prompt-tokens-stdev",
type=int,
default=None,
help="Standard deviation for prompt tokens sampling.",
)
@click.option(
"--prompt-tokens-min",
type=int,
default=None,
help="Minimum number of prompt tokens.",
)
@click.option(
"--prompt-tokens-max",
type=int,
default=None,
help="Maximum number of prompt tokens.",
)
@click.option(
"--prompt-random-seed",
type=int,
default=42,
show_default=True,
help="Random seed for prompt token sampling.",
)
@click.option(
"--output-tokens-average",
type=int,
default=10,
show_default=True,
help="Average target number of tokens for outputs.",
)
@click.option(
"--output-tokens-stdev",
type=int,
default=None,
help="Standard deviation for output tokens sampling.",
)
@click.option(
"--output-tokens-min",
type=int,
default=None,
help="Minimum number of output tokens.",
)
@click.option(
"--output-tokens-max",
type=int,
default=None,
help="Maximum number of output tokens.",
)
@click.option(
"--output-random-seed",
type=int,
default=123,
show_default=True,
help="Random seed for output token sampling.",
)
@click.option(
"--push-to-hub",
is_flag=True,
help="Set this flag to push the converted dataset to the Hugging Face Hub.",
)
@click.option(
"--hub-dataset-id",
type=str,
default=None,
help="The Hugging Face Hub dataset ID to push to. "
"Required if --push-to-hub is used.",
)
def dataset(
input_data,
output_path,
processor,
processor_args,
data_args,
short_prompt_strategy,
pad_token,
prompt_tokens_average,
prompt_tokens_stdev,
prompt_tokens_min,
prompt_tokens_max,
prompt_random_seed,
output_tokens_average,
output_tokens_stdev,
output_tokens_min,
output_tokens_max,
output_random_seed,
push_to_hub,
hub_dataset_id,
):
process_dataset(
input_data=input_data,
output_path=output_path,
processor=processor,
processor_args=processor_args,
data_args=data_args,
short_prompt_strategy=short_prompt_strategy,
pad_token=pad_token,
prompt_tokens_average=prompt_tokens_average,
prompt_tokens_stdev=prompt_tokens_stdev,
prompt_tokens_min=prompt_tokens_min,
prompt_tokens_max=prompt_tokens_max,
prompt_random_seed=prompt_random_seed,
output_tokens_average=output_tokens_average,
output_tokens_stdev=output_tokens_stdev,
output_tokens_min=output_tokens_min,
output_tokens_max=output_tokens_max,
output_random_seed=output_random_seed,
push_to_hub=push_to_hub,
hub_dataset_id=hub_dataset_id,
)


if __name__ == "__main__":
cli()
3 changes: 3 additions & 0 deletions src/guidellm/preprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dataset import ShortPromptStrategy, process_dataset

__all__ = ["ShortPromptStrategy", "process_dataset"]
Loading
Loading