diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 981394d7cf..9e9ae9feec 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -86,5 +86,7 @@ title: Webhooks server - local: package_reference/serialization title: Serialization + - local: package_reference/dataclasses + title: Strict dataclasses - local: package_reference/oauth title: OAuth diff --git a/docs/source/en/package_reference/dataclasses.md b/docs/source/en/package_reference/dataclasses.md new file mode 100644 index 0000000000..79d57c84ca --- /dev/null +++ b/docs/source/en/package_reference/dataclasses.md @@ -0,0 +1,220 @@ +# Strict Dataclasses + +The `huggingface_hub` package provides a utility to create **strict dataclasses**. These are enhanced versions of Python's standard `dataclass` with additional validation features. Strict dataclasses ensure that fields are validated both during initialization and assignment, making them ideal for scenarios where data integrity is critical. + +## Overview + +Strict dataclasses are created using the `@strict` decorator. They extend the functionality of regular dataclasses by: + +- Validating field types based on type hints +- Supporting custom validators for additional checks +- Optionally allowing arbitrary keyword arguments in the constructor +- Validating fields both at initialization and during assignment + +## Benefits + +- **Data Integrity**: Ensures fields always contain valid data +- **Ease of Use**: Integrates seamlessly with Python's `dataclass` module +- **Flexibility**: Supports custom validators for complex validation logic +- **Lightweight**: Requires no additional dependencies such as Pydantic, attrs, or similar libraries + +## Usage + +### Basic Example + +```python +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict, as_validated_field + +# Custom validator to ensure a value is positive +@as_validated_field +def positive_int(value: int): + if not value > 0: + raise ValueError(f"Value must be positive, got {value}") + +@strict +@dataclass +class Config: + model_type: str + hidden_size: int = positive_int(default=16) + vocab_size: int = 32 # Default value + + # Methods named `validate_xxx` are treated as class-wise validators + def validate_big_enough_vocab(self): + if self.vocab_size < self.hidden_size: + raise ValueError(f"vocab_size ({self.vocab_size}) must be greater than hidden_size ({self.hidden_size})") +``` + +Fields are validated during initialization: + +```python +config = Config(model_type="bert", hidden_size=24) # Valid +config = Config(model_type="bert", hidden_size=-1) # Raises StrictDataclassFieldValidationError +``` + +Consistency between fields is also validated during initialization (class-wise validation): + +```python +# `vocab_size` too small compared to `hidden_size` +config = Config(model_type="bert", hidden_size=32, vocab_size=16) # Raises StrictDataclassClassValidationError +``` + +Fields are also validated during assignment: + +```python +config.hidden_size = 512 # Valid +config.hidden_size = -1 # Raises StrictDataclassFieldValidationError +``` + +To re-run class-wide validation after assignment, you must call `.validate` explicitly: + +```python +config.validate() # Runs all class validators +``` + +### Custom Validators + +You can attach multiple custom validators to fields using [`validated_field`]. A validator is a callable that takes a single argument and raises an exception if the value is invalid. + +```python +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict, validated_field + +def multiple_of_64(value: int): + if value % 64 != 0: + raise ValueError(f"Value must be a multiple of 64, got {value}") + +@strict +@dataclass +class Config: + hidden_size: int = validated_field(validator=[positive_int, multiple_of_64]) +``` + +In this example, both validators are applied to the `hidden_size` field. + +### Additional Keyword Arguments + +By default, strict dataclasses only accept fields defined in the class. You can allow additional keyword arguments by setting `accept_kwargs=True` in the `@strict` decorator. + +```python +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict(accept_kwargs=True) +@dataclass +class ConfigWithKwargs: + model_type: str + vocab_size: int = 16 + +config = ConfigWithKwargs(model_type="bert", vocab_size=30000, extra_field="extra_value") +print(config) # ConfigWithKwargs(model_type='bert', vocab_size=30000, *extra_field='extra_value') +``` + +Additional keyword arguments appear in the string representation of the dataclass but are prefixed with `*` to highlight that they are not validated. + +### Integration with Type Hints + +Strict dataclasses respect type hints and validate them automatically. For example: + +```python +from typing import List +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict +@dataclass +class Config: + layers: List[int] + +config = Config(layers=[64, 128]) # Valid +config = Config(layers="not_a_list") # Raises StrictDataclassFieldValidationError +``` + +Supported types include: +- Any +- Union +- Optional +- Literal +- List +- Dict +- Tuple +- Set + +And any combination of these types. If your need more complex type validation, you can do it through a custom validator. + +### Class validators + +Methods named `validate_xxx` are treated as class validators. These methods must only take `self` as an argument. Class validators are run once during initialization, right after `__post_init__`. You can define as many of them as needed—they'll be executed sequentially in the order they appear. + +Note that class validators are not automatically re-run when a field is updated after initialization. To manually re-validate the object, you need to call `obj.validate()`. + +```py +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict +@dataclass +class Config: + foo: str + foo_length: int + upper_case: bool = False + + def validate_foo_length(self): + if len(self.foo) != self.foo_length: + raise ValueError(f"foo must be {self.foo_length} characters long, got {len(self.foo)}") + + def validate_foo_casing(self): + if self.upper_case and self.foo.upper() != self.foo: + raise ValueError(f"foo must be uppercase, got {self.foo}") + +config = Config(foo="bar", foo_length=3) # ok + +config.upper_case = True +config.validate() # Raises StrictDataclassClassValidationError + +Config(foo="abcd", foo_length=3) # Raises StrictDataclassFieldValidationError +Config(foo="Bar", foo_length=3, upper_case=True) # Raises StrictDataclassFieldValidationError +``` + + + +Method `.validate()` is a reserved name on strict dataclasses. +To prevent unexpected behaviors, a [`StrictDataclassDefinitionError`] error will be raised if your class already defines one. + + + +## API Reference + +### `@strict` + +The `@strict` decorator enhances a dataclass with strict validation. + +[[autodoc]] dataclasses.strict + +### `as_validated_field` + +Decorator to create a [`validated_field`]. Recommended for fields with a single validator to avoid boilerplate code. + +[[autodoc]] dataclasses.as_validated_field + +### `validated_field` + +Creates a dataclass field with custom validation. + +[[autodoc]] dataclasses.validated_field + +### Errors + +[[autodoc]] errors.StrictDataclassError + +[[autodoc]] errors.StrictDataclassDefinitionError + +[[autodoc]] errors.StrictDataclassFieldValidationError + +## Why Not Use `pydantic`? (or `attrs`? or `marshmallow_dataclass`?) + +- See discussion in https://github.com/huggingface/transformers/issues/36329 regarding adding Pydantic as a dependency. It would be a heavy addition and require careful logic to support both v1 and v2. +- We don't need most of Pydantic's features, especially those related to automatic casting, jsonschema, serialization, aliases, etc. +- We don't need the ability to instantiate a class from a dictionary. +- We don't want to mutate data. In `@strict`, "validation" means "checking if a value is valid." In Pydantic, "validation" means "casting a value, possibly mutating it, and then checking if it's valid." +- We don't need blazing-fast validation. `@strict` isn't designed for heavy loads where performance is critical. Common use cases involve validating a model configuration (performed once and negligible compared to running a model). This allows us to keep the code minimal. \ No newline at end of file diff --git a/src/huggingface_hub/dataclasses.py b/src/huggingface_hub/dataclasses.py new file mode 100644 index 0000000000..c5f8c7a3ea --- /dev/null +++ b/src/huggingface_hub/dataclasses.py @@ -0,0 +1,481 @@ +import inspect +from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields +from functools import wraps +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, + overload, +) + +from .errors import ( + StrictDataclassClassValidationError, + StrictDataclassDefinitionError, + StrictDataclassFieldValidationError, +) + + +Validator_T = Callable[[Any], None] +T = TypeVar("T") + + +# The overload decorator helps type checkers understand the different return types +@overload +def strict(cls: Type[T]) -> Type[T]: ... + + +@overload +def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ... + + +def strict( + cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False +) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: + """ + Decorator to add strict validation to a dataclass. + + This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools + recognize the class as a dataclass. + + Can be used with or without arguments: + - `@strict` + - `@strict(accept_kwargs=True)` + + Args: + cls: + The class to convert to a strict dataclass. + accept_kwargs (`bool`, *optional*): + If True, allows arbitrary keyword arguments in `__init__`. Defaults to False. + + Returns: + The enhanced dataclass with strict validation on field assignment. + + Example: + ```py + >>> from dataclasses import dataclass + >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field + + >>> @as_validated_field + >>> def positive_int(value: int): + ... if not value >= 0: + ... raise ValueError(f"Value must be positive, got {value}") + + >>> @strict(accept_kwargs=True) + ... @dataclass + ... class User: + ... name: str + ... age: int = positive_int(default=10) + + # Initialize + >>> User(name="John") + User(name='John', age=10) + + # Extra kwargs are accepted + >>> User(name="John", age=30, lastname="Doe") + User(name='John', age=30, *lastname='Doe') + + # Invalid type => raises + >>> User(name="John", age="30") + huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': + TypeError: Field 'age' expected int, got str (value: '30') + + # Invalid value => raises + >>> User(name="John", age=-1) + huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': + ValueError: Value must be positive, got -1 + ``` + """ + + def wrap(cls: Type[T]) -> Type[T]: + if not hasattr(cls, "__dataclass_fields__"): + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' must be a dataclass before applying @strict." + ) + + # List and store validators + field_validators: Dict[str, List[Validator_T]] = {} + for f in fields(cls): # type: ignore [arg-type] + validators = [] + validators.append(_create_type_validator(f)) + custom_validator = f.metadata.get("validator") + if custom_validator is not None: + if not isinstance(custom_validator, list): + custom_validator = [custom_validator] + for validator in custom_validator: + if not _is_validator(validator): + raise StrictDataclassDefinitionError( + f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument." + ) + validators.extend(custom_validator) + field_validators[f.name] = validators + cls.__validators__ = field_validators # type: ignore + + # Override __setattr__ to validate fields on assignment + original_setattr = cls.__setattr__ + + def __strict_setattr__(self: Any, name: str, value: Any) -> None: + """Custom __setattr__ method for strict dataclasses.""" + # Run all validators + for validator in self.__validators__.get(name, []): + try: + validator(value) + except (ValueError, TypeError) as e: + raise StrictDataclassFieldValidationError(field=name, cause=e) from e + + # If validation passed, set the attribute + original_setattr(self, name, value) + + cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign] + + if accept_kwargs: + # (optional) Override __init__ to accept arbitrary keyword arguments + original_init = cls.__init__ + + @wraps(original_init) + def __init__(self, **kwargs: Any) -> None: + # Extract only the fields that are part of the dataclass + dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type] + standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields} + + # Call the original __init__ with standard fields + original_init(self, **standard_kwargs) + + # Add any additional kwargs as attributes + for name, value in kwargs.items(): + if name not in dataclass_fields: + self.__setattr__(name, value) + + cls.__init__ = __init__ # type: ignore[method-assign] + + # (optional) Override __repr__ to include additional kwargs + original_repr = cls.__repr__ + + @wraps(original_repr) + def __repr__(self) -> str: + # Call the original __repr__ to get the standard fields + standard_repr = original_repr(self) + + # Get additional kwargs + additional_kwargs = [ + # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass + f"*{k}={v!r}" + for k, v in self.__dict__.items() + if k not in cls.__dataclass_fields__ # type: ignore [attr-defined] + ] + additional_repr = ", ".join(additional_kwargs) + + # Combine both representations + return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr + + cls.__repr__ = __repr__ # type: ignore [method-assign] + + # List all public methods starting with `validate_` => class validators. + class_validators = [] + + for name in dir(cls): + if not name.startswith("validate_"): + continue + method = getattr(cls, name) + if not callable(method): + continue + if len(inspect.signature(method).parameters) != 1: + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument." + " Class validators must take only 'self' as an argument. Methods starting with 'validate_'" + " are considered to be class validators." + ) + class_validators.append(method) + + cls.__class_validators__ = class_validators # type: ignore [attr-defined] + + # Add `validate` method to the class, but first check if it already exists + def validate(self: T) -> None: + """Run class validators on the instance.""" + for validator in cls.__class_validators__: # type: ignore [attr-defined] + try: + validator(self) + except (ValueError, TypeError) as e: + raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e + + # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class + # (in which case we just override it) + validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined] + + if hasattr(cls, "validate"): + if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined] + raise StrictDataclassDefinitionError( + f"Class '{cls.__name__}' already implements a method called 'validate'." + " This method name is reserved when using the @strict decorator on a dataclass." + " If you want to keep your own method, please rename it." + ) + + cls.validate = validate # type: ignore + + # Run class validators after initialization + initial_init = cls.__init__ + + @wraps(initial_init) + def init_with_validate(self, *args, **kwargs) -> None: + """Run class validators after initialization.""" + initial_init(self, *args, **kwargs) # type: ignore [call-arg] + cls.validate(self) # type: ignore [attr-defined] + + setattr(cls, "__init__", init_with_validate) + + return cls + + # Return wrapped class or the decorator itself + return wrap(cls) if cls is not None else wrap + + +def validated_field( + validator: Union[List[Validator_T], Validator_T], + default: Union[Any, _MISSING_TYPE] = MISSING, + default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Dict] = None, + **kwargs: Any, +) -> Any: + """ + Create a dataclass field with a custom validator. + + Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator. + + Args: + validator (`Callable` or `List[Callable]`): + A method that takes a value as input and raises ValueError/TypeError if the value is invalid. + Can be a list of validators to apply multiple checks. + **kwargs: + Additional arguments to pass to `dataclasses.field()`. + + Returns: + A field with the validator attached in metadata + """ + if not isinstance(validator, list): + validator = [validator] + if metadata is None: + metadata = {} + metadata["validator"] = validator + return field( # type: ignore + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + **kwargs, + ) + + +def as_validated_field(validator: Validator_T): + """ + Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator). + + Args: + validator (`Callable`): + A method that takes a value as input and raises ValueError/TypeError if the value is invalid. + """ + + def _inner( + default: Union[Any, _MISSING_TYPE] = MISSING, + default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Dict] = None, + **kwargs: Any, + ): + return validated_field( + validator, + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + **kwargs, + ) + + return _inner + + +def type_validator(name: str, value: Any, expected_type: Any) -> None: + """Validate that 'value' matches 'expected_type'.""" + origin = get_origin(expected_type) + args = get_args(expected_type) + + if expected_type is Any: + return + elif validator := _BASIC_TYPE_VALIDATORS.get(origin): + validator(name, value, args) + elif isinstance(expected_type, type): # simple types + _validate_simple_type(name, value, expected_type) + else: + raise TypeError(f"Unsupported type for field '{name}': {expected_type}") + + +def _validate_union(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate that value matches one of the types in a Union.""" + errors = [] + for t in args: + try: + type_validator(name, value, t) + return # Valid if any type matches + except TypeError as e: + errors.append(str(e)) + + raise TypeError( + f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}" + ) + + +def _validate_literal(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate Literal type.""" + if value not in args: + raise TypeError(f"Field '{name}' expected one of {args}, got {value}") + + +def _validate_list(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate List[T] type.""" + if not isinstance(value, list): + raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}") + + # Validate each item in the list + item_type = args[0] + for i, item in enumerate(value): + try: + type_validator(f"{name}[{i}]", item, item_type) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in list '{name}'") from e + + +def _validate_dict(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate Dict[K, V] type.""" + if not isinstance(value, dict): + raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}") + + # Validate keys and values + key_type, value_type = args + for k, v in value.items(): + try: + type_validator(f"{name}.key", k, key_type) + type_validator(f"{name}[{k!r}]", v, value_type) + except TypeError as e: + raise TypeError(f"Invalid key or value in dict '{name}'") from e + + +def _validate_tuple(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate Tuple type.""" + if not isinstance(value, tuple): + raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}") + + # Handle variable-length tuples: Tuple[T, ...] + if len(args) == 2 and args[1] is Ellipsis: + for i, item in enumerate(value): + try: + type_validator(f"{name}[{i}]", item, args[0]) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e + # Handle fixed-length tuples: Tuple[T1, T2, ...] + elif len(args) != len(value): + raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}") + else: + for i, (item, expected) in enumerate(zip(value, args)): + try: + type_validator(f"{name}[{i}]", item, expected) + except TypeError as e: + raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e + + +def _validate_set(name: str, value: Any, args: Tuple[Any, ...]) -> None: + """Validate Set[T] type.""" + if not isinstance(value, set): + raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}") + + # Validate each item in the set + item_type = args[0] + for i, item in enumerate(value): + try: + type_validator(f"{name} item", item, item_type) + except TypeError as e: + raise TypeError(f"Invalid item in set '{name}'") from e + + +def _validate_simple_type(name: str, value: Any, expected_type: type) -> None: + """Validate simple type (int, str, etc.).""" + if not isinstance(value, expected_type): + raise TypeError( + f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})" + ) + + +def _create_type_validator(field: Field) -> Validator_T: + """Create a type validator function for a field.""" + # Hacky: we cannot use a lambda here because of reference issues + + def validator(value: Any) -> None: + type_validator(field.name, value, field.type) + + return validator + + +def _is_validator(validator: Any) -> bool: + """Check if a function is a validator. + + A validator is a Callable that can be called with a single positional argument. + The validator can have more arguments with default values. + + Basically, returns True if `validator(value)` is possible. + """ + if not callable(validator): + return False + + signature = inspect.signature(validator) + parameters = list(signature.parameters.values()) + if len(parameters) == 0: + return False + if parameters[0].kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.VAR_POSITIONAL, + ): + return False + for parameter in parameters[1:]: + if parameter.default == inspect.Parameter.empty: + return False + return True + + +_BASIC_TYPE_VALIDATORS = { + Union: _validate_union, + Literal: _validate_literal, + list: _validate_list, + dict: _validate_dict, + tuple: _validate_tuple, + set: _validate_set, +} + + +__all__ = [ + "strict", + "validated_field", + "Validator_T", + "StrictDataclassClassValidationError", + "StrictDataclassDefinitionError", + "StrictDataclassFieldValidationError", +] diff --git a/src/huggingface_hub/errors.py b/src/huggingface_hub/errors.py index 7b09e180bf..a0f7ed80e3 100644 --- a/src/huggingface_hub/errors.py +++ b/src/huggingface_hub/errors.py @@ -329,6 +329,35 @@ class DDUFInvalidEntryNameError(DDUFExportError): """Exception thrown when the entry name is invalid.""" +# STRICT DATACLASSES ERRORS + + +class StrictDataclassError(Exception): + """Base exception for strict dataclasses.""" + + +class StrictDataclassDefinitionError(StrictDataclassError): + """Exception thrown when a strict dataclass is defined incorrectly.""" + + +class StrictDataclassFieldValidationError(StrictDataclassError): + """Exception thrown when a strict dataclass fails validation for a given field.""" + + def __init__(self, field: str, cause: Exception): + error_message = f"Validation error for field '{field}':" + error_message += f"\n {cause.__class__.__name__}: {cause}" + super().__init__(error_message) + + +class StrictDataclassClassValidationError(StrictDataclassError): + """Exception thrown when a strict dataclass fails validation on a class validator.""" + + def __init__(self, validator: str, cause: Exception): + error_message = f"Class validation error for validator '{validator}':" + error_message += f"\n {cause.__class__.__name__}: {cause}" + super().__init__(error_message) + + # XET ERRORS diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index ab6f90b157..992eac104b 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -14,7 +14,6 @@ # limitations under the License # ruff: noqa: F401 - from huggingface_hub.errors import ( BadRequestError, CacheNotFound, diff --git a/tests/test_utils_strict_dataclass.py b/tests/test_utils_strict_dataclass.py new file mode 100644 index 0000000000..4a4cd6d56c --- /dev/null +++ b/tests/test_utils_strict_dataclass.py @@ -0,0 +1,604 @@ +import inspect +from dataclasses import asdict, astuple, dataclass, is_dataclass +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints + +import jedi +import pytest + +from huggingface_hub.dataclasses import _is_validator, as_validated_field, strict, type_validator, validated_field +from huggingface_hub.errors import ( + StrictDataclassClassValidationError, + StrictDataclassDefinitionError, + StrictDataclassFieldValidationError, +) + + +def positive_int(value: int): + if not value >= 0: + raise ValueError(f"Value must be positive, got {value}") + + +def multiple_of_64(value: int): + if not value % 64 == 0: + raise ValueError(f"Value must be a multiple of 64, got {value}") + + +@as_validated_field +def strictly_positive(value: int): + if not value > 0: + raise ValueError(f"Value must be strictly positive, got {value}") + + +@strict +@dataclass +class Config: + model_type: str + hidden_size: int = validated_field(validator=[positive_int, multiple_of_64]) + vocab_size: int = strictly_positive(default=16) + + +@strict(accept_kwargs=True) +@dataclass +class ConfigWithKwargs: + model_type: str + vocab_size: int = validated_field(validator=positive_int, default=16) + + +class DummyClass: + pass + + +def test_valid_initialization(): + config = Config(model_type="bert", vocab_size=30000, hidden_size=768) + assert config.model_type == "bert" + assert config.vocab_size == 30000 + assert config.hidden_size == 768 + + +def test_default_values(): + config = Config(model_type="bert", hidden_size=1024) + assert config.model_type == "bert" + assert config.vocab_size == 16 + assert config.hidden_size == 1024 + + +def test_invalid_type_initialization(): + with pytest.raises(StrictDataclassFieldValidationError): + Config(model_type={"type": "bert"}, vocab_size=30000, hidden_size=768) + + with pytest.raises(StrictDataclassFieldValidationError): + Config(model_type="bert", vocab_size="30000", hidden_size=768) + + +def test_all_validators_are_applied(): + # must be positive + with pytest.raises(StrictDataclassFieldValidationError): + Config(model_type="bert", vocab_size=-1, hidden_size=1024) + + # must be a multiple of 64 + with pytest.raises(StrictDataclassFieldValidationError): + Config(model_type="bert", hidden_size=1025) + + # both validators are applied + with pytest.raises(StrictDataclassFieldValidationError): + Config(model_type="bert", hidden_size=-1024) + + +def test_validated_on_assignment(): + config = Config(model_type="bert", hidden_size=1024) + config.vocab_size = 10000 # ok + with pytest.raises(StrictDataclassFieldValidationError): + config.vocab_size = "10000" # type validator checked + with pytest.raises(StrictDataclassFieldValidationError): + config.vocab_size = -1 # custom validators checked + with pytest.raises(StrictDataclassFieldValidationError): + config.vocab_size = 0 # must be strictly positive + + +def test_lax_on_new_attributes(): + config = Config(model_type="bert", hidden_size=1024) + config.new_attribute = "new_value" + assert config.new_attribute == "new_value" # not validated + + +def test_custom_validator_must_be_callable(): + """Must raise at class definition time.""" + with pytest.raises(StrictDataclassDefinitionError): + + @strict + @dataclass + class Config: + model_type: str = validated_field(validator="not_a_function") + + with pytest.raises(StrictDataclassDefinitionError): + + @strict + @dataclass + class Config: + model_type: str = validated_field(validator=lambda: None) # not a validator either + + +@pytest.mark.parametrize( + "value, type_annotation", + [ + # Basic types + (5, int), + (5.0, float), + ("John", str), + # Union types + (5, Union[int, str]), + ("John", Union[int, str]), + # Optional + (5, Optional[int]), + (None, Optional[int]), + (DummyClass(), Optional[DummyClass]), + # Literal + ("John", Literal["John", "Doe"]), + (5, Literal[4, 5, 6]), + # List + ([1, 2, 3], List[int]), + ([1, 2, "3"], List[Union[int, str]]), + # Tuple + ((1, 2, 3), Tuple[int, int, int]), + ((1, 2, "3"), Tuple[int, int, str]), + ((1, 2, 3, 4), Tuple[int, ...]), + # Dict + ({"a": 1, "b": 2}, Dict[str, int]), + ({"a": 1, "b": "2"}, Dict[str, Union[int, str]]), + # Set + ({1, 2, 3}, Set[int]), + ({1, 2, "3"}, Set[Union[int, str]]), + # Custom classes + (DummyClass(), DummyClass), + # Any + (5, Any), + ("John", Any), + (DummyClass(), Any), + # Deep nested type + ( + { + "a": [ + (1, DummyClass(), {1, "2", "3", 4}), + (2, DummyClass(), None), + ], + }, + Dict[ + str, + List[ + Tuple[ + int, + DummyClass, + Optional[Set[Union[int, str],]], + ] + ], + ], + ), + ], +) +def test_type_validator_valid(value, type_annotation): + type_validator("dummy", value, type_annotation) + + +@pytest.mark.parametrize( + "value, type_annotation", + [ + # Basic types + (5, float), + (5.0, int), + ("John", int), + # Union types + (5.0, Union[int, str]), + (None, Union[int, str]), + (DummyClass(), Union[int, str]), + # Optional + ("John", Optional[int]), + (DummyClass(), Optional[int]), + # Literal + ("Ada", Literal["John", "Doe"]), + (3, Literal[4, 5, 6]), + # List + (5, List[int]), + ([1, 2, "3"], List[int]), + # Tuple + (5, Tuple[int, int, int]), + ((1, 2, "3"), Tuple[int, int, int]), + ((1, 2, 3, 4), Tuple[int, int, int]), + ((1, 2, "3", 4), Tuple[int, ...]), + # Dict + (5, Dict[str, int]), + ({"a": 1, "b": "2"}, Dict[str, int]), + # Set + (5, Set[int]), + ({1, 2, "3"}, Set[int]), + # Custom classes + (5, DummyClass), + ("John", DummyClass), + ], +) +def test_type_validator_invalid(value, type_annotation): + with pytest.raises(TypeError): + type_validator("dummy", value, type_annotation) + + +class DummyValidator: + def __init__(self, threshold): + self.threshold = threshold + + def __call__(self, value): + return value < self.threshold + + def compare(self, value, value2=10): + return value < value2 + + +@pytest.mark.parametrize( + "obj", + [ + positive_int, + multiple_of_64, + lambda value: None, + lambda value, factor=2: None, + lambda value=1, factor=2: value * factor, + lambda *values: None, + DummyValidator(threshold=10), # callable object + DummyValidator(threshold=10).compare, # callable method + ], +) +def test_is_validator(obj): + # Anything that can be called with `obj(value)` is a correct validator. + assert _is_validator(obj) + + +@pytest.mark.parametrize( + "obj", + [ + 5, # not callable + lambda: None, # no argument + lambda value1, value2: None, # more than one argument with default values + lambda *, value: None, # keyword-only argument + ], +) +def test_not_a_validator(obj): + assert not _is_validator(obj) + + +def test_preserve_metadata(): + class ConfigWithMetadataField: + foo: int = strictly_positive(metadata={"foo": "bar"}, default=10) + + assert ConfigWithMetadataField.foo.metadata["foo"] == "bar" + + +def test_accept_kwargs(): + config = ConfigWithKwargs(model_type="bert", vocab_size=30000, hidden_size=768) + assert config.model_type == "bert" + assert config.vocab_size == 30000 + assert config.hidden_size == 768 + + # Defined fields are still validated + with pytest.raises(StrictDataclassFieldValidationError): + ConfigWithKwargs(model_type="bert", vocab_size=-1) + + # Default values are still used + config = ConfigWithKwargs(model_type="bert") + assert config.vocab_size == 16 + + +def test_do_not_accept_kwargs(): + @strict + @dataclass + class Config: + model_type: str + + with pytest.raises(TypeError): + Config(model_type="bert", vocab_size=30000) + + +def test_is_recognized_as_dataclass(): + # Check that dataclasses module recognizes it as a dataclass + assert is_dataclass(Config) + + # Check that an instance is recognized as a dataclass instance + config = Config(model_type="bert", hidden_size=768) + assert is_dataclass(config) + + +def test_behave_as_a_dataclass(): + # Check that dataclasses.asdict works + config = Config(model_type="bert", hidden_size=768) + assert asdict(config) == {"model_type": "bert", "hidden_size": 768, "vocab_size": 16} + + # Check that dataclasses.astuple works + assert astuple(config) == ("bert", 768, 16) + + +def test_type_annotations_preserved(): + # Check that type hints are preserved + hints = get_type_hints(Config) + assert hints["model_type"] is str + assert hints["hidden_size"] is int + assert hints["vocab_size"] is int + + +def test_correct_init_signature(): + # Check that __init__ has the expected signature + signature = inspect.signature(Config.__init__) + parameters = list(signature.parameters.values()) + + # First param should be self + assert parameters[0].name == "self" + + # model_type should be required + assert parameters[1].name == "model_type" + assert parameters[1].default == inspect.Parameter.empty + + # hidden_size should be required (and validated) + assert parameters[2].name == "hidden_size" + assert parameters[2].default == inspect.Parameter.empty + + # vocab_size should be optional with default + assert parameters[3].name == "vocab_size" + assert parameters[3].default == 16 + + +def test_correct_eq_repr(): + # Test equality comparison + config1 = Config(model_type="bert", hidden_size=0) + config2 = Config(model_type="bert", hidden_size=0) + config3 = Config(model_type="gpt", hidden_size=0) + + assert config1 == config2 + assert config1 != config3 + + # Test repr + assert repr(config1) == "Config(model_type='bert', hidden_size=0, vocab_size=16)" + + +def test_repr_if_accept_kwargs(): + config1 = ConfigWithKwargs(foo="bar", model_type="bert") + assert repr(config1) == "ConfigWithKwargs(model_type='bert', vocab_size=16, *foo='bar')" + + +def test_autocompletion_attribute_without_kwargs(): + # Create a sample script + completions = jedi.Script(""" +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict +@dataclass +class Config: + model_type: str + hidden_size: int = 768 + +config = Config(model_type="bert") +config. +""").complete(line=12, column=7) + completion_names = [c.name for c in completions] + assert "model_type" in completion_names + assert "hidden_size" in completion_names + + +def test_autocompletion_attribute_with_kwargs(): + # Create a sample script + completions = jedi.Script(""" +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict(accept_kwargs=True) +@dataclass +class Config: + model_type: str + hidden_size: int = 768 + +config = Config(model_type="bert", foo="bar") +config. +""").complete(line=12, column=7) + completion_names = [c.name for c in completions] + assert "model_type" in completion_names + assert "hidden_size" in completion_names + assert "foo" not in completion_names # not an official arg + + +def test_autocompletion_init_without_kwargs(): + # Create a sample script + completions = jedi.Script(""" +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict +@dataclass +class Config: + model_type: str + hidden_size: int = 768 + +config = Config( +""").complete(line=11, column=16) + completion_names = [c.name for c in completions] + assert "model_type=" in completion_names + assert "hidden_size=" in completion_names + + +def test_autocompletion_init_with_kwargs(): + # Create a sample script + completions = jedi.Script(""" +from dataclasses import dataclass +from huggingface_hub.dataclasses import strict + +@strict(accept_kwargs=True) +@dataclass +class Config: + model_type: str + hidden_size: int = 768 + +config = Config( +""").complete(line=11, column=16) + completion_names = [c.name for c in completions] + assert "model_type=" in completion_names + assert "hidden_size=" in completion_names + + +def test_strict_requires_dataclass(): + with pytest.raises(StrictDataclassDefinitionError): + + @strict + class InvalidConfig: + model_type: str + + +class TestClassValidation: + @strict + @dataclass + class ParentConfig: + foo: str = "bar" + foo_length: int = 3 + + def validate_foo_length(self): + if len(self.foo) != self.foo_length: + raise ValueError(f"foo must be {self.foo_length} characters long, got {len(self.foo)}") + + @strict + @dataclass + class ChildConfig(ParentConfig): + number: int = 42 + + def validate_number_multiple_of_foo_length(self): + if self.number % self.foo_length != 0: + raise ValueError(f"number must be a multiple of foo_length ({self.foo_length}), got {self.number}") + + @strict + @dataclass + class OtherChildConfig(ParentConfig): + number: int = 42 + + @strict + @dataclass + class ChildConfigWithPostInit(ParentConfig): + def __post_init__(self): + # Let's assume post_init doubles each value + # Validation is ran AFTER __post_init__ + self.foo = self.foo * 2 + self.foo_length = self.foo_length * 2 + + def test_parent_config_validation(self): + # Test valid initialization + config = self.ParentConfig(foo="bar", foo_length=3) + assert config.foo == "bar" + assert config.foo_length == 3 + + # Test invalid initialization + with pytest.raises(StrictDataclassClassValidationError): + self.ParentConfig(foo="bar", foo_length=4) + + def test_child_config_validation(self): + # Test valid initialization + config = self.ChildConfig(foo="bar", foo_length=3, number=42) + assert config.foo == "bar" + assert config.foo_length == 3 + assert config.number == 42 + + # Test invalid initialization + with pytest.raises(StrictDataclassClassValidationError): + self.ChildConfig(foo="bar", foo_length=4, number=40) + + with pytest.raises(StrictDataclassClassValidationError): + self.ChildConfig(foo="bar", foo_length=3, number=43) + + def test_other_child_config_validation(self): + # Test valid initialization + config = self.OtherChildConfig(foo="bar", foo_length=3, number=43) + assert config.foo == "bar" + assert config.foo_length == 3 + assert config.number == 43 # not validated => did not fail + + # Test invalid initialization + with pytest.raises(StrictDataclassClassValidationError): + self.OtherChildConfig(foo="bar", foo_length=4, number=42) + + def test_validate_after_init(self): + # Test valid initialization + config = self.ParentConfig(foo="bar", foo_length=3) + + # Attributes can be updated after initialization + config.foo = "abcd" + config.foo_length = 4 + config.validate() # Explicit call required + + # Explicit validation fails + config.foo_length = 5 + with pytest.raises(StrictDataclassClassValidationError): + config.validate() + + def test_validation_runs_after_post_init(self): + config = self.ChildConfigWithPostInit(foo="bar", foo_length=3) + assert config.foo == "barbar" + assert config.foo_length == 6 + + with pytest.raises(StrictDataclassClassValidationError, match="foo must be 4 characters long, got 6"): + # post init doubles the value and then the validation fails + self.ChildConfigWithPostInit(foo="bar", foo_length=2) + + +class TestClassValidationWithInheritance: + """Regression test. + + If parent class is not a strict dataclass but defines validators, the child class should validate them too. + """ + + class Base: + def validate_foo(self): + if self.foo < 0: + raise ValueError("foo must be positive") + + @strict + @dataclass + class Config(Base): + foo: int + bar: int + + def validate_bar(self): + if self.bar < 0: + raise ValueError("bar must be positive") + + def test_class_validation_with_inheritance(self): + # Test valid initialization + config = self.Config(foo=0, bar=0) + assert config.foo == 0 + assert config.bar == 0 + + # Test invalid initialization + with pytest.raises(StrictDataclassClassValidationError): + self.Config(foo=0, bar=-1) # validation from child class + + with pytest.raises(StrictDataclassClassValidationError): + self.Config(foo=-1, bar=0) # validation from parent class + + +class TestClassValidateAlreadyExists: + """Regression test. + + If a class already has a validate method, it should raise a StrictDataclassDefinitionError. + """ + + def test_validate_already_defined_by_class(self): + with pytest.raises(StrictDataclassDefinitionError): + + @strict + @dataclass + class Config: + foo: int = 0 + + def validate(self): + pass # already defined => should raise an error + + def test_validate_already_defined_by_parent(self): + with pytest.raises(StrictDataclassDefinitionError): + + class ParentClass: + def validate(self): + pass + + @strict + @dataclass + class ConfigWithParent(ParentClass): # 'validate' already defined => should raise an error + foo: int = 0