Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/poly/resources/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def read_local_resource(
# Get file name from file_path
file_name = os.path.splitext(os.path.basename(file_path))[0]

return cls.from_yaml_dict(
instance = cls.from_yaml_dict(
yaml_dict,
resource_id=resource_id,
file_name=file_name,
Expand All @@ -645,6 +645,8 @@ def read_local_resource(
known_position=known_position,
resource_mappings=resource_mappings,
)
utils.check_yaml_field_types(instance)
return instance

def validate(self, resource_mappings: list[ResourceMapping] = None, **kwargs):
"""Validate the flow step resource."""
Expand Down
4 changes: 3 additions & 1 deletion src/poly/resources/pronunciation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ def read_local_resource(
f"Resource with name {resource_clean_name} not found in {true_file_path}"
)

return cls.from_yaml_dict(
instance = cls.from_yaml_dict(
yaml_dict, resource_id=resource_id, name="", position=position, **kwargs
)
utils.check_yaml_field_types(instance)
return instance

@property
def command_type(self) -> str:
Expand Down
4 changes: 3 additions & 1 deletion src/poly/resources/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,15 @@ def read_local_resource(
file_path=file_path,
**kwargs,
)
return cls.from_yaml_dict(
instance = cls.from_yaml_dict(
yaml_dict,
resource_id=resource_id,
name=resource_name,
resource_mappings=resource_mappings,
**kwargs,
)
utils.check_yaml_field_types(instance)
return instance

@abstractmethod
def to_yaml_dict(self) -> dict:
Expand Down
78 changes: 77 additions & 1 deletion src/poly/resources/resource_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
import re
import subprocess
import sys
import types
import typing
from dataclasses import fields, is_dataclass
from difflib import unified_diff
from enum import Enum
from io import StringIO
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

import langcodes
import ruamel.yaml as yaml
Expand Down Expand Up @@ -564,6 +567,79 @@ def is_valid_language_code(code: str) -> bool:
return langcodes.tag_is_valid(code)


def _unwrap_optional(hint: type) -> type:
"""Strip Optional[X] / X | None down to X."""
origin = typing.get_origin(hint)
if origin is Union or origin is types.UnionType:
args = [a for a in typing.get_args(hint) if a is not type(None)]
if len(args) == 1:
return args[0]
return hint


_SCALAR_TYPES = (str, int, float, bool)


def _matches_scalar(value: object, expected: type) -> bool:
"""Check if a value matches a scalar type, with YAML-aware compatibility.

bool is excluded from int/float checks because YAML's yes/no auto-cast
should not silently pass as a number. int is accepted for float because
YAML parses 500 as int but 500.0 as float — both are valid numerics.
"""
if expected is float:
return isinstance(value, (int, float)) and not isinstance(value, bool)
if expected is int:
return isinstance(value, int) and not isinstance(value, bool)
return isinstance(value, expected)


def check_yaml_field_types(instance: object, _path: str = "") -> None:
"""Validate that scalar and list[scalar] fields have the correct type after YAML parsing."""
hints = typing.get_type_hints(type(instance))
for f in fields(instance):
value = getattr(instance, f.name)
if value is None:
continue
hint = _unwrap_optional(hints.get(f.name, type(None)))
field_path = f"{_path}.{f.name}" if _path else f.name

if hint in _SCALAR_TYPES and not _matches_scalar(value, hint):
raise ValueError(
f"'{field_path}' should be {hint.__name__} but got {type(value).__name__}."
)
list_args = typing.get_args(hint) if typing.get_origin(hint) is list else ()
if list_args and list_args[0] in _SCALAR_TYPES:
expected_item_type = list_args[0]
if not isinstance(value, list):
raise ValueError(
f"'{field_path}' should be a list of {expected_item_type.__name__} "
f"but got {type(value).__name__}."
)
for i, item in enumerate(value):
if not _matches_scalar(item, expected_item_type):
raise ValueError(
f"'{field_path}[{i}]' should be {expected_item_type.__name__} "
f"but got {type(item).__name__}."
)
dict_args = typing.get_args(hint) if typing.get_origin(hint) is dict else ()
if len(dict_args) == 2 and isinstance(value, dict):
key_type, val_type = dict_args
for k, v in value.items():
if key_type in _SCALAR_TYPES and not _matches_scalar(k, key_type):
raise ValueError(
f"'{field_path}' has key {k!r} which should be {key_type.__name__} "
f"but got {type(k).__name__}."
)
if val_type in _SCALAR_TYPES and not _matches_scalar(v, val_type):
raise ValueError(
f"'{field_path}[{k!r}]' should be {val_type.__name__} "
f"but got {type(v).__name__}."
)
if is_dataclass(value) and not isinstance(value, type):
check_yaml_field_types(value, field_path)


def assign_flow_positions(
nodes: list["BaseFlowStep"],
start_node_id: str,
Expand Down
2 changes: 1 addition & 1 deletion src/poly/resources/safety_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

from google.protobuf.message import Message

import poly.resources.resource_utils as utils
from poly.handlers.protobuf.channels_pb2 import Channel_UpdateSafetyFilters, ChannelType
from poly.handlers.protobuf.content_filter_settings_pb2 import (
AzureContentFilter,
AzureContentFilterCategory,
ContentFilterSettings_UpdateContentFilterSettings,
)
import poly.resources.resource_utils as utils
from poly.resources.resource import ResourceMapping, YamlResource

PRECISION_MAPPING = {"LOOSE": "lenient", "MEDIUM": "medium", "STRICT": "strict"}
Expand Down
2 changes: 1 addition & 1 deletion src/poly/tests/project_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
SettingsRole,
SettingsRules,
SMSTemplate,
Topic,
TestCase,
TestCaseAssertion,
TestCaseTags,
Topic,
TranscriptCorrection,
Translation,
Variable,
Expand Down
Loading
Loading