Skip to content
23 changes: 22 additions & 1 deletion code-gen/src/poly_scribe_code_gen/parse_idl.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,31 @@ def _flatten_members(members: list[dict[str, Any]]) -> dict[str, Any]:
output = {}
for member in members:
if member["type"] == "field":
# check if the member ext_attrs with the name "Default" exists
default_type = None
if member["ext_attrs"] and any(attr["name"] == "Default" for attr in member["ext_attrs"]):
# get the Default ext_attr
default_ext_attr = next(attr for attr in member["ext_attrs"] if attr["name"] == "Default")
if default_ext_attr["rhs"]["value"] is not None:
default_type = default_ext_attr["rhs"]["value"]

# Check if member["default"]["value"] is an empty dict
if (
member["default"]
and isinstance(member["default"]["value"], dict)
and member["default"]["value"] is not None
):
default_value = "{}"
elif member["default"] and member["default"]["value"] is not None:
default_value = member["default"]["value"]
else:
default_value = None

output[member["name"]] = {
"type": _flatten_type(member["idl_type"], parent_ext_attrs=member["ext_attrs"]), # type: ignore
"required": bool(member["required"]),
"default": member["default"]["value"] if member["default"] and member["default"]["value"] else None,
"default": default_value,
"default_type": default_type,
}
else:
msg = f"Unsupported WebIDL type '{member['type']}'."
Expand Down
9 changes: 9 additions & 0 deletions code-gen/src/poly_scribe_code_gen/py_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ def _transform_types(parsed_idl: ParsedIDL) -> ParsedIDL:
for struct_name, struct_data in parsed_idl["structs"].items():
for member_data in struct_data["members"].values():
member_data["type"] = _transformer(member_data["type"], parsed_idl["inheritance_data"])

if member_data["default"] == "{}" and member_data["default_type"] is not None:
# member_data["default"] = f"Field(default={member_data['default_type']}())"
member_data["default"] = f"{member_data['default_type']}()"

if member_data["default"] == "{}" and member_data["default_type"] is None:
# member_data["default"] = f"Field(default={member_data['type']}())"
member_data["default"] = f"{member_data['type']}()"

if not member_data["required"]:
member_data["type"] = f"Optional[{member_data['type']}]"

Expand Down
2 changes: 1 addition & 1 deletion code-gen/src/poly_scribe_code_gen/templates/reflect.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ namespace {{ package }} {
{% if member_data.block_comment %}
{{ member_data.block_comment|indent(8) -}}
{% endif %}
{{ member_data.type }} {{ member_name }}{% if member_data.default %} = {{ member_data.default }}{% endif %};
{{ member_data.type }} {{ member_name }}{% if member_data.default is not none %} = {% if member_data.default_type %}{{ member_data.default_type }}{% endif %} {{ member_data.default }}{% endif %};
{% endfor %}
};
{%- if not loop.last %}
Expand Down
35 changes: 35 additions & 0 deletions code-gen/tests/cpp_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,38 @@ def test_render_template_struct_with_multi_inheritance() -> None:
assert "C_t content;".replace(" ", "") in struct_body.replace(" ", "")

assert 'using X_t = rfl::TaggedUnion<"type", X, B, C, M, N>;'.replace(" ", "") in result.replace(" ", "")


def test_render_template_struct_with_empty_type_default() -> None:
idl = """
dictionary Base {
};

dictionary Foo : Base {
};

dictionary Bar : Base {
};

dictionary Data {
[Default=Foo] Base base = {};
};
"""

parsed_idl = _validate_and_parse(idl)

result = cpp_gen._render_template(parsed_idl, {"package": "test"})

pattern = re.compile(r"struct (\w+) \{((?:[^{}]|\{[^{}]*\})*)\};", re.MULTILINE)
matches = pattern.findall(result)

assert len(matches) == 4
assert "Base" in [match[0] for match in matches]
assert "Foo" in [match[0] for match in matches]
assert "Bar" in [match[0] for match in matches]
assert "Data" in [match[0] for match in matches]

for match in matches:
struct_body = match[1]
if match[0] == "Data":
assert "Base_t base = Foo{};".replace(" ", "") in struct_body.replace(" ", "")
63 changes: 62 additions & 1 deletion code-gen/tests/parse_idl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,17 @@ def test__validate_and_parse_struct() -> None:
struct_data = parsed_idl["structs"]["FooBar"]
assert struct_data["inheritance"] is None
struct_members = struct_data["members"]
assert struct_members["foo"] == {"type": "int", "default": None, "required": False}
assert struct_members["foo"] == {
"type": "int",
"default": None,
"required": False,
"default_type": None,
}
assert struct_members["bar"] == {
"type": "float",
"default": None,
"required": False,
"default_type": None,
}
assert struct_members["baz"] == {
"type": {
Expand All @@ -136,6 +142,7 @@ def test__validate_and_parse_struct() -> None:
},
"default": None,
"required": False,
"default_type": None,
}
assert struct_members["qux"] == {
"type": {
Expand All @@ -148,6 +155,7 @@ def test__validate_and_parse_struct() -> None:
},
"default": None,
"required": False,
"default_type": None,
}
assert struct_members["quux"] == { # Fails due to ext attrs!
"type": {
Expand All @@ -160,6 +168,7 @@ def test__validate_and_parse_struct() -> None:
},
"default": None,
"required": False,
"default_type": None,
}

struct_data = parsed_idl["structs"]["BazQux"]
Expand All @@ -176,6 +185,7 @@ def test__validate_and_parse_struct() -> None:
},
"default": None,
"required": False,
"default_type": None,
}


Expand Down Expand Up @@ -246,16 +256,19 @@ def test__validate_and_parse_struct_default_values_and_required() -> None:
"type": "int",
"default": "42",
"required": False,
"default_type": None,
}
assert struct_members["default_float"] == {
"type": "float",
"default": "3.14",
"required": False,
"default_type": None,
}
assert struct_members["required_int"] == {
"type": "int",
"default": None,
"required": True,
"default_type": None,
}


Expand Down Expand Up @@ -907,6 +920,7 @@ def test__validate_and_parse_string_default_value() -> None:
"type": "string",
"default": "default_value",
"required": False,
"default_type": None,
}


Expand All @@ -925,3 +939,50 @@ def test__find_comments_are_associated_with_correct_type() -> None:
# Check that no inline comment is associated with the type "Cls"
for key in comment_data["inline_comments"]:
assert "Cls" not in key, f"Unexpected inline comment key containing 'Cls': {key}"


def test__validate_and_parse_default_empty() -> None:
idl = """
dictionary Foo {
int bar = {};
};
"""

parsed_idl = parsing._validate_and_parse(idl)

struct_data = parsed_idl["structs"]["Foo"]
struct_members = struct_data["members"]
assert struct_members["bar"] == {
"type": "int",
"default": "{}",
"required": False,
"default_type": None,
}


def test__validate_and_parse_default_empty_type_defined() -> None:
idl = """
dictionary Base {
};

dictionary Foo : Base {
};

dictionary Bar : Base {
};

dictionary Data {
[Default=Foo] Base base = {};
};
"""

parsed_idl = parsing._validate_and_parse(idl)

struct_data = parsed_idl["structs"]["Data"]
struct_members = struct_data["members"]
assert struct_members["base"] == {
"type": "Base",
"default": "{}",
"required": False,
"default_type": "Foo",
}
45 changes: 45 additions & 0 deletions code-gen/tests/py_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,48 @@ def test__render_template_string_default_value() -> None:
struct_body = match[2]
if match[0] == "Foo":
assert 'foo: Optional[str] = "bar"'.replace(" ", "") in struct_body.replace(" ", "")


def test_render_template_struct_with_empty_type_default() -> None:
idl = """
dictionary Base {
};

dictionary Foo : Base {
};

dictionary Bar : Base {
int value = {};
};

dictionary Data {
[Default=Foo] Base base = {};
};
"""

parsed_idl = _validate_and_parse(idl)

result = py_gen._render_template(parsed_idl, {"package": "foo"})

pattern = re.compile(r"class\s+(\w+)\((\w*)\):\s*(.*?)\n\n", re.DOTALL)
matches = pattern.findall(result)

assert len(matches) == 4
assert "Base" in [match[0] for match in matches]
assert "Foo" in [match[0] for match in matches]
assert "Bar" in [match[0] for match in matches]
assert "Data" in [match[0] for match in matches]

for match in matches:
struct_body = match[2]
if match[0] == "Data":
assert 'base: Optional[Annotated[Union[Foo, Bar, Base],Field(discriminator="type")]] = Foo()'.replace(
" ", ""
) in struct_body.replace(" ", "")
elif match[0] == "Base":
assert 'type: Literal["Base"] = "Base"'.replace(" ", "") in struct_body.replace(" ", "")
elif match[0] == "Foo":
assert 'type: Literal["Foo"] = "Foo"'.replace(" ", "") in struct_body.replace(" ", "")
elif match[0] == "Bar":
assert 'type: Literal["Bar"] = "Bar"'.replace(" ", "") in struct_body.replace(" ", "")
assert "value: Optional[int] = int()".replace(" ", "") in struct_body.replace(" ", "")
Loading