|
4 | 4 | from datetime import date, time |
5 | 5 | from enum import Enum |
6 | 6 | from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_args |
| 7 | + |
7 | 8 | from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator |
8 | 9 | from pydantic.fields import FieldInfo |
9 | 10 |
|
|
14 | 15 | from guardrails.datatypes import Choice as ChoiceDataType |
15 | 16 | from guardrails.datatypes import DataType |
16 | 17 | from guardrails.datatypes import Date as DateDataType |
| 18 | +from guardrails.datatypes import Enum as EnumDataType |
17 | 19 | from guardrails.datatypes import Float as FloatDataType |
18 | 20 | from guardrails.datatypes import Integer as IntegerDataType |
19 | 21 | from guardrails.datatypes import List as ListDataType |
20 | 22 | from guardrails.datatypes import Object as ObjectDataType |
21 | 23 | from guardrails.datatypes import PythonCode as PythonCodeDataType |
22 | 24 | from guardrails.datatypes import String as StringDataType |
23 | 25 | from guardrails.datatypes import Time as TimeDataType |
24 | | -from guardrails.datatypes import Enum as EnumDataType |
25 | 26 | from guardrails.validator_base import Validator |
26 | 27 | from guardrails.validatorsattr import ValidatorsAttr |
27 | 28 |
|
@@ -88,17 +89,20 @@ def is_dict(type_annotation: Any) -> bool: |
88 | 89 | return True |
89 | 90 | return False |
90 | 91 |
|
| 92 | + |
91 | 93 | def is_enum(type_annotation: Any) -> bool: |
92 | | - """Check if a type_annotation is an enum""" |
| 94 | + """Check if a type_annotation is an enum.""" |
93 | 95 |
|
94 | 96 | type_annotation = prepare_type_annotation(type_annotation) |
95 | 97 |
|
96 | | - if is_pydantic_base_model(type_annotation): |
97 | | - return False |
98 | | - if issubclass(type_annotation, Enum): |
99 | | - return True |
| 98 | + try: |
| 99 | + if issubclass(type_annotation, Enum): |
| 100 | + return True |
| 101 | + except TypeError: |
| 102 | + pass |
100 | 103 | return False |
101 | 104 |
|
| 105 | + |
102 | 106 | def _create_bare_model(model: Type[BaseModel]) -> Type[BaseModel]: |
103 | 107 | class BareModel(BaseModel): |
104 | 108 | __annotations__ = getattr(model, "__annotations__", {}) |
@@ -335,9 +339,6 @@ def convert_pydantic_model_to_datatype( |
335 | 339 |
|
336 | 340 | model_fields = add_pydantic_validators_as_guardrails_validators(model) |
337 | 341 |
|
338 | | - # Use inline import to avoid circular dependency |
339 | | - from guardrails.validators import ValidChoices |
340 | | - |
341 | 342 | children = {} |
342 | 343 | for field_name, field in model_fields.items(): |
343 | 344 | if field_name in excluded_fields: |
@@ -399,13 +400,14 @@ def convert_pydantic_model_to_datatype( |
399 | 400 | name=field_name, |
400 | 401 | ) |
401 | 402 | elif target_datatype == EnumDataType: |
402 | | - valid_choices = type_annotation._member_names_ |
403 | | - field.json_schema_extra["validators"] = [ValidChoices(choices=valid_choices)] |
404 | | - return pydantic_field_to_datatype( |
| 403 | + assert issubclass(type_annotation, Enum) |
| 404 | + valid_choices = [choice.value for choice in type_annotation] |
| 405 | + children[field_name] = pydantic_field_to_datatype( |
405 | 406 | EnumDataType, |
406 | 407 | field, |
407 | 408 | strict=strict, |
408 | | - enum_values=valid_choices |
| 409 | + enum_values=valid_choices, |
| 410 | + name=field_name, |
409 | 411 | ) |
410 | 412 | elif is_pydantic_base_model(field.annotation): |
411 | 413 | children[field_name] = convert_pydantic_model_to_datatype( |
|
0 commit comments