From a423be79d58c298f1fbd71fccab74908c70aa4c5 Mon Sep 17 00:00:00 2001 From: Ruari Phipps Date: Mon, 15 Jun 2026 17:17:56 +0100 Subject: [PATCH 1/4] fix: add dataclass-aware YAML field type validation at parse time YAML values with unquoted colons silently parse as dicts instead of strings, causing opaque TypeError from protobuf C code at push time. Uses dataclasses.fields() + typing.get_type_hints() to automatically validate str and list[str] fields in YamlResource.read_local_resource, with recursion into nested dataclass fields. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/poly/resources/flows.py | 5 +- src/poly/resources/pronunciation.py | 5 +- src/poly/resources/resource.py | 45 ++++++++++- src/poly/tests/resources_test.py | 111 +++++++++++++++++++++++++--- 4 files changed, 150 insertions(+), 16 deletions(-) diff --git a/src/poly/resources/flows.py b/src/poly/resources/flows.py index 6b156be5..0957472f 100644 --- a/src/poly/resources/flows.py +++ b/src/poly/resources/flows.py @@ -55,6 +55,7 @@ ResourceMapping, SubResource, YamlResource, + check_yaml_field_types, ) FUNCTION_REGEX = re.compile(r"{{f[nt]:([\w-]+)}}") @@ -635,7 +636,7 @@ def read_local_resource( # Get file name from file_path file_name = os.path.splitext(os.path.basename(file_path))[0] - return cls.from_yaml_dict( + instance = cls.from_yaml_dict( yaml_dict, resource_id=resource_id, file_name=file_name, @@ -645,6 +646,8 @@ def read_local_resource( known_position=known_position, resource_mappings=resource_mappings, ) + check_yaml_field_types(instance) + return instance def validate(self, resource_mappings: list[ResourceMapping] = None, **kwargs): """Validate the flow step resource.""" diff --git a/src/poly/resources/pronunciation.py b/src/poly/resources/pronunciation.py index ccfcd900..b66e0f7c 100644 --- a/src/poly/resources/pronunciation.py +++ b/src/poly/resources/pronunciation.py @@ -17,6 +17,7 @@ MultiResourceYamlResource, ResourceMapping, _parse_multi_resource_path, + check_yaml_field_types, ) @@ -145,9 +146,11 @@ def read_local_resource( f"Resource with name {resource_clean_name} not found in {true_file_path}" ) - return cls.from_yaml_dict( + instance = cls.from_yaml_dict( yaml_dict, resource_id=resource_id, name="", position=position, **kwargs ) + check_yaml_field_types(instance) + return instance @property def command_type(self) -> str: diff --git a/src/poly/resources/resource.py b/src/poly/resources/resource.py index 74497ede..badf30b4 100644 --- a/src/poly/resources/resource.py +++ b/src/poly/resources/resource.py @@ -4,14 +4,53 @@ """ import os +import types +import typing from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, fields as dc_fields, is_dataclass from typing import ClassVar, Optional from google.protobuf.message import Message import poly.resources.resource_utils as utils +_YAML_TYPE_HINT = ( + "A YAML value was likely parsed as the wrong type. " + "Check for an unquoted mid-sentence colon or a leading [ { * & ? and quote the value." +) + + +def _unwrap_optional(hint: type) -> type: + """Strip Optional[X] / X | None down to X.""" + origin = typing.get_origin(hint) + if origin is typing.Union or origin is types.UnionType: + args = [a for a in typing.get_args(hint) if a is not type(None)] + if len(args) == 1: + return args[0] + return hint + + +def check_yaml_field_types(instance: object, _path: str = "") -> None: + """Validate that str and list[str] fields weren't parsed as dicts by YAML.""" + hints = typing.get_type_hints(type(instance)) + for f in dc_fields(instance): + value = getattr(instance, f.name) + if value is None: + continue + hint = _unwrap_optional(hints.get(f.name, type(None))) + field_path = f"{_path}.{f.name}" if _path else f.name + + if hint is str and isinstance(value, dict): + raise ValueError(f"'{field_path}' should be a string but got a dict. {_YAML_TYPE_HINT}") + if typing.get_origin(hint) is list and typing.get_args(hint) == (str,): + for i, item in enumerate(value): + if isinstance(item, dict): + raise ValueError( + f"'{field_path}[{i}]' should be a string but got a dict. {_YAML_TYPE_HINT}" + ) + if is_dataclass(value) and not isinstance(value, type): + check_yaml_field_types(value, field_path) + @dataclass class ResourceMapping: @@ -363,13 +402,15 @@ def read_local_resource( file_path=file_path, **kwargs, ) - return cls.from_yaml_dict( + instance = cls.from_yaml_dict( yaml_dict, resource_id=resource_id, name=resource_name, resource_mappings=resource_mappings, **kwargs, ) + check_yaml_field_types(instance) + return instance @abstractmethod def to_yaml_dict(self) -> dict: diff --git a/src/poly/tests/resources_test.py b/src/poly/tests/resources_test.py index e3fe3ad9..4161423c 100644 --- a/src/poly/tests/resources_test.py +++ b/src/poly/tests/resources_test.py @@ -66,6 +66,7 @@ MultiResourceYamlResource, ResourceMapping, _parse_multi_resource_path, + check_yaml_field_types, ) from poly.resources.safety_filters import ( ChatSafetyFilters, @@ -7208,9 +7209,7 @@ def test_validate(self): prompts=[], function_calls=[], ), - tags=TestCaseTags( - resource_id="TEST-missing-scenario", name="tags", tags=[] - ), + tags=TestCaseTags(resource_id="TEST-missing-scenario", name="tags", tags=[]), ).validate() self.assertIn("Scenario is required", str(cm.exception)) @@ -7227,9 +7226,7 @@ def test_validate(self): prompts=[], function_calls=[], ), - tags=TestCaseTags( - resource_id="TEST-missing-language", name="tags", tags=[] - ), + tags=TestCaseTags(resource_id="TEST-missing-language", name="tags", tags=[]), ).validate() self.assertIn("Language is required", str(cm.exception)) @@ -7287,9 +7284,7 @@ def test_get_new_updated_deleted_subresources(self): self.assertEqual(deleted_after_edit, []) def test_discover_resources(self): - base_path = os.path.join( - os.path.dirname(__file__), "test_projects", "test_project" - ) + base_path = os.path.join(os.path.dirname(__file__), "test_projects", "test_project") discovered = TestCase.discover_resources(base_path) self.assertCountEqual( discovered, @@ -7299,6 +7294,7 @@ def test_discover_resources(self): ], ) + class ParseMultiResourcePathTests(unittest.TestCase): """Tests for _parse_multi_resource_path including Windows drive-letter handling.""" @@ -7682,9 +7678,7 @@ def test_update_command_type_is_a_real_command_field(self): from poly.handlers.protobuf.commands_pb2 import Command lang = DefaultLanguage(resource_id="en-US", name="en-US") - self.assertEqual( - lang.update_command_type, "languages_update_default_language" - ) + self.assertEqual(lang.update_command_type, "languages_update_default_language") Command(**{lang.update_command_type: lang.build_update_proto()}) def test_validate_duplicate_with_additional_raises(self): @@ -7921,5 +7915,98 @@ def test_unrelated_mappings_only_raises(self): self.assertIn("ChatStylePrompt", str(cm.exception)) +class CheckYamlFieldTypesTest(unittest.TestCase): + """Tests for the check_yaml_field_types YAML colon-space footgun guard.""" + + def test_str_field_with_dict_raises(self): + """A str field that got a dict (colon-space misparse) should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions={"bad key": "value"}, + content="fine", + example_queries=[], + ) + with self.assertRaises(ValueError) as ctx: + check_yaml_field_types(topic) + self.assertIn("actions", str(ctx.exception)) + self.assertIn("should be a string but got a dict", str(ctx.exception)) + + def test_list_str_field_with_dict_item_raises(self): + """A list[str] field containing a dict item should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions="fine", + content="fine", + example_queries=["good", {"bad key": "value"}], + ) + with self.assertRaises(ValueError) as ctx: + check_yaml_field_types(topic) + self.assertIn("example_queries[1]", str(ctx.exception)) + self.assertIn("should be a string but got a dict", str(ctx.exception)) + + def test_nested_dataclass_field_checked(self): + """Nested dataclass str fields should be recursively validated.""" + assertion = TestCaseAssertion( + resource_id="TC-1", + name="assertions", + prompts=["fine", {"the outcome: either": "a specific time"}], + function_calls=[], + ) + test_case = TestCase( + resource_id="TC-1", + name="test", + scenario="test scenario", + channel="chat.polyai", + language="en-GB", + assertions=assertion, + tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), + ) + with self.assertRaises(ValueError) as ctx: + check_yaml_field_types(test_case) + self.assertIn("assertions.prompts[1]", str(ctx.exception)) + + def test_valid_resource_passes(self): + """A correctly typed resource should not raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions="do something", + content="some content", + example_queries=["query 1", "query 2"], + ) + check_yaml_field_types(topic) + + def test_optional_str_field_none_passes(self): + """Optional[str] fields with None should not raise.""" + test_case = TestCase( + resource_id="TC-1", + name="test", + scenario="test scenario", + channel="chat.polyai", + language="en-GB", + assertions=TestCaseAssertion( + resource_id="TC-1", name="assertions", prompts=[], function_calls=[] + ), + tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), + variant=None, + ) + check_yaml_field_types(test_case) + + def test_error_message_includes_hint(self): + """The error message should guide the user to quote the YAML value.""" + personality = SettingsPersonality( + resource_id="P-1", + name="personality", + custom={"the tone: friendly": "and warm"}, + adjectives={}, + ) + with self.assertRaises(ValueError) as ctx: + check_yaml_field_types(personality) + self.assertIn("unquoted", str(ctx.exception)) + self.assertIn("colon", str(ctx.exception)) + + if __name__ == "__main__": unittest.main() From 6c7c4e6d3a1b0ab21768e311fdfcecfd44e89742 Mon Sep 17 00:00:00 2001 From: Ruari Phipps Date: Mon, 15 Jun 2026 17:23:16 +0100 Subject: [PATCH 2/4] fix: also check that list[str] fields are actually lists, not dicts A list[str] field receiving a dict (e.g., prompt_assertions parsed as a mapping) would iterate dict keys as strings, silently passing the check. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/poly/resources/resource.py | 5 +++++ src/poly/tests/resources_test.py | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/poly/resources/resource.py b/src/poly/resources/resource.py index badf30b4..bdd0f141 100644 --- a/src/poly/resources/resource.py +++ b/src/poly/resources/resource.py @@ -43,6 +43,11 @@ def check_yaml_field_types(instance: object, _path: str = "") -> None: if hint is str and isinstance(value, dict): raise ValueError(f"'{field_path}' should be a string but got a dict. {_YAML_TYPE_HINT}") if typing.get_origin(hint) is list and typing.get_args(hint) == (str,): + if not isinstance(value, list): + raise ValueError( + f"'{field_path}' should be a list of strings but got " + f"{type(value).__name__}. {_YAML_TYPE_HINT}" + ) for i, item in enumerate(value): if isinstance(item, dict): raise ValueError( diff --git a/src/poly/tests/resources_test.py b/src/poly/tests/resources_test.py index 4161423c..b6394335 100644 --- a/src/poly/tests/resources_test.py +++ b/src/poly/tests/resources_test.py @@ -7994,6 +7994,28 @@ def test_optional_str_field_none_passes(self): ) check_yaml_field_types(test_case) + def test_list_str_field_with_dict_value_raises(self): + """A list[str] field that got a dict instead of a list should raise.""" + assertion = TestCaseAssertion( + resource_id="TC-1", + name="assertions", + prompts={"assertions": ["It responds with a nice message"]}, + function_calls=[], + ) + test_case = TestCase( + resource_id="TC-1", + name="test", + scenario="test scenario", + channel="chat.polyai", + language="en-GB", + assertions=assertion, + tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), + ) + with self.assertRaises(ValueError) as ctx: + check_yaml_field_types(test_case) + self.assertIn("assertions.prompts", str(ctx.exception)) + self.assertIn("list of strings", str(ctx.exception)) + def test_error_message_includes_hint(self): """The error message should guide the user to quote the YAML value.""" personality = SettingsPersonality( From 983de7b2b443df93920ae1fab73d3daa6f3bdaa1 Mon Sep 17 00:00:00 2001 From: Ruari Phipps Date: Mon, 15 Jun 2026 17:38:29 +0100 Subject: [PATCH 3/4] fix: broaden type validation to all scalars, dicts, and int/float compat - Check all scalar types (str, int, float, bool) not just str - Validate dict[K, V] keys and values when K/V are scalar types - Accept int where float is expected (YAML parses 500 as int) - Reject bool where int/float is expected (yes/no auto-cast) - Validate list[str] field is actually a list, not a dict - Remove verbose YAML hint from error messages Co-Authored-By: Claude Opus 4.6 (1M context) --- src/poly/resources/resource.py | 57 ++++++++++++++++----- src/poly/tests/resources_test.py | 85 ++++++++++++++++++++++++++++---- 2 files changed, 120 insertions(+), 22 deletions(-) diff --git a/src/poly/resources/resource.py b/src/poly/resources/resource.py index bdd0f141..d6598d40 100644 --- a/src/poly/resources/resource.py +++ b/src/poly/resources/resource.py @@ -14,11 +14,6 @@ import poly.resources.resource_utils as utils -_YAML_TYPE_HINT = ( - "A YAML value was likely parsed as the wrong type. " - "Check for an unquoted mid-sentence colon or a leading [ { * & ? and quote the value." -) - def _unwrap_optional(hint: type) -> type: """Strip Optional[X] / X | None down to X.""" @@ -30,8 +25,25 @@ def _unwrap_optional(hint: type) -> type: return hint +_SCALAR_TYPES = (str, int, float, bool) + + +def _matches_scalar(value: object, expected: type) -> bool: + """Check if a value matches a scalar type, with YAML-aware compatibility. + + bool is excluded from int/float checks because YAML's yes/no auto-cast + should not silently pass as a number. int is accepted for float because + YAML parses 500 as int but 500.0 as float — both are valid numerics. + """ + if expected is float: + return isinstance(value, (int, float)) and not isinstance(value, bool) + if expected is int: + return isinstance(value, int) and not isinstance(value, bool) + return isinstance(value, expected) + + def check_yaml_field_types(instance: object, _path: str = "") -> None: - """Validate that str and list[str] fields weren't parsed as dicts by YAML.""" + """Validate that scalar and list[scalar] fields have the correct type after YAML parsing.""" hints = typing.get_type_hints(type(instance)) for f in dc_fields(instance): value = getattr(instance, f.name) @@ -40,18 +52,37 @@ def check_yaml_field_types(instance: object, _path: str = "") -> None: hint = _unwrap_optional(hints.get(f.name, type(None))) field_path = f"{_path}.{f.name}" if _path else f.name - if hint is str and isinstance(value, dict): - raise ValueError(f"'{field_path}' should be a string but got a dict. {_YAML_TYPE_HINT}") - if typing.get_origin(hint) is list and typing.get_args(hint) == (str,): + if hint in _SCALAR_TYPES and not _matches_scalar(value, hint): + raise ValueError( + f"'{field_path}' should be {hint.__name__} but got {type(value).__name__}." + ) + list_args = typing.get_args(hint) if typing.get_origin(hint) is list else () + if list_args and list_args[0] in _SCALAR_TYPES: + expected_item_type = list_args[0] if not isinstance(value, list): raise ValueError( - f"'{field_path}' should be a list of strings but got " - f"{type(value).__name__}. {_YAML_TYPE_HINT}" + f"'{field_path}' should be a list of {expected_item_type.__name__} " + f"but got {type(value).__name__}." ) for i, item in enumerate(value): - if isinstance(item, dict): + if not _matches_scalar(item, expected_item_type): + raise ValueError( + f"'{field_path}[{i}]' should be {expected_item_type.__name__} " + f"but got {type(item).__name__}." + ) + dict_args = typing.get_args(hint) if typing.get_origin(hint) is dict else () + if len(dict_args) == 2 and isinstance(value, dict): + key_type, val_type = dict_args + for k, v in value.items(): + if key_type in _SCALAR_TYPES and not _matches_scalar(k, key_type): + raise ValueError( + f"'{field_path}' has key {k!r} which should be {key_type.__name__} " + f"but got {type(k).__name__}." + ) + if val_type in _SCALAR_TYPES and not _matches_scalar(v, val_type): raise ValueError( - f"'{field_path}[{i}]' should be a string but got a dict. {_YAML_TYPE_HINT}" + f"'{field_path}[{k!r}]' should be {val_type.__name__} " + f"but got {type(v).__name__}." ) if is_dataclass(value) and not isinstance(value, type): check_yaml_field_types(value, field_path) diff --git a/src/poly/tests/resources_test.py b/src/poly/tests/resources_test.py index b6394335..e6f2aef1 100644 --- a/src/poly/tests/resources_test.py +++ b/src/poly/tests/resources_test.py @@ -65,6 +65,7 @@ from poly.resources.resource import ( MultiResourceYamlResource, ResourceMapping, + _matches_scalar, _parse_multi_resource_path, check_yaml_field_types, ) @@ -7930,7 +7931,7 @@ def test_str_field_with_dict_raises(self): with self.assertRaises(ValueError) as ctx: check_yaml_field_types(topic) self.assertIn("actions", str(ctx.exception)) - self.assertIn("should be a string but got a dict", str(ctx.exception)) + self.assertIn("should be str but got dict", str(ctx.exception)) def test_list_str_field_with_dict_item_raises(self): """A list[str] field containing a dict item should raise.""" @@ -7944,7 +7945,7 @@ def test_list_str_field_with_dict_item_raises(self): with self.assertRaises(ValueError) as ctx: check_yaml_field_types(topic) self.assertIn("example_queries[1]", str(ctx.exception)) - self.assertIn("should be a string but got a dict", str(ctx.exception)) + self.assertIn("should be str but got dict", str(ctx.exception)) def test_nested_dataclass_field_checked(self): """Nested dataclass str fields should be recursively validated.""" @@ -8014,20 +8015,86 @@ def test_list_str_field_with_dict_value_raises(self): with self.assertRaises(ValueError) as ctx: check_yaml_field_types(test_case) self.assertIn("assertions.prompts", str(ctx.exception)) - self.assertIn("list of strings", str(ctx.exception)) + self.assertIn("list of str", str(ctx.exception)) - def test_error_message_includes_hint(self): - """The error message should guide the user to quote the YAML value.""" + def test_str_field_with_bool_raises(self): + """A str field that got a bool (yes/no auto-cast) should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions=True, + content="fine", + example_queries=[], + ) + with self.assertRaises(ValueError) as ctx: + check_yaml_field_types(topic) + self.assertIn("actions", str(ctx.exception)) + self.assertIn("bool", str(ctx.exception)) + + def test_str_field_with_int_raises(self): + """A str field that got an int (bare number) should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions=42, + content="fine", + example_queries=[], + ) + with self.assertRaises(ValueError) as ctx: + check_yaml_field_types(topic) + self.assertIn("actions", str(ctx.exception)) + self.assertIn("int", str(ctx.exception)) + + def test_list_str_field_with_bool_item_raises(self): + """A list[str] field containing a bool item should raise.""" + topic = Topic( + resource_id="TOPIC-1", + name="test", + actions="fine", + content="fine", + example_queries=["good", True], + ) + with self.assertRaises(ValueError) as ctx: + check_yaml_field_types(topic) + self.assertIn("example_queries[1]", str(ctx.exception)) + self.assertIn("bool", str(ctx.exception)) + + def test_dict_value_wrong_type_raises(self): + """A dict[str, bool] field with a str value instead of bool should raise.""" personality = SettingsPersonality( resource_id="P-1", name="personality", - custom={"the tone: friendly": "and warm"}, - adjectives={}, + custom="fine", + adjectives={"Polite": "sure"}, ) with self.assertRaises(ValueError) as ctx: check_yaml_field_types(personality) - self.assertIn("unquoted", str(ctx.exception)) - self.assertIn("colon", str(ctx.exception)) + self.assertIn("adjectives", str(ctx.exception)) + self.assertIn("should be bool but got str", str(ctx.exception)) + + def test_dict_valid_types_passes(self): + """A dict[str, bool] field with correct types should not raise.""" + personality = SettingsPersonality( + resource_id="P-1", + name="personality", + custom="fine", + adjectives={"Polite": True, "Calm": False}, + ) + check_yaml_field_types(personality) + + def test_int_accepted_for_float_field(self): + """int values should be accepted where float is expected (YAML parses 500 as int).""" + self.assertTrue(_matches_scalar(42, float)) + self.assertTrue(_matches_scalar(3.14, float)) + + def test_bool_rejected_for_int_field(self): + """bool should not pass as int even though bool is a subclass of int in Python.""" + self.assertFalse(_matches_scalar(True, int)) + self.assertFalse(_matches_scalar(False, int)) + + def test_bool_rejected_for_float_field(self): + """bool should not pass as float either.""" + self.assertFalse(_matches_scalar(True, float)) if __name__ == "__main__": From 05fbf6a1a81b239d4a1cdf8c77ad422421ecf581 Mon Sep 17 00:00:00 2001 From: Ruari Phipps Date: Tue, 16 Jun 2026 10:45:26 +0100 Subject: [PATCH 4/4] Move to resource utils --- src/poly/resources/flows.py | 3 +- src/poly/resources/pronunciation.py | 3 +- src/poly/resources/resource.py | 79 +--------------------------- src/poly/resources/resource_utils.py | 78 ++++++++++++++++++++++++++- src/poly/resources/safety_filters.py | 2 +- src/poly/tests/project_test.py | 2 +- src/poly/tests/resources_test.py | 65 ++++++++++++++--------- src/poly/utils.py | 7 ++- 8 files changed, 125 insertions(+), 114 deletions(-) diff --git a/src/poly/resources/flows.py b/src/poly/resources/flows.py index 0957472f..8ae27cbe 100644 --- a/src/poly/resources/flows.py +++ b/src/poly/resources/flows.py @@ -55,7 +55,6 @@ ResourceMapping, SubResource, YamlResource, - check_yaml_field_types, ) FUNCTION_REGEX = re.compile(r"{{f[nt]:([\w-]+)}}") @@ -646,7 +645,7 @@ def read_local_resource( known_position=known_position, resource_mappings=resource_mappings, ) - check_yaml_field_types(instance) + utils.check_yaml_field_types(instance) return instance def validate(self, resource_mappings: list[ResourceMapping] = None, **kwargs): diff --git a/src/poly/resources/pronunciation.py b/src/poly/resources/pronunciation.py index b66e0f7c..38bca602 100644 --- a/src/poly/resources/pronunciation.py +++ b/src/poly/resources/pronunciation.py @@ -17,7 +17,6 @@ MultiResourceYamlResource, ResourceMapping, _parse_multi_resource_path, - check_yaml_field_types, ) @@ -149,7 +148,7 @@ def read_local_resource( instance = cls.from_yaml_dict( yaml_dict, resource_id=resource_id, name="", position=position, **kwargs ) - check_yaml_field_types(instance) + utils.check_yaml_field_types(instance) return instance @property diff --git a/src/poly/resources/resource.py b/src/poly/resources/resource.py index d6598d40..a67c7609 100644 --- a/src/poly/resources/resource.py +++ b/src/poly/resources/resource.py @@ -4,10 +4,8 @@ """ import os -import types -import typing from abc import ABC, abstractmethod -from dataclasses import dataclass, fields as dc_fields, is_dataclass +from dataclasses import dataclass from typing import ClassVar, Optional from google.protobuf.message import Message @@ -15,79 +13,6 @@ import poly.resources.resource_utils as utils -def _unwrap_optional(hint: type) -> type: - """Strip Optional[X] / X | None down to X.""" - origin = typing.get_origin(hint) - if origin is typing.Union or origin is types.UnionType: - args = [a for a in typing.get_args(hint) if a is not type(None)] - if len(args) == 1: - return args[0] - return hint - - -_SCALAR_TYPES = (str, int, float, bool) - - -def _matches_scalar(value: object, expected: type) -> bool: - """Check if a value matches a scalar type, with YAML-aware compatibility. - - bool is excluded from int/float checks because YAML's yes/no auto-cast - should not silently pass as a number. int is accepted for float because - YAML parses 500 as int but 500.0 as float — both are valid numerics. - """ - if expected is float: - return isinstance(value, (int, float)) and not isinstance(value, bool) - if expected is int: - return isinstance(value, int) and not isinstance(value, bool) - return isinstance(value, expected) - - -def check_yaml_field_types(instance: object, _path: str = "") -> None: - """Validate that scalar and list[scalar] fields have the correct type after YAML parsing.""" - hints = typing.get_type_hints(type(instance)) - for f in dc_fields(instance): - value = getattr(instance, f.name) - if value is None: - continue - hint = _unwrap_optional(hints.get(f.name, type(None))) - field_path = f"{_path}.{f.name}" if _path else f.name - - if hint in _SCALAR_TYPES and not _matches_scalar(value, hint): - raise ValueError( - f"'{field_path}' should be {hint.__name__} but got {type(value).__name__}." - ) - list_args = typing.get_args(hint) if typing.get_origin(hint) is list else () - if list_args and list_args[0] in _SCALAR_TYPES: - expected_item_type = list_args[0] - if not isinstance(value, list): - raise ValueError( - f"'{field_path}' should be a list of {expected_item_type.__name__} " - f"but got {type(value).__name__}." - ) - for i, item in enumerate(value): - if not _matches_scalar(item, expected_item_type): - raise ValueError( - f"'{field_path}[{i}]' should be {expected_item_type.__name__} " - f"but got {type(item).__name__}." - ) - dict_args = typing.get_args(hint) if typing.get_origin(hint) is dict else () - if len(dict_args) == 2 and isinstance(value, dict): - key_type, val_type = dict_args - for k, v in value.items(): - if key_type in _SCALAR_TYPES and not _matches_scalar(k, key_type): - raise ValueError( - f"'{field_path}' has key {k!r} which should be {key_type.__name__} " - f"but got {type(k).__name__}." - ) - if val_type in _SCALAR_TYPES and not _matches_scalar(v, val_type): - raise ValueError( - f"'{field_path}[{k!r}]' should be {val_type.__name__} " - f"but got {type(v).__name__}." - ) - if is_dataclass(value) and not isinstance(value, type): - check_yaml_field_types(value, field_path) - - @dataclass class ResourceMapping: """Data class to hold resource mapping information.""" @@ -445,7 +370,7 @@ def read_local_resource( resource_mappings=resource_mappings, **kwargs, ) - check_yaml_field_types(instance) + utils.check_yaml_field_types(instance) return instance @abstractmethod diff --git a/src/poly/resources/resource_utils.py b/src/poly/resources/resource_utils.py index bf5fce64..a30e0f99 100644 --- a/src/poly/resources/resource_utils.py +++ b/src/poly/resources/resource_utils.py @@ -13,10 +13,13 @@ import re import subprocess import sys +import types +import typing +from dataclasses import fields, is_dataclass from difflib import unified_diff from enum import Enum from io import StringIO -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import langcodes import ruamel.yaml as yaml @@ -564,6 +567,79 @@ def is_valid_language_code(code: str) -> bool: return langcodes.tag_is_valid(code) +def _unwrap_optional(hint: type) -> type: + """Strip Optional[X] / X | None down to X.""" + origin = typing.get_origin(hint) + if origin is Union or origin is types.UnionType: + args = [a for a in typing.get_args(hint) if a is not type(None)] + if len(args) == 1: + return args[0] + return hint + + +_SCALAR_TYPES = (str, int, float, bool) + + +def _matches_scalar(value: object, expected: type) -> bool: + """Check if a value matches a scalar type, with YAML-aware compatibility. + + bool is excluded from int/float checks because YAML's yes/no auto-cast + should not silently pass as a number. int is accepted for float because + YAML parses 500 as int but 500.0 as float — both are valid numerics. + """ + if expected is float: + return isinstance(value, (int, float)) and not isinstance(value, bool) + if expected is int: + return isinstance(value, int) and not isinstance(value, bool) + return isinstance(value, expected) + + +def check_yaml_field_types(instance: object, _path: str = "") -> None: + """Validate that scalar and list[scalar] fields have the correct type after YAML parsing.""" + hints = typing.get_type_hints(type(instance)) + for f in fields(instance): + value = getattr(instance, f.name) + if value is None: + continue + hint = _unwrap_optional(hints.get(f.name, type(None))) + field_path = f"{_path}.{f.name}" if _path else f.name + + if hint in _SCALAR_TYPES and not _matches_scalar(value, hint): + raise ValueError( + f"'{field_path}' should be {hint.__name__} but got {type(value).__name__}." + ) + list_args = typing.get_args(hint) if typing.get_origin(hint) is list else () + if list_args and list_args[0] in _SCALAR_TYPES: + expected_item_type = list_args[0] + if not isinstance(value, list): + raise ValueError( + f"'{field_path}' should be a list of {expected_item_type.__name__} " + f"but got {type(value).__name__}." + ) + for i, item in enumerate(value): + if not _matches_scalar(item, expected_item_type): + raise ValueError( + f"'{field_path}[{i}]' should be {expected_item_type.__name__} " + f"but got {type(item).__name__}." + ) + dict_args = typing.get_args(hint) if typing.get_origin(hint) is dict else () + if len(dict_args) == 2 and isinstance(value, dict): + key_type, val_type = dict_args + for k, v in value.items(): + if key_type in _SCALAR_TYPES and not _matches_scalar(k, key_type): + raise ValueError( + f"'{field_path}' has key {k!r} which should be {key_type.__name__} " + f"but got {type(k).__name__}." + ) + if val_type in _SCALAR_TYPES and not _matches_scalar(v, val_type): + raise ValueError( + f"'{field_path}[{k!r}]' should be {val_type.__name__} " + f"but got {type(v).__name__}." + ) + if is_dataclass(value) and not isinstance(value, type): + check_yaml_field_types(value, field_path) + + def assign_flow_positions( nodes: list["BaseFlowStep"], start_node_id: str, diff --git a/src/poly/resources/safety_filters.py b/src/poly/resources/safety_filters.py index b8b5ee51..c46b3f3f 100644 --- a/src/poly/resources/safety_filters.py +++ b/src/poly/resources/safety_filters.py @@ -9,13 +9,13 @@ from google.protobuf.message import Message +import poly.resources.resource_utils as utils from poly.handlers.protobuf.channels_pb2 import Channel_UpdateSafetyFilters, ChannelType from poly.handlers.protobuf.content_filter_settings_pb2 import ( AzureContentFilter, AzureContentFilterCategory, ContentFilterSettings_UpdateContentFilterSettings, ) -import poly.resources.resource_utils as utils from poly.resources.resource import ResourceMapping, YamlResource PRECISION_MAPPING = {"LOOSE": "lenient", "MEDIUM": "medium", "STRICT": "strict"} diff --git a/src/poly/tests/project_test.py b/src/poly/tests/project_test.py index 89929e37..e8b83633 100644 --- a/src/poly/tests/project_test.py +++ b/src/poly/tests/project_test.py @@ -32,10 +32,10 @@ SettingsRole, SettingsRules, SMSTemplate, - Topic, TestCase, TestCaseAssertion, TestCaseTags, + Topic, TranscriptCorrection, Translation, Variable, diff --git a/src/poly/tests/resources_test.py b/src/poly/tests/resources_test.py index e6f2aef1..bef98f77 100644 --- a/src/poly/tests/resources_test.py +++ b/src/poly/tests/resources_test.py @@ -7,10 +7,9 @@ import unittest import yaml - -import poly.resources.resource_utils as resource_utils from jsonschema import ValidationError +import poly.resources.resource_utils as resource_utils from poly.handlers.sync_client import SyncClientHandler from poly.resources.agent_settings import ( SettingsPersonality, @@ -53,7 +52,6 @@ FunctionParameters, FunctionType, ) - from poly.resources.handoff import Handoff from poly.resources.keyphrase_boosting import KeyphraseBoosting from poly.resources.languages import ( @@ -65,9 +63,7 @@ from poly.resources.resource import ( MultiResourceYamlResource, ResourceMapping, - _matches_scalar, _parse_multi_resource_path, - check_yaml_field_types, ) from poly.resources.safety_filters import ( ChatSafetyFilters, @@ -76,11 +72,6 @@ VoiceSafetyFilters, ) from poly.resources.sms import EnvPhoneNumbers, SMSTemplate -from poly.resources.topic import ( - FUNCTION_REGEX, - Topic, -) -from poly.resources.transcript_correction import RegularExpressionRule, TranscriptCorrection from poly.resources.test_suite import ( FunctionCallArgumentAssertion, FunctionCallAssertion, @@ -88,6 +79,11 @@ TestCaseAssertion, TestCaseTags, ) +from poly.resources.topic import ( + FUNCTION_REGEX, + Topic, +) +from poly.resources.transcript_correction import RegularExpressionRule, TranscriptCorrection from poly.resources.translations import Translation from poly.resources.variable import Variable from poly.resources.variant_attributes import Variant, VariantAttribute @@ -7929,7 +7925,7 @@ def test_str_field_with_dict_raises(self): example_queries=[], ) with self.assertRaises(ValueError) as ctx: - check_yaml_field_types(topic) + resource_utils.check_yaml_field_types(topic) self.assertIn("actions", str(ctx.exception)) self.assertIn("should be str but got dict", str(ctx.exception)) @@ -7943,7 +7939,7 @@ def test_list_str_field_with_dict_item_raises(self): example_queries=["good", {"bad key": "value"}], ) with self.assertRaises(ValueError) as ctx: - check_yaml_field_types(topic) + resource_utils.check_yaml_field_types(topic) self.assertIn("example_queries[1]", str(ctx.exception)) self.assertIn("should be str but got dict", str(ctx.exception)) @@ -7965,7 +7961,7 @@ def test_nested_dataclass_field_checked(self): tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), ) with self.assertRaises(ValueError) as ctx: - check_yaml_field_types(test_case) + resource_utils.check_yaml_field_types(test_case) self.assertIn("assertions.prompts[1]", str(ctx.exception)) def test_valid_resource_passes(self): @@ -7977,7 +7973,7 @@ def test_valid_resource_passes(self): content="some content", example_queries=["query 1", "query 2"], ) - check_yaml_field_types(topic) + resource_utils.check_yaml_field_types(topic) def test_optional_str_field_none_passes(self): """Optional[str] fields with None should not raise.""" @@ -7993,7 +7989,7 @@ def test_optional_str_field_none_passes(self): tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), variant=None, ) - check_yaml_field_types(test_case) + resource_utils.check_yaml_field_types(test_case) def test_list_str_field_with_dict_value_raises(self): """A list[str] field that got a dict instead of a list should raise.""" @@ -8013,7 +8009,7 @@ def test_list_str_field_with_dict_value_raises(self): tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), ) with self.assertRaises(ValueError) as ctx: - check_yaml_field_types(test_case) + resource_utils.check_yaml_field_types(test_case) self.assertIn("assertions.prompts", str(ctx.exception)) self.assertIn("list of str", str(ctx.exception)) @@ -8027,7 +8023,7 @@ def test_str_field_with_bool_raises(self): example_queries=[], ) with self.assertRaises(ValueError) as ctx: - check_yaml_field_types(topic) + resource_utils.check_yaml_field_types(topic) self.assertIn("actions", str(ctx.exception)) self.assertIn("bool", str(ctx.exception)) @@ -8041,7 +8037,7 @@ def test_str_field_with_int_raises(self): example_queries=[], ) with self.assertRaises(ValueError) as ctx: - check_yaml_field_types(topic) + resource_utils.check_yaml_field_types(topic) self.assertIn("actions", str(ctx.exception)) self.assertIn("int", str(ctx.exception)) @@ -8055,7 +8051,7 @@ def test_list_str_field_with_bool_item_raises(self): example_queries=["good", True], ) with self.assertRaises(ValueError) as ctx: - check_yaml_field_types(topic) + resource_utils.check_yaml_field_types(topic) self.assertIn("example_queries[1]", str(ctx.exception)) self.assertIn("bool", str(ctx.exception)) @@ -8068,7 +8064,7 @@ def test_dict_value_wrong_type_raises(self): adjectives={"Polite": "sure"}, ) with self.assertRaises(ValueError) as ctx: - check_yaml_field_types(personality) + resource_utils.check_yaml_field_types(personality) self.assertIn("adjectives", str(ctx.exception)) self.assertIn("should be bool but got str", str(ctx.exception)) @@ -8080,21 +8076,38 @@ def test_dict_valid_types_passes(self): custom="fine", adjectives={"Polite": True, "Calm": False}, ) - check_yaml_field_types(personality) + resource_utils.check_yaml_field_types(personality) def test_int_accepted_for_float_field(self): """int values should be accepted where float is expected (YAML parses 500 as int).""" - self.assertTrue(_matches_scalar(42, float)) - self.assertTrue(_matches_scalar(3.14, float)) + self.assertTrue(resource_utils._matches_scalar(42, float)) + self.assertTrue(resource_utils._matches_scalar(3.14, float)) def test_bool_rejected_for_int_field(self): """bool should not pass as int even though bool is a subclass of int in Python.""" - self.assertFalse(_matches_scalar(True, int)) - self.assertFalse(_matches_scalar(False, int)) + self.assertFalse(resource_utils._matches_scalar(True, int)) + self.assertFalse(resource_utils._matches_scalar(False, int)) def test_bool_rejected_for_float_field(self): """bool should not pass as float either.""" - self.assertFalse(_matches_scalar(True, float)) + self.assertFalse(resource_utils._matches_scalar(True, float)) + self.assertFalse(resource_utils._matches_scalar(False, float)) + + def test_none_on_optional_field_passes(self): + """None in an Optional[str] field should not raise.""" + test_case = TestCase( + resource_id="TC-1", + name="test", + scenario="test scenario", + channel="chat.polyai", + language="en-GB", + assertions=TestCaseAssertion( + resource_id="TC-1", name="assertions", prompts=[], function_calls=[] + ), + tags=TestCaseTags(resource_id="TC-1", name="tags", tags=[]), + variant=None, + ) + resource_utils.check_yaml_field_types(test_case) if __name__ == "__main__": diff --git a/src/poly/utils.py b/src/poly/utils.py index 2cda34bf..5f0c9db5 100644 --- a/src/poly/utils.py +++ b/src/poly/utils.py @@ -13,14 +13,13 @@ import re from typing import Callable, Optional -from poly.resources import Function, FunctionStep, Resource, ResourceMapping - -from poly.handlers.protobuf.commands_pb2 import Command from poly.handlers.protobuf.channels_pb2 import ( Channel_UpdateStatus, - WebChatChannel_UpdateStatus, ChannelStatus, + WebChatChannel_UpdateStatus, ) +from poly.handlers.protobuf.commands_pb2 import Command +from poly.resources import Function, FunctionStep, Resource, ResourceMapping logger = logging.getLogger(__name__)