Skip to content

Commit f098da4

Browse files
authored
Merge pull request #741 from common-workflow-language/adjust-is-subtype
Implement smarter `is_subtype` logic
2 parents cfde2e6 + fef7256 commit f098da4

File tree

8 files changed

+258
-43
lines changed

8 files changed

+258
-43
lines changed

schema_salad/avro/schema.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
620620
raise SchemaParseException(fail_msg)
621621

622622

623-
def is_subtype(existing: PropType, new: PropType) -> bool:
623+
def is_subtype(types: Dict[str, Any], existing: PropType, new: PropType) -> bool:
624624
"""Check if a new type specification is compatible with an existing type spec."""
625625
if existing == new:
626626
return True
@@ -632,46 +632,35 @@ def is_subtype(existing: PropType, new: PropType) -> bool:
632632
if isinstance(new, list) and "null" in new:
633633
return False
634634
return True
635-
if (
636-
isinstance(existing, dict)
637-
and "type" in existing
638-
and existing["type"] == "array"
639-
and isinstance(new, dict)
640-
and "type" in new
641-
and new["type"] == "array"
642-
):
643-
return is_subtype(existing["items"], new["items"])
644-
if (
645-
isinstance(existing, dict)
646-
and "type" in existing
647-
and existing["type"] == "enum"
648-
and isinstance(new, dict)
649-
and "type" in new
650-
and new["type"] == "enum"
651-
):
652-
return is_subtype(existing["symbols"], new["symbols"])
653-
if (
654-
isinstance(existing, dict)
655-
and "type" in existing
656-
and existing["type"] == "record"
657-
and isinstance(new, dict)
658-
and "type" in new
659-
and new["type"] == "record"
660-
):
661-
for new_field in cast(List[Dict[str, Any]], new["fields"]):
662-
new_field_missing = True
663-
for existing_field in cast(List[Dict[str, Any]], existing["fields"]):
664-
if new_field["name"] == existing_field["name"]:
665-
if not is_subtype(existing_field["type"], new_field["type"]):
666-
return False
667-
new_field_missing = False
668-
if new_field_missing:
669-
return False
670-
return True
635+
if isinstance(existing, str) and existing in types:
636+
return is_subtype(types, types[existing], new)
637+
if isinstance(new, str) and new in types:
638+
return is_subtype(types, existing, types[new])
639+
if isinstance(existing, dict) and isinstance(new, dict):
640+
if "extends" in new and new["extends"] == existing.get("name"):
641+
return True
642+
if existing.get("type") == "array" and new.get("type") == "array":
643+
return is_subtype(types, existing["items"], new["items"])
644+
if existing.get("type") == "enum" and new.get("type") == "enum":
645+
return is_subtype(types, existing["symbols"], new["symbols"])
646+
if existing.get("type") == "record" and new.get("type") == "record":
647+
for new_field in cast(List[Dict[str, Any]], new["fields"]):
648+
new_field_missing = True
649+
for existing_field in cast(List[Dict[str, Any]], existing["fields"]):
650+
if new_field["name"] == existing_field["name"]:
651+
if not is_subtype(types, existing_field["type"], new_field["type"]):
652+
return False
653+
new_field_missing = False
654+
if new_field_missing:
655+
return False
656+
return True
671657
if isinstance(existing, list) and isinstance(new, list):
672658
missing = False
673-
for _type in new:
674-
if _type not in existing and (not is_subtype(existing, cast(PropType, _type))):
659+
for _type_new in new:
660+
if _type_new not in existing and not any(
661+
is_subtype(types, cast(PropType, _type_existing), cast(PropType, _type_new))
662+
for _type_existing in existing
663+
):
675664
missing = True
676665
return not missing
677666
return False

schema_salad/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def extend_and_specialize(items: List[Dict[str, Any]], loader: Loader) -> List[D
594594
"""Apply 'extend' and 'specialize' to fully materialize derived record types."""
595595
items2 = deepcopy_strip(items)
596596
types = {i["name"]: i for i in items2} # type: Dict[str, Any]
597+
types.update({k[len(saladp) :]: v for k, v in types.items() if k.startswith(saladp)})
597598
results = []
598599

599600
for stype in items2:
@@ -654,7 +655,7 @@ def extend_and_specialize(items: List[Dict[str, Any]], loader: Loader) -> List[D
654655
field = exfield
655656
else:
656657
# make sure field name has not been used yet
657-
if not is_subtype(exfield["type"], field["type"]):
658+
if not is_subtype(types, exfield["type"], field["type"]):
658659
raise SchemaParseException(
659660
f"Field name {field['name']} already in use with "
660661
"incompatible type. "
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
$base: "https://example.com/nested_schema#"
2+
3+
$namespaces:
4+
bs: "https://example.com/base_schema#"
5+
dv: "https://example.com/derived_schema#"
6+
7+
$graph:
8+
9+
- $import: avro_subtype.yml
10+
11+
- type: record
12+
name: AbstractContainer
13+
abstract: true
14+
doc: |
15+
This is an abstract container thing that includes an AbstractThing field
16+
fields:
17+
override_me:
18+
type: bs:AbstractThing
19+
jsonldPredicate: "bs:override_me"
20+
21+
22+
- type: record
23+
name: ExtendedContainer
24+
extends: AbstractContainer
25+
doc: |
26+
An extended version of the abstract container that implements an extra field
27+
and uses an ExtendedThing to override the original field
28+
fields:
29+
extra_field:
30+
type:
31+
type: array
32+
items: [string]
33+
override_me:
34+
type: dv:ExtendedThing
35+
jsonldPredicate: "bs:override_me"
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
$base: "https://example.com/nested_schema#"
2+
3+
$namespaces:
4+
bs: "https://example.com/base_schema#"
5+
dv: "https://example.com/derived_schema#"
6+
7+
$graph:
8+
9+
- $import: avro_subtype_bad.yml
10+
11+
- type: record
12+
name: AbstractContainer
13+
abstract: true
14+
doc: |
15+
This is an abstract container thing that includes an AbstractThing field
16+
fields:
17+
override_me:
18+
type: bs:AbstractThing
19+
jsonldPredicate: "bs:override_me"
20+
21+
22+
- type: record
23+
name: ExtendedContainer
24+
extends: AbstractContainer
25+
doc: |
26+
An extended version of the abstract container that implements an extra field
27+
and uses an ExtendedThing to override the original field
28+
fields:
29+
extra_field:
30+
type:
31+
type: array
32+
items: [string]
33+
override_me:
34+
type: dv:ExtendedThing
35+
jsonldPredicate: "bs:override_me"
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
$base: "https://example.com/recursive_schema#"
2+
3+
$namespaces:
4+
bs: "https://example.com/base_schema#"
5+
6+
$graph:
7+
8+
- $import: "metaschema_base.yml"
9+
10+
- type: record
11+
name: RecursiveThing
12+
doc: |
13+
This is an arbitrary recursive thing that includes itself in its fields
14+
fields:
15+
override_me:
16+
type: RecursiveThing
17+
jsonldPredicate: "bs:override_me"
18+
19+
20+
- type: record
21+
name: ExtendedThing
22+
extends: RecursiveThing
23+
doc: |
24+
An extended version of the recursive thing that implements an extra field
25+
fields:
26+
field_one:
27+
type:
28+
type: array
29+
items: [string]
30+
override_me:
31+
type: ExtendedThing
32+
jsonldPredicate: "bs:override_me"
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
$base: "https://example.com/union_schema#"
2+
3+
$namespaces:
4+
bs: "https://example.com/base_schema#"
5+
dv: "https://example.com/derived_schema#"
6+
7+
$graph:
8+
9+
- $import: avro_subtype.yml
10+
11+
- type: record
12+
name: AbstractContainer
13+
abstract: true
14+
doc: |
15+
This is an abstract container thing that includes an AbstractThing
16+
type in its field types
17+
fields:
18+
override_me:
19+
type: [int, string, bs:AbstractThing]
20+
jsonldPredicate: "bs:override_me"
21+
22+
23+
- type: record
24+
name: ExtendedContainer
25+
extends: AbstractContainer
26+
doc: |
27+
An extended version of the abstract container that implements an extra field
28+
and contains an ExtendedThing type in its overridden field types
29+
fields:
30+
extra_field:
31+
type:
32+
type: array
33+
items: [string]
34+
override_me:
35+
type: [int, dv:ExtendedThing]
36+
jsonldPredicate: "bs:override_me"
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
$base: "https://example.com/union_schema#"
2+
3+
$namespaces:
4+
bs: "https://example.com/base_schema#"
5+
dv: "https://example.com/derived_schema#"
6+
7+
$graph:
8+
9+
- $import: avro_subtype_bad.yml
10+
11+
- type: record
12+
name: AbstractContainer
13+
abstract: true
14+
doc: |
15+
This is an abstract container thing that includes an AbstractThing
16+
type in its field types
17+
fields:
18+
override_me:
19+
type: [int, string, bs:AbstractThing]
20+
jsonldPredicate: "bs:override_me"
21+
22+
23+
- type: record
24+
name: ExtendedContainer
25+
extends: AbstractContainer
26+
doc: |
27+
An extended version of the abstract container that implements an extra field
28+
and contains an ExtendedThing type in its overridden field types
29+
fields:
30+
extra_field:
31+
type:
32+
type: array
33+
items: [string]
34+
override_me:
35+
type: [int, dv:ExtendedThing]
36+
jsonldPredicate: "bs:override_me"

schema_salad/tests/test_subtypes.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Confirm subtypes."""
2+
23
import pytest
34

45
from schema_salad.avro import schema
56
from schema_salad.avro.schema import Names, SchemaParseException
67
from schema_salad.schema import load_schema
7-
88
from .util import get_data
99

1010
types = [
@@ -84,7 +84,7 @@
8484
@pytest.mark.parametrize("old,new,result", types)
8585
def test_subtypes(old: schema.PropType, new: schema.PropType, result: bool) -> None:
8686
"""Test is_subtype() function."""
87-
assert schema.is_subtype(old, new) == result
87+
assert schema.is_subtype({}, old, new) == result
8888

8989

9090
def test_avro_loading_subtype() -> None:
@@ -105,4 +105,55 @@ def test_avro_loading_subtype_bad() -> None:
105105
r"Any vs \['string', 'int'\]\."
106106
)
107107
with pytest.raises(SchemaParseException, match=target_error):
108-
document_loader, avsc_names, schema_metadata, metaschema_loader = load_schema(path)
108+
_ = load_schema(path)
109+
110+
111+
def test_subtypes_nested() -> None:
112+
"""Confirm correct subtype handling on a nested type definition."""
113+
path = get_data("tests/test_schema/avro_subtype_nested.yml")
114+
assert path
115+
document_loader, avsc_names, schema_metadata, metaschema_loader = load_schema(path)
116+
assert isinstance(avsc_names, Names)
117+
assert avsc_names.get_name("com.example.nested_schema.ExtendedContainer", None)
118+
119+
120+
def test_subtypes_nested_bad() -> None:
121+
"""Confirm subtype error when overriding incorrectly in nested types."""
122+
path = get_data("tests/test_schema/avro_subtype_nested_bad.yml")
123+
assert path
124+
target_error = (
125+
r"Field name .*\/override_me already in use with incompatible type. "
126+
r"Any vs \['string', 'int'\]\."
127+
)
128+
with pytest.raises(SchemaParseException, match=target_error):
129+
_ = load_schema(path)
130+
131+
132+
def test_subtypes_recursive() -> None:
133+
"""Confirm correct subtype handling on a recursive type definition."""
134+
path = get_data("tests/test_schema/avro_subtype_recursive.yml")
135+
assert path
136+
document_loader, avsc_names, schema_metadata, metaschema_loader = load_schema(path)
137+
assert isinstance(avsc_names, Names)
138+
assert avsc_names.get_name("com.example.recursive_schema.RecursiveThing", None)
139+
140+
141+
def test_subtypes_union() -> None:
142+
"""Confirm correct subtype handling on an union type definition."""
143+
path = get_data("tests/test_schema/avro_subtype_union.yml")
144+
assert path
145+
document_loader, avsc_names, schema_metadata, metaschema_loader = load_schema(path)
146+
assert isinstance(avsc_names, Names)
147+
assert avsc_names.get_name("com.example.union_schema.ExtendedContainer", None)
148+
149+
150+
def test_subtypes_union_bad() -> None:
151+
"""Confirm subtype error when overriding incorrectly in array types."""
152+
path = get_data("tests/test_schema/avro_subtype_union_bad.yml")
153+
assert path
154+
target_error = (
155+
r"Field name .*\/override_me already in use with incompatible type. "
156+
r"Any vs \['string', 'int'\]\."
157+
)
158+
with pytest.raises(SchemaParseException, match=target_error):
159+
_ = load_schema(path)

0 commit comments

Comments
 (0)