Skip to content

Commit 8247e48

Browse files
committed
Enum: add tests
1 parent e5d6c70 commit 8247e48

File tree

7 files changed

+43
-0
lines changed

7 files changed

+43
-0
lines changed

tests/integration_tests/mock_llm_outputs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def _invoke_llm(self, prompt, *args, **kwargs):
2828
pydantic.COMPILED_PROMPT_FULL_REASK_1: pydantic.LLM_OUTPUT_FULL_REASK_1,
2929
pydantic.COMPILED_PROMPT_REASK_2: pydantic.LLM_OUTPUT_REASK_2,
3030
pydantic.COMPILED_PROMPT_FULL_REASK_2: pydantic.LLM_OUTPUT_FULL_REASK_2,
31+
pydantic.COMPILED_PROMPT_ENUM: pydantic.LLM_OUTPUT_ENUM,
32+
pydantic.COMPILED_PROMPT_ENUM_2: pydantic.LLM_OUTPUT_ENUM_2,
3133
string.COMPILED_PROMPT: string.LLM_OUTPUT,
3234
string.COMPILED_PROMPT_REASK: string.LLM_OUTPUT_REASK,
3335
string.COMPILED_LIST_PROMPT: string.LIST_LLM_OUTPUT,

tests/integration_tests/test_assets/pydantic/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@
3131
COMPILED_PROMPT_REASK_2 = reader("compiled_prompt_reask_2.txt")
3232
COMPILED_PROMPT_FULL_REASK_2 = reader("compiled_prompt_full_reask_2.txt")
3333
COMPILED_INSTRUCTIONS_REASK_2 = reader("compiled_instructions_reask_2.txt")
34+
COMPILED_PROMPT_ENUM = reader("compiled_prompt_enum.txt")
35+
COMPILED_PROMPT_ENUM_2 = reader("compiled_prompt_enum_2.txt")
3436

3537
LLM_OUTPUT = reader("llm_output.txt")
3638
LLM_OUTPUT_REASK_1 = reader("llm_output_reask_1.txt")
3739
LLM_OUTPUT_FULL_REASK_1 = reader("llm_output_full_reask_1.txt")
3840
LLM_OUTPUT_REASK_2 = reader("llm_output_reask_2.txt")
3941
LLM_OUTPUT_FULL_REASK_2 = reader("llm_output_full_reask_2.txt")
42+
LLM_OUTPUT_ENUM = reader("llm_output_enum.txt")
43+
LLM_OUTPUT_ENUM_2 = reader("llm_output_enum_2.txt")
4044

4145
RAIL_SPEC_WITH_REASK = reader("reask.rail")
4246

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
What is the status of this task?
2+
3+
Json Output:
4+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
What is the status of this task REALLY?
2+
3+
Json Output:
4+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"status": "not started"}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"status": "i dont know?"}

tests/integration_tests/test_guard.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import json
23
import os
34
from typing import Optional, Union
@@ -782,3 +783,29 @@ def test_in_memory_validator_log_is_not_duplicated(mocker):
782783

783784
finally:
784785
OneLine.run_in_separate_process = separate_proc_bak
786+
787+
788+
def test_enum_datatype(mocker):
789+
mocker.patch("guardrails.llm_providers.OpenAICallable", new=MockOpenAICallable)
790+
791+
class TaskStatus(enum.Enum):
792+
not_started = "not started"
793+
on_hold = "on hold"
794+
in_progress = "in progress"
795+
796+
class Task(BaseModel):
797+
status: TaskStatus
798+
799+
guard = gd.Guard.from_pydantic(Task)
800+
_, dict_o = guard(
801+
get_static_openai_create_func(),
802+
prompt="What is the status of this task?",
803+
)
804+
assert dict_o == {"status": "not started"}
805+
806+
guard = gd.Guard.from_pydantic(Task)
807+
with pytest.raises(ValueError):
808+
guard(
809+
get_static_openai_create_func(),
810+
prompt="What is the status of this task REALLY?",
811+
)

0 commit comments

Comments
 (0)