Skip to content

Commit e5d6c70

Browse files
committed
Enum: fix functionality
1 parent 645c827 commit e5d6c70

File tree

4 files changed

+30
-30
lines changed

4 files changed

+30
-30
lines changed

guardrails/datatypes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Dict, Iterable
77
from typing import List as TypedList
88
from typing import Optional, Sequence, Type, TypeVar, Union
9-
from enum import Enum
109

1110
from dateutil.parser import parse
1211
from lxml import etree as ET

guardrails/utils/json_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DataType,
1313
Date,
1414
Email,
15+
Enum,
1516
Float,
1617
Integer,
1718
)
@@ -52,6 +53,7 @@ def verify(
5253
ListDataType: list,
5354
Date: str,
5455
Time: str,
56+
Enum: str,
5557
}
5658

5759
ignore_types = [

guardrails/utils/pydantic_utils/v1.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from guardrails.datatypes import Choice as ChoiceDataType
1616
from guardrails.datatypes import DataType
1717
from guardrails.datatypes import Date as DateDataType
18+
from guardrails.datatypes import Enum as EnumDataType
1819
from guardrails.datatypes import Float as FloatDataType
1920
from guardrails.datatypes import Integer as IntegerDataType
2021
from guardrails.datatypes import List as ListDataType
2122
from guardrails.datatypes import Object as ObjectDataType
2223
from guardrails.datatypes import String as StringDataType
2324
from guardrails.datatypes import Time as TimeDataType
24-
from guardrails.datatypes import Enum as EnumDataType
2525
from guardrails.validator_base import Validator
2626
from guardrails.validatorsattr import ValidatorsAttr
2727

@@ -68,17 +68,20 @@ def is_dict(type_annotation: Any) -> bool:
6868
return True
6969
return False
7070

71+
7172
def is_enum(type_annotation: Any) -> bool:
72-
"""Check if a type_annotation is an enum"""
73+
"""Check if a type_annotation is an enum."""
7374

7475
type_annotation = prepare_type_annotation(type_annotation)
7576

76-
if is_pydantic_base_model(type_annotation):
77-
return False
78-
if issubclass(type_annotation, Enum):
79-
return True
77+
try:
78+
if issubclass(type_annotation, Enum):
79+
return True
80+
except TypeError:
81+
pass
8082
return False
8183

84+
8285
def prepare_type_annotation(type_annotation: Union[ModelField, Type]) -> Type:
8386
"""Get the raw type annotation that can be used for downstream processing.
8487
@@ -315,9 +318,6 @@ def convert_pydantic_model_to_datatype(
315318

316319
model_fields = add_pydantic_validators_as_guardrails_validators(model)
317320

318-
# Use inline import to avoid circular dependency
319-
from guardrails.validators import ValidChoices
320-
321321
children = {}
322322
for field_name, field in model_fields.items():
323323
if field_name in excluded_fields:
@@ -374,13 +374,10 @@ def convert_pydantic_model_to_datatype(
374374
discriminator_key=discriminator,
375375
)
376376
elif target_datatype == EnumDataType:
377-
valid_choices = type_annotation._member_names_
378-
field.field_info.extra["validators"] = [ValidChoices(choices=valid_choices)]
379-
return pydantic_field_to_datatype(
380-
EnumDataType,
381-
field,
382-
strict=strict,
383-
enum_values=valid_choices
377+
assert issubclass(type_annotation, Enum)
378+
valid_choices = [choice.value for choice in type_annotation]
379+
children[field_name] = pydantic_field_to_datatype(
380+
EnumDataType, field, strict=strict, enum_values=valid_choices
384381
)
385382
elif isinstance(field.type_, type) and issubclass(field.type_, BaseModel):
386383
children[field_name] = convert_pydantic_model_to_datatype(

guardrails/utils/pydantic_utils/v2.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import date, time
55
from enum import Enum
66
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_args
7+
78
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
89
from pydantic.fields import FieldInfo
910

@@ -14,14 +15,14 @@
1415
from guardrails.datatypes import Choice as ChoiceDataType
1516
from guardrails.datatypes import DataType
1617
from guardrails.datatypes import Date as DateDataType
18+
from guardrails.datatypes import Enum as EnumDataType
1719
from guardrails.datatypes import Float as FloatDataType
1820
from guardrails.datatypes import Integer as IntegerDataType
1921
from guardrails.datatypes import List as ListDataType
2022
from guardrails.datatypes import Object as ObjectDataType
2123
from guardrails.datatypes import PythonCode as PythonCodeDataType
2224
from guardrails.datatypes import String as StringDataType
2325
from guardrails.datatypes import Time as TimeDataType
24-
from guardrails.datatypes import Enum as EnumDataType
2526
from guardrails.validator_base import Validator
2627
from guardrails.validatorsattr import ValidatorsAttr
2728

@@ -88,17 +89,20 @@ def is_dict(type_annotation: Any) -> bool:
8889
return True
8990
return False
9091

92+
9193
def is_enum(type_annotation: Any) -> bool:
92-
"""Check if a type_annotation is an enum"""
94+
"""Check if a type_annotation is an enum."""
9395

9496
type_annotation = prepare_type_annotation(type_annotation)
9597

96-
if is_pydantic_base_model(type_annotation):
97-
return False
98-
if issubclass(type_annotation, Enum):
99-
return True
98+
try:
99+
if issubclass(type_annotation, Enum):
100+
return True
101+
except TypeError:
102+
pass
100103
return False
101104

105+
102106
def _create_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:
103107
class BareModel(BaseModel):
104108
__annotations__ = getattr(model, "__annotations__", {})
@@ -335,9 +339,6 @@ def convert_pydantic_model_to_datatype(
335339

336340
model_fields = add_pydantic_validators_as_guardrails_validators(model)
337341

338-
# Use inline import to avoid circular dependency
339-
from guardrails.validators import ValidChoices
340-
341342
children = {}
342343
for field_name, field in model_fields.items():
343344
if field_name in excluded_fields:
@@ -399,13 +400,14 @@ def convert_pydantic_model_to_datatype(
399400
name=field_name,
400401
)
401402
elif target_datatype == EnumDataType:
402-
valid_choices = type_annotation._member_names_
403-
field.json_schema_extra["validators"] = [ValidChoices(choices=valid_choices)]
404-
return pydantic_field_to_datatype(
403+
assert issubclass(type_annotation, Enum)
404+
valid_choices = [choice.value for choice in type_annotation]
405+
children[field_name] = pydantic_field_to_datatype(
405406
EnumDataType,
406407
field,
407408
strict=strict,
408-
enum_values=valid_choices
409+
enum_values=valid_choices,
410+
name=field_name,
409411
)
410412
elif is_pydantic_base_model(field.annotation):
411413
children[field_name] = convert_pydantic_model_to_datatype(

0 commit comments

Comments
 (0)