Skip to content

Commit 645c827

Browse files
committed
feat Add Enum datatype
1 parent ff395e8 commit 645c827

File tree

3 files changed

+97
-1
lines changed

3 files changed

+97
-1
lines changed

guardrails/datatypes.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Dict, Iterable
77
from typing import List as TypedList
88
from typing import Optional, Sequence, Type, TypeVar, Union
9+
from enum import Enum
910

1011
from dateutil.parser import parse
1112
from lxml import etree as ET
@@ -386,6 +387,50 @@ class Percentage(ScalarType):
386387
tag = "percentage"
387388

388389

390+
@register_type("enum")
391+
class Enum(ScalarType):
392+
"""Element tag: `<enum>`"""
393+
394+
tag = "enum"
395+
396+
def __init__(
397+
self,
398+
children: Dict[str, Any],
399+
validators_attr: ValidatorsAttr,
400+
optional: bool,
401+
name: Optional[str],
402+
description: Optional[str],
403+
enum_values: TypedList[str],
404+
) -> None:
405+
super().__init__(children, validators_attr, optional, name, description)
406+
self.enum_values = enum_values
407+
408+
def from_str(self, s: str) -> Optional[str]:
409+
"""Create an Enum from a string."""
410+
if s is None:
411+
return None
412+
if s not in self.enum_values:
413+
raise ValueError(f"Invalid enum value: {s}")
414+
return s
415+
416+
@classmethod
417+
def from_xml(
418+
cls,
419+
enum_values: TypedList[str],
420+
validators: Sequence[ValidatorSpec],
421+
description: Optional[str] = None,
422+
strict: bool = False,
423+
) -> "Enum":
424+
return cls(
425+
children={},
426+
validators_attr=ValidatorsAttr.from_validators(validators, cls.tag, strict),
427+
optional=False,
428+
name=None,
429+
description=description,
430+
enum_values=enum_values,
431+
)
432+
433+
389434
@register_type("list")
390435
class List(NonScalarType):
391436
"""Element tag: `<list>`"""

guardrails/utils/pydantic_utils/v1.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from copy import deepcopy
55
from datetime import date, time
6+
from enum import Enum
67
from typing import Any, Callable, Dict, Optional, Type, Union, get_args, get_origin
78

89
from pydantic import BaseModel, validator
@@ -20,6 +21,7 @@
2021
from guardrails.datatypes import Object as ObjectDataType
2122
from guardrails.datatypes import String as StringDataType
2223
from guardrails.datatypes import Time as TimeDataType
24+
from guardrails.datatypes import Enum as EnumDataType
2325
from guardrails.validator_base import Validator
2426
from guardrails.validatorsattr import ValidatorsAttr
2527

@@ -66,6 +68,16 @@ def is_dict(type_annotation: Any) -> bool:
6668
return True
6769
return False
6870

71+
def is_enum(type_annotation: Any) -> bool:
72+
"""Check if a type_annotation is an enum"""
73+
74+
type_annotation = prepare_type_annotation(type_annotation)
75+
76+
if is_pydantic_base_model(type_annotation):
77+
return False
78+
if issubclass(type_annotation, Enum):
79+
return True
80+
return False
6981

7082
def prepare_type_annotation(type_annotation: Union[ModelField, Type]) -> Type:
7183
"""Get the raw type annotation that can be used for downstream processing.
@@ -262,6 +274,8 @@ def field_to_datatype(field: Union[ModelField, Type]) -> Type[DataType]:
262274
return ListDataType
263275
elif is_dict(type_annotation):
264276
return ObjectDataType
277+
elif is_enum(type_annotation):
278+
return EnumDataType
265279
elif type_annotation == bool:
266280
return BooleanDataType
267281
elif type_annotation == date:
@@ -301,6 +315,9 @@ def convert_pydantic_model_to_datatype(
301315

302316
model_fields = add_pydantic_validators_as_guardrails_validators(model)
303317

318+
# Use inline import to avoid circular dependency
319+
from guardrails.validators import ValidChoices
320+
304321
children = {}
305322
for field_name, field in model_fields.items():
306323
if field_name in excluded_fields:
@@ -356,6 +373,15 @@ def convert_pydantic_model_to_datatype(
356373
strict=strict,
357374
discriminator_key=discriminator,
358375
)
376+
elif target_datatype == EnumDataType:
377+
valid_choices = type_annotation._member_names_
378+
field.field_info.extra["validators"] = [ValidChoices(choices=valid_choices)]
379+
return pydantic_field_to_datatype(
380+
EnumDataType,
381+
field,
382+
strict=strict,
383+
enum_values=valid_choices
384+
)
359385
elif isinstance(field.type_, type) and issubclass(field.type_, BaseModel):
360386
children[field_name] = convert_pydantic_model_to_datatype(
361387
field, datatype=target_datatype, strict=strict

guardrails/utils/pydantic_utils/v2.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import warnings
33
from copy import deepcopy
44
from datetime import date, time
5+
from enum import Enum
56
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_args
6-
77
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
88
from pydantic.fields import FieldInfo
99

@@ -21,6 +21,7 @@
2121
from guardrails.datatypes import PythonCode as PythonCodeDataType
2222
from guardrails.datatypes import String as StringDataType
2323
from guardrails.datatypes import Time as TimeDataType
24+
from guardrails.datatypes import Enum as EnumDataType
2425
from guardrails.validator_base import Validator
2526
from guardrails.validatorsattr import ValidatorsAttr
2627

@@ -87,6 +88,16 @@ def is_dict(type_annotation: Any) -> bool:
8788
return True
8889
return False
8990

91+
def is_enum(type_annotation: Any) -> bool:
92+
"""Check if a type_annotation is an enum"""
93+
94+
type_annotation = prepare_type_annotation(type_annotation)
95+
96+
if is_pydantic_base_model(type_annotation):
97+
return False
98+
if issubclass(type_annotation, Enum):
99+
return True
100+
return False
90101

91102
def _create_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:
92103
class BareModel(BaseModel):
@@ -277,6 +288,8 @@ def field_to_datatype(field: Union[FieldInfo, Type]) -> Type[DataType]:
277288
return ListDataType
278289
elif is_dict(type_annotation):
279290
return ObjectDataType
291+
elif is_enum(type_annotation):
292+
return EnumDataType
280293
elif type_annotation == bool:
281294
return BooleanDataType
282295
elif type_annotation == date:
@@ -322,6 +335,9 @@ def convert_pydantic_model_to_datatype(
322335

323336
model_fields = add_pydantic_validators_as_guardrails_validators(model)
324337

338+
# Use inline import to avoid circular dependency
339+
from guardrails.validators import ValidChoices
340+
325341
children = {}
326342
for field_name, field in model_fields.items():
327343
if field_name in excluded_fields:
@@ -382,6 +398,15 @@ def convert_pydantic_model_to_datatype(
382398
discriminator_key=discriminator,
383399
name=field_name,
384400
)
401+
elif target_datatype == EnumDataType:
402+
valid_choices = type_annotation._member_names_
403+
field.json_schema_extra["validators"] = [ValidChoices(choices=valid_choices)]
404+
return pydantic_field_to_datatype(
405+
EnumDataType,
406+
field,
407+
strict=strict,
408+
enum_values=valid_choices
409+
)
385410
elif is_pydantic_base_model(field.annotation):
386411
children[field_name] = convert_pydantic_model_to_datatype(
387412
field,

0 commit comments

Comments
 (0)