-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Flatten allOf properties for OpenAI compatibility #3451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 17 commits
afddade
8ff4db7
e4aebaf
e5815fd
b1058e9
1244bfd
f143029
eaa685e
cea2b10
27192da
756bdc4
6b47179
6114437
16e33e8
ff02ed0
71c1434
83cbd04
60b1db5
98fbc43
169e6bc
35106a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,14 +2,17 @@ | |
|
|
||
| import re | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Callable | ||
| from copy import deepcopy | ||
| from dataclasses import dataclass | ||
| from typing import Any, Literal | ||
| from typing import Any, Literal, cast | ||
|
|
||
| from .exceptions import UserError | ||
|
|
||
| JsonSchema = dict[str, Any] | ||
|
|
||
| __all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer'] | ||
|
|
||
|
|
||
| @dataclass(init=False) | ||
| class JsonSchemaTransformer(ABC): | ||
|
|
@@ -26,14 +29,16 @@ def __init__( | |
| strict: bool | None = None, | ||
| prefer_inlined_defs: bool = False, | ||
| simplify_nullable_unions: bool = False, # TODO (v2): Remove this, no longer used | ||
| flatten_allof: bool = False, | ||
| ): | ||
| self.schema = schema | ||
|
|
||
| self.strict = strict | ||
| self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly | ||
| self.strict = strict # Can be set to False by subclasses to set `strict` on `ToolDefinition `when not set explicitly by the user. | ||
| self.is_strict_compatible = True | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| self.prefer_inlined_defs = prefer_inlined_defs | ||
| self.simplify_nullable_unions = simplify_nullable_unions | ||
| self.flatten_allof = flatten_allof | ||
|
|
||
| self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {}) | ||
| self.refs_stack: list[str] = [] | ||
|
|
@@ -73,6 +78,10 @@ def walk(self) -> JsonSchema: | |
| return handled | ||
|
|
||
| def _handle(self, schema: JsonSchema) -> JsonSchema: | ||
| # Flatten allOf if requested, before processing the schema | ||
| if self.flatten_allof: | ||
| schema = _recurse_flatten_allof(schema) | ||
|
|
||
| nested_refs = 0 | ||
| if self.prefer_inlined_defs: | ||
| while ref := schema.get('$ref'): | ||
|
|
@@ -109,24 +118,7 @@ def _handle(self, schema: JsonSchema) -> JsonSchema: | |
| return schema | ||
|
|
||
| def _handle_object(self, schema: JsonSchema) -> JsonSchema: | ||
| if properties := schema.get('properties'): | ||
| handled_properties = {} | ||
| for key, value in properties.items(): | ||
| handled_properties[key] = self._handle(value) | ||
| schema['properties'] = handled_properties | ||
|
|
||
| if (additional_properties := schema.get('additionalProperties')) is not None: | ||
| if isinstance(additional_properties, bool): | ||
| schema['additionalProperties'] = additional_properties | ||
| else: | ||
| schema['additionalProperties'] = self._handle(additional_properties) | ||
|
|
||
| if (pattern_properties := schema.get('patternProperties')) is not None: | ||
| handled_pattern_properties = {} | ||
| for key, value in pattern_properties.items(): | ||
| handled_pattern_properties[key] = self._handle(value) | ||
| schema['patternProperties'] = handled_pattern_properties | ||
|
|
||
| _process_object_nested_schemas(schema, self._handle) | ||
| return schema | ||
|
|
||
| def _handle_array(self, schema: JsonSchema) -> JsonSchema: | ||
|
|
@@ -187,3 +179,289 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None): | |
|
|
||
| def transform(self, schema: JsonSchema) -> JsonSchema: | ||
| return schema | ||
|
|
||
|
|
||
| def _get_type_set(schema: JsonSchema) -> set[str] | None: | ||
| """Extract type(s) from a schema as a set of strings.""" | ||
| schema_type = schema.get('type') | ||
| if isinstance(schema_type, list): | ||
| return {str(t) for t in cast(list[Any], schema_type)} | ||
| if isinstance(schema_type, str): | ||
| return {schema_type} | ||
| return None | ||
|
|
||
|
|
||
| def _process_object_nested_schemas(schema: JsonSchema, process_fn: Callable[[JsonSchema], JsonSchema]) -> None: | ||
| """Process nested schemas in an object schema (properties, additionalProperties, patternProperties). | ||
|
|
||
| Args: | ||
| schema: The object schema to process (modified in place) | ||
| process_fn: Function to apply to each nested schema | ||
| """ | ||
| if properties := schema.get('properties'): | ||
| if isinstance(properties, dict): | ||
| properties_dict = cast(dict[str, Any], properties) | ||
| schema['properties'] = { | ||
| k: process_fn(cast(JsonSchema, v)) if isinstance(v, dict) else v for k, v in properties_dict.items() | ||
| } | ||
|
|
||
| if (additional_properties := schema.get('additionalProperties')) is not None: | ||
| if isinstance(additional_properties, dict): | ||
| schema['additionalProperties'] = process_fn(cast(JsonSchema, additional_properties)) | ||
| # If it's a bool, leave it as is | ||
|
|
||
| if pattern_properties := schema.get('patternProperties'): | ||
| if isinstance(pattern_properties, dict): | ||
| pattern_properties_dict = cast(dict[str, Any], pattern_properties) | ||
| schema['patternProperties'] = { | ||
| k: process_fn(cast(JsonSchema, v)) if isinstance(v, dict) else v | ||
| for k, v in pattern_properties_dict.items() | ||
| } | ||
|
|
||
|
|
||
| def _process_nested_schemas_without_allof(s: JsonSchema) -> JsonSchema: | ||
| """Process nested schemas recursively when there is no allOf at the current level.""" | ||
| schema_type = s.get('type') | ||
| if schema_type == 'object': | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| _process_object_nested_schemas(s, _recurse_flatten_allof) | ||
| elif schema_type == 'array': | ||
| if isinstance(s.get('items'), dict): | ||
| s['items'] = _recurse_flatten_allof(cast(JsonSchema, s['items'])) | ||
| return s | ||
|
|
||
|
|
||
| def _collect_base_schema_data( | ||
| result: JsonSchema, | ||
| ) -> tuple[dict[str, JsonSchema], set[str], dict[str, JsonSchema], list[Any], list[set[str]]]: | ||
| """Collect data from base schema: properties, required, patternProperties, additionalProperties.""" | ||
| properties: dict[str, JsonSchema] = {} | ||
| required: set[str] = set() | ||
| pattern_properties: dict[str, JsonSchema] = {} | ||
| additional_values: list[Any] = [] | ||
| restricted_property_sets: list[set[str]] = [] | ||
|
|
||
| base_properties = ( | ||
| cast(dict[str, JsonSchema], result.get('properties', {})) if isinstance(result.get('properties'), dict) else {} | ||
| ) | ||
| base_additional = result.get('additionalProperties') | ||
|
|
||
| if base_properties: | ||
| properties.update(base_properties) | ||
| if isinstance(result.get('required'), list): | ||
| required.update(result['required']) | ||
| if isinstance(result.get('patternProperties'), dict): | ||
| pattern_properties.update(result['patternProperties']) | ||
| if base_additional is False: | ||
| additional_values.append(False) | ||
| # Only restrict if base schema has properties; if base has no properties but additionalProperties: False, | ||
| # it means no additional properties are allowed, but properties from allOf members are still valid | ||
| if base_properties: | ||
| restricted_property_sets.append(set(base_properties.keys())) | ||
|
|
||
| return properties, required, pattern_properties, additional_values, restricted_property_sets | ||
|
|
||
|
|
||
| def _collect_member_data( | ||
| processed_members: list[JsonSchema], | ||
| properties: dict[str, JsonSchema], | ||
| required: set[str], | ||
| pattern_properties: dict[str, JsonSchema], | ||
| additional_values: list[Any], | ||
| restricted_property_sets: list[set[str]], | ||
| members_properties: list[dict[str, JsonSchema]], | ||
| members_additional_props: list[Any], | ||
| ) -> None: | ||
| """Collect data from allOf members and update the collections.""" | ||
| for m in processed_members: | ||
| member_props = ( | ||
| cast(dict[str, JsonSchema], m.get('properties', {})) if isinstance(m.get('properties'), dict) else {} | ||
| ) | ||
| members_properties.append(member_props) | ||
| members_additional_props.append(m.get('additionalProperties')) | ||
|
|
||
| if member_props: | ||
| properties.update(member_props) | ||
| if isinstance(m.get('required'), list): | ||
| required.update(m['required']) | ||
| if isinstance(m.get('patternProperties'), dict): | ||
| pattern_properties.update(m['patternProperties']) | ||
| if 'additionalProperties' in m: | ||
| additional_values.append(m['additionalProperties']) | ||
| if m['additionalProperties'] is False: | ||
| restricted_property_sets.append(set(member_props.keys())) | ||
|
|
||
|
|
||
| def _filter_by_restricted_property_sets( | ||
| properties: dict[str, JsonSchema], required: set[str], restricted_property_sets: list[set[str]] | ||
| ) -> tuple[dict[str, JsonSchema], set[str]]: | ||
| """Filter properties and required by restricted property sets (intersection when some/all have additionalProperties: False).""" | ||
| if not restricted_property_sets: | ||
| return properties, required | ||
|
|
||
| # Intersection of allowed properties from all members with additionalProperties: False | ||
| allowed_names = restricted_property_sets[0].copy() | ||
| for prop_set in restricted_property_sets[1:]: | ||
| allowed_names &= prop_set | ||
| # Filter properties to only include allowed names | ||
| if allowed_names: | ||
| properties = {k: v for k, v in properties.items() if k in allowed_names} | ||
| required = {r for r in required if r in allowed_names} | ||
| else: | ||
| # Empty intersection - remove all properties | ||
| properties = {} | ||
| required = set() | ||
|
|
||
| return properties, required | ||
|
|
||
|
|
||
| def _filter_incompatible_properties( | ||
| properties: dict[str, JsonSchema], | ||
| required: set[str], | ||
| members_properties: list[dict[str, JsonSchema]], | ||
| members_additional_props: list[Any], | ||
| ) -> tuple[dict[str, JsonSchema], set[str]]: | ||
| """Filter incompatible properties based on additionalProperties constraints.""" | ||
| if not properties: | ||
| return properties, required | ||
|
|
||
| incompatible_props: set[str] = set() | ||
|
|
||
| for prop_name, prop_schema in properties.items(): | ||
| prop_types = _get_type_set(prop_schema) | ||
|
|
||
| # Check compatibility with each member (including base) | ||
| for member_props, member_additional in zip(members_properties, members_additional_props): | ||
| if prop_name in member_props: | ||
| # Property explicitly defined - check type compatibility | ||
| member_prop_types = _get_type_set(member_props[prop_name]) | ||
| if prop_types and member_prop_types and not prop_types & member_prop_types: | ||
| incompatible_props.add(prop_name) | ||
| break | ||
| continue # Compatible, check next member | ||
| if isinstance(member_additional, dict): | ||
| allowed_types = _get_type_set(cast(JsonSchema, member_additional)) | ||
| # Property type must be a subset of allowed types | ||
| if prop_types and allowed_types and not (prop_types <= allowed_types): | ||
| incompatible_props.add(prop_name) | ||
| break | ||
|
|
||
| if incompatible_props: | ||
| allowed_names = {k for k in properties.keys() if k not in incompatible_props} | ||
| properties = {k: v for k, v in properties.items() if k in allowed_names} | ||
| required = {r for r in required if r in allowed_names} | ||
|
|
||
| return properties, required | ||
|
|
||
|
|
||
| def _process_result_nested_schemas(result: JsonSchema) -> None: | ||
| """Recursively process nested schemas in the result (additionalProperties, patternProperties, items).""" | ||
| if isinstance(result.get('additionalProperties'), dict): | ||
| result['additionalProperties'] = _recurse_flatten_allof(cast(JsonSchema, result['additionalProperties'])) | ||
| if isinstance(result.get('patternProperties'), dict): | ||
| result['patternProperties'] = { | ||
| k: _recurse_flatten_allof(cast(JsonSchema, v)) | ||
| for k, v in result['patternProperties'].items() | ||
| if isinstance(v, dict) | ||
| } | ||
| if isinstance(result.get('items'), dict): | ||
| result['items'] = _recurse_flatten_allof(cast(JsonSchema, result['items'])) | ||
|
|
||
|
|
||
| def _recurse_flatten_allof(schema: JsonSchema) -> JsonSchema: | ||
| """Recursively flatten allOf in a JSON schema. | ||
|
|
||
| This function: | ||
| 1. Makes a deep copy of the schema | ||
| 2. Flattens allOf at the current level | ||
| 3. Recursively processes nested schemas (properties, items, etc.) | ||
| """ | ||
| s = deepcopy(schema) | ||
|
|
||
| # Case 1: No allOf - process nested schemas recursively and return | ||
| allof = s.get('allOf') | ||
| if not isinstance(allof, list) or not allof: | ||
| return _process_nested_schemas_without_allof(s) | ||
|
|
||
| # Check all members are dicts | ||
| members = cast(list[JsonSchema], allof) | ||
| if not all(isinstance(m, dict) for m in members): | ||
| return s | ||
|
|
||
| # Check all members are object-like (can be merged) | ||
| def _is_object_like(member: JsonSchema) -> bool: | ||
| member_type = member.get('type') | ||
| if member_type is None: | ||
| # No type but has object-like keys | ||
| keys = ('properties', 'additionalProperties', 'patternProperties') | ||
| return bool(any(k in member for k in keys)) | ||
| return isinstance(member_type, str) and member_type == 'object' | ||
|
|
||
| if not all(_is_object_like(m) for m in members): | ||
| return s | ||
|
|
||
| # Recursively flatten each member first | ||
| processed_members = [_recurse_flatten_allof(m) for m in members] | ||
| result: JsonSchema = {k: v for k, v in s.items() if k != 'allOf'} | ||
| result['type'] = 'object' | ||
|
|
||
| # Collect data from base schema and members | ||
| base_properties = ( | ||
| cast(dict[str, JsonSchema], result.get('properties', {})) if isinstance(result.get('properties'), dict) else {} | ||
| ) | ||
| base_additional = result.get('additionalProperties') | ||
|
|
||
| properties, required, pattern_properties, additional_values, restricted_property_sets = _collect_base_schema_data( | ||
| result | ||
| ) | ||
|
|
||
| # Then merge properties from all members | ||
| members_properties: list[dict[str, JsonSchema]] = [base_properties] | ||
| members_additional_props: list[Any] = [base_additional] | ||
|
|
||
| _collect_member_data( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not obvious that this modifies some of these arguments in place. Can we please create a new helper |
||
| processed_members, | ||
| properties, | ||
| required, | ||
| pattern_properties, | ||
| additional_values, | ||
| restricted_property_sets, | ||
| members_properties, | ||
| members_additional_props, | ||
| ) | ||
|
|
||
| # Filter by restricted property sets and incompatible properties | ||
| properties, required = _filter_by_restricted_property_sets(properties, required, restricted_property_sets) | ||
| properties, required = _filter_incompatible_properties( | ||
| properties, required, members_properties, members_additional_props | ||
| ) | ||
|
|
||
| # Apply filtered properties | ||
| if properties: | ||
| # Recursively flatten nested properties | ||
| result['properties'] = {k: _recurse_flatten_allof(v) for k, v in properties.items()} | ||
| if required: | ||
| result['required'] = sorted(required) | ||
| if pattern_properties: | ||
| result['patternProperties'] = {k: _recurse_flatten_allof(v) for k, v in pattern_properties.items()} | ||
|
|
||
| # Merge additionalProperties | ||
| if additional_values: | ||
| # If any is False, result is False (most restrictive) | ||
| if any(v is False for v in additional_values): | ||
| result['additionalProperties'] = False | ||
| # If there's exactly one dict schema, preserve it | ||
| elif len(additional_values) == 1 and isinstance(additional_values[0], dict): | ||
| result['additionalProperties'] = additional_values[0] | ||
| # If any is a dict schema (multiple), result is True (can't merge multiple schemas) | ||
| elif any(isinstance(v, dict) for v in additional_values): | ||
| result['additionalProperties'] = True | ||
| # Otherwise, default to True | ||
| else: | ||
| result['additionalProperties'] = True | ||
|
|
||
| # Recursively process nested schemas (additionalProperties, patternProperties) | ||
| # Note: items is only valid for array types, not object types, so result.get('items') should never | ||
| # be present when result['type'] == 'object'. However, we keep this check for robustness. | ||
| _process_result_nested_schemas(result) | ||
|
|
||
| return result | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,7 +143,7 @@ class OpenAIJsonSchemaTransformer(JsonSchemaTransformer): | |
| """ | ||
|
|
||
| def __init__(self, schema: JsonSchema, *, strict: bool | None = None): | ||
| super().__init__(schema, strict=strict) | ||
| super().__init__(schema, strict=strict, flatten_allof=True) | ||
| self.root_ref = schema.get('$ref') | ||
|
|
||
| def walk(self) -> JsonSchema: | ||
|
|
@@ -157,7 +157,12 @@ def walk(self) -> JsonSchema: | |
| if self.root_ref is not None: | ||
| result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method | ||
| root_key = re.sub(r'^#/\$defs/', '', self.root_ref) | ||
| result.update(self.defs.get(root_key) or {}) | ||
| # Use the transformed schema from $defs, not the original self.defs | ||
| if '$defs' in result and root_key in result['$defs']: | ||
| result.update(result['$defs'][root_key]) | ||
| else: | ||
| # Fallback to original if transformed version not available (shouldn't happen in normal flow) | ||
| result.update(self.defs.get(root_key) or {}) | ||
|
||
|
|
||
| return result | ||
|
|
||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually belonged on the next line :D