diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 4a252aec6..a08b36cbd 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -17,7 +17,13 @@ class {{ enum.py_name }}(betterproto.Enum): def __get_pydantic_core_schema__(cls, _source_type, _handler): from pydantic_core import core_schema - return core_schema.int_schema(ge=0) + # Return the schema for validation and serialization + return core_schema.chain_schema( + [ + core_schema.int_schema(ge=0), # Validate as a string first + core_schema.no_info_plain_validator_function(lambda value: cls(value)), # Custom validation + ] + ) {% endif %} {% endfor %} diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py index 21a5ac3b9..578aba877 100644 --- a/tests/inputs/enum/test_enum.py +++ b/tests/inputs/enum/test_enum.py @@ -3,6 +3,10 @@ Choice, Test, ) +from tests.output_betterproto_pydantic.enum import ( + Choice as ChoicePyd, + Test as TestPyd, +) def test_enum_set_and_get(): @@ -112,3 +116,8 @@ def test_renamed_enum_members(): "MINUS", "_0_PREFIXED", } + + +def test_pydantic_enum_preserve_type(): + test = TestPyd(choice=ChoicePyd.ZERO) + assert isinstance(test.choice, ChoicePyd)