diff --git a/src/poly/resources/flows.py b/src/poly/resources/flows.py index 6b156be..8ae27cb 100644 --- a/src/poly/resources/flows.py +++ b/src/poly/resources/flows.py @@ -635,7 +635,7 @@ def read_local_resource( # Get file name from file_path file_name = os.path.splitext(os.path.basename(file_path))[0] - return cls.from_yaml_dict( + instance = cls.from_yaml_dict( yaml_dict, resource_id=resource_id, file_name=file_name, @@ -645,6 +645,8 @@ def read_local_resource( known_position=known_position, resource_mappings=resource_mappings, ) + utils.check_yaml_field_types(instance) + return instance def validate(self, resource_mappings: list[ResourceMapping] = None, **kwargs): """Validate the flow step resource.""" diff --git a/src/poly/resources/pronunciation.py b/src/poly/resources/pronunciation.py index ccfcd90..38bca60 100644 --- a/src/poly/resources/pronunciation.py +++ b/src/poly/resources/pronunciation.py @@ -145,9 +145,11 @@ def read_local_resource( f"Resource with name {resource_clean_name} not found in {true_file_path}" ) - return cls.from_yaml_dict( + instance = cls.from_yaml_dict( yaml_dict, resource_id=resource_id, name="", position=position, **kwargs ) + utils.check_yaml_field_types(instance) + return instance @property def command_type(self) -> str: diff --git a/src/poly/resources/resource.py b/src/poly/resources/resource.py index 74497ed..a67c760 100644 --- a/src/poly/resources/resource.py +++ b/src/poly/resources/resource.py @@ -363,13 +363,15 @@ def read_local_resource( file_path=file_path, **kwargs, ) - return cls.from_yaml_dict( + instance = cls.from_yaml_dict( yaml_dict, resource_id=resource_id, name=resource_name, resource_mappings=resource_mappings, **kwargs, ) + utils.check_yaml_field_types(instance) + return instance @abstractmethod def to_yaml_dict(self) -> dict: diff --git a/src/poly/resources/resource_utils.py b/src/poly/resources/resource_utils.py index bf5fce6..a30e0f9 100644 --- a/src/poly/resources/resource_utils.py +++ b/src/poly/resources/resource_utils.py @@ -13,10 +13,13 @@ import re import subprocess import sys +import types +import typing +from dataclasses import fields, is_dataclass from difflib import unified_diff from enum import Enum from io import StringIO -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import langcodes import ruamel.yaml as yaml @@ -564,6 +567,79 @@ def is_valid_language_code(code: str) -> bool: return langcodes.tag_is_valid(code) +def _unwrap_optional(hint: type) -> type: + """Strip Optional[X] / X | None down to X.""" + origin = typing.get_origin(hint) + if origin is Union or origin is types.UnionType: + args = [a for a in typing.get_args(hint) if a is not type(None)] + if len(args) == 1: + return args[0] + return hint + + +_SCALAR_TYPES = (str, int, float, bool) + + +def _matches_scalar(value: object, expected: type) -> bool: + """Check if a value matches a scalar type, with YAML-aware compatibility. + + bool is excluded from int/float checks because YAML's yes/no auto-cast + should not silently pass as a number. int is accepted for float because + YAML parses 500 as int but 500.0 as float — both are valid numerics. + """ + if expected is float: + return isinstance(value, (int, float)) and not isinstance(value, bool) + if expected is int: + return isinstance(value, int) and not isinstance(value, bool) + return isinstance(value, expected) + + +def check_yaml_field_types(instance: object, _path: str = "") -> None: + """Validate that scalar and list[scalar] fields have the correct type after YAML parsing.""" + hints = typing.get_type_hints(type(instance)) + for f in fields(instance): + value = getattr(instance, f.name) + if value is None: + continue + hint = _unwrap_optional(hints.get(f.name, type(None))) + field_path = f"{_path}.{f.name}" if _path else f.name + + if hint in _SCALAR_TYPES and not _matches_scalar(value, hint): + raise ValueError( + f"'{field_path}' should be {hint.__name__} but got {type(value).__name__}." + ) + list_args = typing.get_args(hint) if typing.get_origin(hint) is list else () + if list_args and list_args[0] in _SCALAR_TYPES: + expected_item_type = list_args[0] + if not isinstance(value, list): + raise ValueError( + f"'{field_path}' should be a list of {expected_item_type.__name__} " + f"but got {type(value).__name__}." + ) + for i, item in enumerate(value): + if not _matches_scalar(item, expected_item_type): + raise ValueError( + f"'{field_path}[{i}]' should be {expected_item_type.__name__} " + f"but got {type(item).__name__}." + ) + dict_args = typing.get_args(hint) if typing.get_origin(hint) is dict else () + if len(dict_args) == 2 and isinstance(value, dict): + key_type, val_type = dict_args + for k, v in value.items(): + if key_type in _SCALAR_TYPES and not _matches_scalar(k, key_type): + raise ValueError( + f"'{field_path}' has key {k!r} which should be {key_type.__name__} " + f"but got {type(k).__name__}." + ) + if val_type in _SCALAR_TYPES and not _matches_scalar(v, val_type): + raise ValueError( + f"'{field_path}[{k!r}]' should be {val_type.__name__} " + f"but got {type(v).__name__}." + ) + if is_dataclass(value) and not isinstance(value, type): + check_yaml_field_types(value, field_path) + + def assign_flow_positions( nodes: list["BaseFlowStep"], start_node_id: str, diff --git a/src/poly/resources/safety_filters.py b/src/poly/resources/safety_filters.py index b8b5ee5..c46b3f3 100644 --- a/src/poly/resources/safety_filters.py +++ b/src/poly/resources/safety_filters.py @@ -9,13 +9,13 @@ from google.protobuf.message import Message +import poly.resources.resource_utils as utils from poly.handlers.protobuf.channels_pb2 import Channel_UpdateSafetyFilters, ChannelType from poly.handlers.protobuf.content_filter_settings_pb2 import ( AzureContentFilter, AzureContentFilterCategory, ContentFilterSettings_UpdateContentFilterSettings, ) -import poly.resources.resource_utils as utils from poly.resources.resource import ResourceMapping, YamlResource PRECISION_MAPPING = {"LOOSE": "lenient", "MEDIUM": "medium", "STRICT": "strict"} diff --git a/src/poly/tests/project_test.py b/src/poly/tests/project_test.py index 89929e3..e8b8363 100644 --- a/src/poly/tests/project_test.py +++ b/src/poly/tests/project_test.py @@ -32,10 +32,10 @@ SettingsRole, SettingsRules, SMSTemplate, - Topic, TestCase, TestCaseAssertion, TestCaseTags, + Topic, TranscriptCorrection, Translation, Variable, diff --git a/src/poly/tests/resources_test.py b/src/poly/tests/resources_test.py index e3fe3ad..bef98f7 100644 --- a/src/poly/tests/resources_test.py +++ b/src/poly/tests/resources_test.py @@ -7,10 +7,9 @@ import unittest import yaml - -import poly.resources.resource_utils as resource_utils from jsonschema import ValidationError +import poly.resources.resource_utils as resource_utils from poly.handlers.sync_client import SyncClientHandler from poly.resources.agent_settings import ( SettingsPersonality, @@ -53,7 +52,6 @@ FunctionParameters, FunctionType, ) - from poly.resources.handoff import Handoff from poly.resources.keyphrase_boosting import KeyphraseBoosting from poly.resources.languages import ( @@ -74,11 +72,6 @@ VoiceSafetyFilters, ) from poly.resources.sms import EnvPhoneNumbers, SMSTemplate -from poly.resources.topic import ( - FUNCTION_REGEX, - Topic, -) -from poly.resources.transcript_correction import RegularExpressionRule, TranscriptCorrection from poly.resources.test_suite import ( FunctionCallArgumentAssertion, FunctionCallAssertion, @@ -86,6 +79,11 @@ TestCaseAssertion, TestCaseTags, ) +from poly.resources.topic import ( + FUNCTION_REGEX, + Topic, +) +from poly.resources.transcript_correction import RegularExpressionRule, TranscriptCorrection from poly.resources.translations import Translation from poly.resources.variable import Variable from poly.resources.variant_attributes import Variant, VariantAttribute @@ -7208,9 +7206,7 @@ def test_validate(self): prompts=[], function_calls=[], ), - tags=TestCaseTags( - resource_id="TEST-missing-scenario", name="tags", tags=[] - ), + tags=TestCaseTags(resource_id="TEST-missing-scenario", name="tags", tags=[]), ).validate() self.assertIn("Scenario is required", str(cm.exception)) @@ -7227,9 +7223,7 @@ def test_validate(self): prompts=[], function_calls=[], ), - tags=TestCaseTags( - resource_id="TEST-missing-language", name="tags", tags=[] - ), + tags=TestCaseTags(resource_id="TEST-missing-language", name="tags", tags=[]), ).validate() self.assertIn("Language is required", str(cm.exception)) @@ -7287,9 +7281,7 @@ def test_get_new_updated_deleted_subresources(self): self.assertEqual(deleted_after_edit, []) def test_discover_resources(self): - base_path = os.path.join( - os.path.dirname(__file__), "test_projects", "test_project" - ) + base_path = os.path.join(os.path.dirname(__file__), "test_projects", "test_project") discovered = TestCase.discover_resources(base_path) self.assertCountEqual( discovered, @@ -7299,6 +7291,7 @@ def test_discover_resources(self): ], ) + class ParseMultiResourcePathTests(unittest.TestCase): """Tests for _parse_multi_resource_path including Windows drive-letter handling.""" @@ -7682,9 +7675,7 @@ def test_update_command_type_is_a_real_command_field(self): from poly.handlers.protobuf.commands_pb2 import Command lang = DefaultLanguage(resource_id="en-US", name="en-US") - self.assertEqual( - lang.update_command_type, "languages_update_default_language" - ) + self.assertEqual(lang.update_command_type, "languages_update_default_language") Command(**{lang.update_command_type: lang.build_update_proto()}) def test_validate_duplicate_with_additional_raises(self): @@ -7921,5 +7912,203 @@ def test_unrelated_mappings_only_raises(self): self.assertIn("ChatStylePrompt", str(cm.exception)) +class CheckYamlFieldTypesTest(unittest.TestCase): + """Tests for the check_yaml_field_types YAML colon-space footgun guard.""" + + def test_str_field_with_dict_raises(self): + """A str field that got a dict (colon-space misparse) should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions={"bad key": "value"}, + content="fine", + example_queries=[], + ) + with self.assertRaises(ValueError) as ctx: + resource_utils.check_yaml_field_types(topic) + self.assertIn("actions", str(ctx.exception)) + self.assertIn("should be str but got dict", str(ctx.exception)) + + def test_list_str_field_with_dict_item_raises(self): + """A list[str] field containing a dict item should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions="fine", + content="fine", + example_queries=["good", {"bad key": "value"}], + ) + with self.assertRaises(ValueError) as ctx: + resource_utils.check_yaml_field_types(topic) + self.assertIn("example_queries[1]", str(ctx.exception)) + self.assertIn("should be str but got dict", str(ctx.exception)) + + def test_nested_dataclass_field_checked(self): + """Nested dataclass str fields should be recursively validated.""" + assertion = TestCaseAssertion( + resource_id="TC-1", + name="assertions", + prompts=["fine", {"the outcome: either": "a specific time"}], + function_calls=[], + ) + test_case = TestCase( + resource_id="TC-1", + name="test", + scenario="test scenario", + channel="chat.polyai", + language="en-GB", + assertions=assertion, + tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), + ) + with self.assertRaises(ValueError) as ctx: + resource_utils.check_yaml_field_types(test_case) + self.assertIn("assertions.prompts[1]", str(ctx.exception)) + + def test_valid_resource_passes(self): + """A correctly typed resource should not raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions="do something", + content="some content", + example_queries=["query 1", "query 2"], + ) + resource_utils.check_yaml_field_types(topic) + + def test_optional_str_field_none_passes(self): + """Optional[str] fields with None should not raise.""" + test_case = TestCase( + resource_id="TC-1", + name="test", + scenario="test scenario", + channel="chat.polyai", + language="en-GB", + assertions=TestCaseAssertion( + resource_id="TC-1", name="assertions", prompts=[], function_calls=[] + ), + tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), + variant=None, + ) + resource_utils.check_yaml_field_types(test_case) + + def test_list_str_field_with_dict_value_raises(self): + """A list[str] field that got a dict instead of a list should raise.""" + assertion = TestCaseAssertion( + resource_id="TC-1", + name="assertions", + prompts={"assertions": ["It responds with a nice message"]}, + function_calls=[], + ) + test_case = TestCase( + resource_id="TC-1", + name="test", + scenario="test scenario", + channel="chat.polyai", + language="en-GB", + assertions=assertion, + tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), + ) + with self.assertRaises(ValueError) as ctx: + resource_utils.check_yaml_field_types(test_case) + self.assertIn("assertions.prompts", str(ctx.exception)) + self.assertIn("list of str", str(ctx.exception)) + + def test_str_field_with_bool_raises(self): + """A str field that got a bool (yes/no auto-cast) should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions=True, + content="fine", + example_queries=[], + ) + with self.assertRaises(ValueError) as ctx: + resource_utils.check_yaml_field_types(topic) + self.assertIn("actions", str(ctx.exception)) + self.assertIn("bool", str(ctx.exception)) + + def test_str_field_with_int_raises(self): + """A str field that got an int (bare number) should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions=42, + content="fine", + example_queries=[], + ) + with self.assertRaises(ValueError) as ctx: + resource_utils.check_yaml_field_types(topic) + self.assertIn("actions", str(ctx.exception)) + self.assertIn("int", str(ctx.exception)) + + def test_list_str_field_with_bool_item_raises(self): + """A list[str] field containing a bool item should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions="fine", + content="fine", + example_queries=["good", True], + ) + with self.assertRaises(ValueError) as ctx: + resource_utils.check_yaml_field_types(topic) + self.assertIn("example_queries[1]", str(ctx.exception)) + self.assertIn("bool", str(ctx.exception)) + + def test_dict_value_wrong_type_raises(self): + """A dict[str, bool] field with a str value instead of bool should raise.""" + personality = SettingsPersonality( + resource_id="P-1", + name="personality", + custom="fine", + adjectives={"Polite": "sure"}, + ) + with self.assertRaises(ValueError) as ctx: + resource_utils.check_yaml_field_types(personality) + self.assertIn("adjectives", str(ctx.exception)) + self.assertIn("should be bool but got str", str(ctx.exception)) + + def test_dict_valid_types_passes(self): + """A dict[str, bool] field with correct types should not raise.""" + personality = SettingsPersonality( + resource_id="P-1", + name="personality", + custom="fine", + adjectives={"Polite": True, "Calm": False}, + ) + resource_utils.check_yaml_field_types(personality) + + def test_int_accepted_for_float_field(self): + """int values should be accepted where float is expected (YAML parses 500 as int).""" + self.assertTrue(resource_utils._matches_scalar(42, float)) + self.assertTrue(resource_utils._matches_scalar(3.14, float)) + + def test_bool_rejected_for_int_field(self): + """bool should not pass as int even though bool is a subclass of int in Python.""" + self.assertFalse(resource_utils._matches_scalar(True, int)) + self.assertFalse(resource_utils._matches_scalar(False, int)) + + def test_bool_rejected_for_float_field(self): + """bool should not pass as float either.""" + self.assertFalse(resource_utils._matches_scalar(True, float)) + self.assertFalse(resource_utils._matches_scalar(False, float)) + + def test_none_on_optional_field_passes(self): + """None in an Optional[str] field should not raise.""" + test_case = TestCase( + resource_id="TC-1", + name="test", + scenario="test scenario", + channel="chat.polyai", + language="en-GB", + assertions=TestCaseAssertion( + resource_id="TC-1", name="assertions", prompts=[], function_calls=[] + ), + tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), + variant=None, + ) + resource_utils.check_yaml_field_types(test_case) + + if __name__ == "__main__": unittest.main() diff --git a/src/poly/utils.py b/src/poly/utils.py index 2cda34b..5f0c9db 100644 --- a/src/poly/utils.py +++ b/src/poly/utils.py @@ -13,14 +13,13 @@ import re from typing import Callable, Optional -from poly.resources import Function, FunctionStep, Resource, ResourceMapping - -from poly.handlers.protobuf.commands_pb2 import Command from poly.handlers.protobuf.channels_pb2 import ( Channel_UpdateStatus, - WebChatChannel_UpdateStatus, ChannelStatus, + WebChatChannel_UpdateStatus, ) +from poly.handlers.protobuf.commands_pb2 import Command +from poly.resources import Function, FunctionStep, Resource, ResourceMapping logger = logging.getLogger(__name__)