Skip to content

Commit 78d5052

Browse files
authored
Merge pull request #486 from emekaokoli19/Enum
[feat] Support enums in pydantic
2 parents ff395e8 + 8247e48 commit 78d5052

File tree

11 files changed

+139
-0
lines changed

11 files changed

+139
-0
lines changed

guardrails/datatypes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,50 @@ class Percentage(ScalarType):
386386
tag = "percentage"
387387

388388

389+
@register_type("enum")
390+
class Enum(ScalarType):
391+
"""Element tag: `<enum>`"""
392+
393+
tag = "enum"
394+
395+
def __init__(
396+
self,
397+
children: Dict[str, Any],
398+
validators_attr: ValidatorsAttr,
399+
optional: bool,
400+
name: Optional[str],
401+
description: Optional[str],
402+
enum_values: TypedList[str],
403+
) -> None:
404+
super().__init__(children, validators_attr, optional, name, description)
405+
self.enum_values = enum_values
406+
407+
def from_str(self, s: str) -> Optional[str]:
408+
"""Create an Enum from a string."""
409+
if s is None:
410+
return None
411+
if s not in self.enum_values:
412+
raise ValueError(f"Invalid enum value: {s}")
413+
return s
414+
415+
@classmethod
416+
def from_xml(
417+
cls,
418+
enum_values: TypedList[str],
419+
validators: Sequence[ValidatorSpec],
420+
description: Optional[str] = None,
421+
strict: bool = False,
422+
) -> "Enum":
423+
return cls(
424+
children={},
425+
validators_attr=ValidatorsAttr.from_validators(validators, cls.tag, strict),
426+
optional=False,
427+
name=None,
428+
description=description,
429+
enum_values=enum_values,
430+
)
431+
432+
389433
@register_type("list")
390434
class List(NonScalarType):
391435
"""Element tag: `<list>`"""

guardrails/utils/json_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DataType,
1313
Date,
1414
Email,
15+
Enum,
1516
Float,
1617
Integer,
1718
)
@@ -52,6 +53,7 @@ def verify(
5253
ListDataType: list,
5354
Date: str,
5455
Time: str,
56+
Enum: str,
5557
}
5658

5759
ignore_types = [

guardrails/utils/pydantic_utils/v1.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from copy import deepcopy
55
from datetime import date, time
6+
from enum import Enum
67
from typing import Any, Callable, Dict, Optional, Type, Union, get_args, get_origin
78

89
from pydantic import BaseModel, validator
@@ -14,6 +15,7 @@
1415
from guardrails.datatypes import Choice as ChoiceDataType
1516
from guardrails.datatypes import DataType
1617
from guardrails.datatypes import Date as DateDataType
18+
from guardrails.datatypes import Enum as EnumDataType
1719
from guardrails.datatypes import Float as FloatDataType
1820
from guardrails.datatypes import Integer as IntegerDataType
1921
from guardrails.datatypes import List as ListDataType
@@ -67,6 +69,19 @@ def is_dict(type_annotation: Any) -> bool:
6769
return False
6870

6971

72+
def is_enum(type_annotation: Any) -> bool:
73+
"""Check if a type_annotation is an enum."""
74+
75+
type_annotation = prepare_type_annotation(type_annotation)
76+
77+
try:
78+
if issubclass(type_annotation, Enum):
79+
return True
80+
except TypeError:
81+
pass
82+
return False
83+
84+
7085
def prepare_type_annotation(type_annotation: Union[ModelField, Type]) -> Type:
7186
"""Get the raw type annotation that can be used for downstream processing.
7287
@@ -262,6 +277,8 @@ def field_to_datatype(field: Union[ModelField, Type]) -> Type[DataType]:
262277
return ListDataType
263278
elif is_dict(type_annotation):
264279
return ObjectDataType
280+
elif is_enum(type_annotation):
281+
return EnumDataType
265282
elif type_annotation == bool:
266283
return BooleanDataType
267284
elif type_annotation == date:
@@ -356,6 +373,12 @@ def convert_pydantic_model_to_datatype(
356373
strict=strict,
357374
discriminator_key=discriminator,
358375
)
376+
elif target_datatype == EnumDataType:
377+
assert issubclass(type_annotation, Enum)
378+
valid_choices = [choice.value for choice in type_annotation]
379+
children[field_name] = pydantic_field_to_datatype(
380+
EnumDataType, field, strict=strict, enum_values=valid_choices
381+
)
359382
elif isinstance(field.type_, type) and issubclass(field.type_, BaseModel):
360383
children[field_name] = convert_pydantic_model_to_datatype(
361384
field, datatype=target_datatype, strict=strict

guardrails/utils/pydantic_utils/v2.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from copy import deepcopy
44
from datetime import date, time
5+
from enum import Enum
56
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_args
67

78
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
@@ -14,6 +15,7 @@
1415
from guardrails.datatypes import Choice as ChoiceDataType
1516
from guardrails.datatypes import DataType
1617
from guardrails.datatypes import Date as DateDataType
18+
from guardrails.datatypes import Enum as EnumDataType
1719
from guardrails.datatypes import Float as FloatDataType
1820
from guardrails.datatypes import Integer as IntegerDataType
1921
from guardrails.datatypes import List as ListDataType
@@ -88,6 +90,19 @@ def is_dict(type_annotation: Any) -> bool:
8890
return False
8991

9092

93+
def is_enum(type_annotation: Any) -> bool:
94+
"""Check if a type_annotation is an enum."""
95+
96+
type_annotation = prepare_type_annotation(type_annotation)
97+
98+
try:
99+
if issubclass(type_annotation, Enum):
100+
return True
101+
except TypeError:
102+
pass
103+
return False
104+
105+
91106
def _create_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:
92107
class BareModel(BaseModel):
93108
__annotations__ = getattr(model, "__annotations__", {})
@@ -277,6 +292,8 @@ def field_to_datatype(field: Union[FieldInfo, Type]) -> Type[DataType]:
277292
return ListDataType
278293
elif is_dict(type_annotation):
279294
return ObjectDataType
295+
elif is_enum(type_annotation):
296+
return EnumDataType
280297
elif type_annotation == bool:
281298
return BooleanDataType
282299
elif type_annotation == date:
@@ -382,6 +399,16 @@ def convert_pydantic_model_to_datatype(
382399
discriminator_key=discriminator,
383400
name=field_name,
384401
)
402+
elif target_datatype == EnumDataType:
403+
assert issubclass(type_annotation, Enum)
404+
valid_choices = [choice.value for choice in type_annotation]
405+
children[field_name] = pydantic_field_to_datatype(
406+
EnumDataType,
407+
field,
408+
strict=strict,
409+
enum_values=valid_choices,
410+
name=field_name,
411+
)
385412
elif is_pydantic_base_model(field.annotation):
386413
children[field_name] = convert_pydantic_model_to_datatype(
387414
field,

tests/integration_tests/mock_llm_outputs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def _invoke_llm(self, prompt, *args, **kwargs):
2828
pydantic.COMPILED_PROMPT_FULL_REASK_1: pydantic.LLM_OUTPUT_FULL_REASK_1,
2929
pydantic.COMPILED_PROMPT_REASK_2: pydantic.LLM_OUTPUT_REASK_2,
3030
pydantic.COMPILED_PROMPT_FULL_REASK_2: pydantic.LLM_OUTPUT_FULL_REASK_2,
31+
pydantic.COMPILED_PROMPT_ENUM: pydantic.LLM_OUTPUT_ENUM,
32+
pydantic.COMPILED_PROMPT_ENUM_2: pydantic.LLM_OUTPUT_ENUM_2,
3133
string.COMPILED_PROMPT: string.LLM_OUTPUT,
3234
string.COMPILED_PROMPT_REASK: string.LLM_OUTPUT_REASK,
3335
string.COMPILED_LIST_PROMPT: string.LIST_LLM_OUTPUT,

tests/integration_tests/test_assets/pydantic/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@
3131
COMPILED_PROMPT_REASK_2 = reader("compiled_prompt_reask_2.txt")
3232
COMPILED_PROMPT_FULL_REASK_2 = reader("compiled_prompt_full_reask_2.txt")
3333
COMPILED_INSTRUCTIONS_REASK_2 = reader("compiled_instructions_reask_2.txt")
34+
COMPILED_PROMPT_ENUM = reader("compiled_prompt_enum.txt")
35+
COMPILED_PROMPT_ENUM_2 = reader("compiled_prompt_enum_2.txt")
3436

3537
LLM_OUTPUT = reader("llm_output.txt")
3638
LLM_OUTPUT_REASK_1 = reader("llm_output_reask_1.txt")
3739
LLM_OUTPUT_FULL_REASK_1 = reader("llm_output_full_reask_1.txt")
3840
LLM_OUTPUT_REASK_2 = reader("llm_output_reask_2.txt")
3941
LLM_OUTPUT_FULL_REASK_2 = reader("llm_output_full_reask_2.txt")
42+
LLM_OUTPUT_ENUM = reader("llm_output_enum.txt")
43+
LLM_OUTPUT_ENUM_2 = reader("llm_output_enum_2.txt")
4044

4145
RAIL_SPEC_WITH_REASK = reader("reask.rail")
4246

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
What is the status of this task?
2+
3+
Json Output:
4+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
What is the status of this task REALLY?
2+
3+
Json Output:
4+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"status": "not started"}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"status": "i dont know?"}

0 commit comments

Comments
 (0)