Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
try:
import uvloop
except ImportError:
uvloop = None # type: ignore[assignment] # Optional dependency
uvloop = None # type: ignore[assignment] # Optional dependency

from guidellm.backends import BackendType
from guidellm.benchmark import (
Expand Down Expand Up @@ -116,6 +116,7 @@ def benchmark():
)
@click.option(
"--scenario",
"-c",
type=cli_tools.Union(
click.Path(
exists=True,
Expand Down Expand Up @@ -392,8 +393,10 @@ def run(**kwargs):
disable_progress = kwargs.pop("disable_progress", False)

try:
# Only set CLI args that differ from click defaults
new_kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)
args = BenchmarkGenerativeTextArgs.create(
scenario=kwargs.pop("scenario", None), **kwargs
scenario=new_kwargs.pop("scenario", None), **new_kwargs
)
except ValidationError as err:
# Translate pydantic valdation error to click argument error
Expand Down
39 changes: 33 additions & 6 deletions src/guidellm/benchmark/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@
from typing import Any, ClassVar, Literal, TypeVar, cast

import yaml
from pydantic import ConfigDict, Field, computed_field, model_serializer
from pydantic import (
ConfigDict,
Field,
ValidationError,
ValidatorFunctionWrapHandler,
computed_field,
field_validator,
model_serializer,
)
from torch.utils.data import Sampler
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -1142,7 +1150,8 @@ def update_estimate(
)
request_duration = (
(request_end_time - request_start_time)
if request_end_time and request_start_time else None
if request_end_time and request_start_time
else None
)

# Always track concurrency
Expand Down Expand Up @@ -1669,7 +1678,7 @@ def compile(
estimated_state: EstimatedBenchmarkState,
scheduler_state: SchedulerState,
profile: Profile,
requests: Iterable,
requests: Iterable, # noqa: ARG003
backend: BackendInterface,
environment: Environment,
strategy: SchedulingStrategy,
Expand Down Expand Up @@ -1818,8 +1827,6 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
else:
return factory({}) # type: ignore[call-arg] # Confirmed correct at runtime by code above



model_config = ConfigDict(
extra="ignore",
use_enum_values=True,
Expand All @@ -1838,7 +1845,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
profile: StrategyType | ProfileType | Profile = Field(
default="sweep", description="Benchmark profile or scheduling strategy type"
)
rate: float | list[float] | None = Field(
rate: list[float] | None = Field(
default=None, description="Request rate(s) for rate-based scheduling"
)
# Backend configuration
Expand Down Expand Up @@ -1931,6 +1938,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
default=None, description="Maximum global error rate (0-1) before stopping"
)

@field_validator("data", "data_args", "rate", mode="wrap")
@classmethod
def single_to_list(
cls, value: Any, handler: ValidatorFunctionWrapHandler
) -> list[Any]:
"""
Ensures field is always a list.

:param value: Input value for the 'data' field
:return: List of data sources
"""
try:
return handler(value)
except ValidationError as err:
# If validation fails, try wrapping the value in a list
if err.errors()[0]["type"] == "list_type":
return handler([value])
else:
raise

@model_serializer
def serialize_model(self):
"""
Expand Down
15 changes: 14 additions & 1 deletion src/guidellm/data/deserializers/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import yaml
from datasets import Features, IterableDataset, Value
from faker import Faker
from pydantic import ConfigDict, Field, model_validator
from pydantic import ConfigDict, Field, ValidationError, model_validator
from transformers import PreTrainedTokenizerBase

from guidellm.data.deserializers.deserializer import (
Expand Down Expand Up @@ -242,6 +242,10 @@ def __call__(
if (config := self._load_config_str(data)) is not None:
return self(config, processor_factory, random_seed, **data_kwargs)

# Try to parse dict-like data directly
if (config := self._load_config_dict(data)) is not None:
return self(config, processor_factory, random_seed, **data_kwargs)

if not isinstance(data, SyntheticTextDatasetConfig):
raise DataNotSupportedError(
"Unsupported data for SyntheticTextDatasetDeserializer, "
Expand All @@ -266,6 +270,15 @@ def __call__(
),
)

def _load_config_dict(self, data: Any) -> SyntheticTextDatasetConfig | None:
if not isinstance(data, dict | list):
return None

try:
return SyntheticTextDatasetConfig.model_validate(data)
except ValidationError:
return None

def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None:
if (not isinstance(data, str) and not isinstance(data, Path)) or (
not Path(data).is_file()
Expand Down
Loading