|
6 | 6 | from decimal import Decimal |
7 | 7 | from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network |
8 | 8 | from pathlib import Path |
9 | | -from typing import Callable, Dict, FrozenSet, List, Literal, Optional, Sequence, Set, Tuple, Type, Union |
| 9 | +from typing import Any, Callable, Dict, FrozenSet, List, Literal, Optional, Sequence, Set, Tuple, Type, Union |
10 | 10 | from uuid import UUID |
11 | 11 |
|
12 | 12 | import pytest |
|
64 | 64 | validator, |
65 | 65 | ) |
66 | 66 |
|
| 67 | +from polyfactory.exceptions import ParameterException |
67 | 68 | from polyfactory.factories import DataclassFactory |
68 | 69 | from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory |
| 70 | +from polyfactory.field_meta import FieldMeta |
69 | 71 | from tests.models import Person, PetFactory |
70 | 72 |
|
71 | 73 | IS_PYDANTIC_V1 = _IS_PYDANTIC_V1 |
@@ -634,6 +636,49 @@ class A(BaseModel): |
634 | 636 | assert AFactory.build() |
635 | 637 |
|
636 | 638 |
|
| 639 | +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires modern union types") |
| 640 | +@pytest.mark.skipif(IS_PYDANTIC_V1, reason="pydantic 2 only test") |
| 641 | +def test_optional_custom_type() -> None: |
| 642 | + from pydantic_core import core_schema |
| 643 | + |
| 644 | + class CustomType: |
| 645 | + def __init__(self, _: Any) -> None: |
| 646 | + pass |
| 647 | + |
| 648 | + def __get_pydantic_core_schema__(self, _: Any) -> core_schema.StringSchema: |
| 649 | + # for pydantic to stop complaining |
| 650 | + return core_schema.str_schema() |
| 651 | + |
| 652 | + class OptionalFormOne(BaseModel): |
| 653 | + optional_custom_type: Optional[CustomType] |
| 654 | + |
| 655 | + @classmethod |
| 656 | + def should_set_none_value(cls, field_meta: FieldMeta) -> bool: |
| 657 | + return False |
| 658 | + |
| 659 | + class OptionalFormOneFactory(ModelFactory[OptionalFormOne]): |
| 660 | + @classmethod |
| 661 | + def should_set_none_value(cls, field_meta: FieldMeta) -> bool: |
| 662 | + return False |
| 663 | + |
| 664 | + class OptionalFormTwo(BaseModel): |
| 665 | + # this is represented differently than `Optional[None]` internally |
| 666 | + optional_custom_type_second_form: CustomType | None |
| 667 | + |
| 668 | + class OptionalFormTwoFactory(ModelFactory[OptionalFormTwo]): |
| 669 | + @classmethod |
| 670 | + def should_set_none_value(cls, field_meta: FieldMeta) -> bool: |
| 671 | + return False |
| 672 | + |
| 673 | + # ensure the custom type field name and variant is in the error message |
| 674 | + |
| 675 | + with pytest.raises(ParameterException, match=r"optional_custom_type"): |
| 676 | + OptionalFormOneFactory.build() |
| 677 | + |
| 678 | + with pytest.raises(ParameterException, match=r"optional_custom_type_second_form"): |
| 679 | + OptionalFormTwoFactory.build() |
| 680 | + |
| 681 | + |
637 | 682 | def test_collection_unions_with_models() -> None: |
638 | 683 | class A(BaseModel): |
639 | 684 | a: int |
|
0 commit comments