diff --git a/code-gen/src/poly_scribe_code_gen/parse_idl.py b/code-gen/src/poly_scribe_code_gen/parse_idl.py index aadda54..36b7f93 100644 --- a/code-gen/src/poly_scribe_code_gen/parse_idl.py +++ b/code-gen/src/poly_scribe_code_gen/parse_idl.py @@ -228,10 +228,31 @@ def _flatten_members(members: list[dict[str, Any]]) -> dict[str, Any]: output = {} for member in members: if member["type"] == "field": + # check if the member ext_attrs with the name "Default" exists + default_type = None + if member["ext_attrs"] and any(attr["name"] == "Default" for attr in member["ext_attrs"]): + # get the Default ext_attr + default_ext_attr = next(attr for attr in member["ext_attrs"] if attr["name"] == "Default") + if default_ext_attr["rhs"]["value"] is not None: + default_type = default_ext_attr["rhs"]["value"] + + # Check if member["default"]["value"] is an empty dict + if ( + member["default"] + and isinstance(member["default"]["value"], dict) + and member["default"]["value"] is not None + ): + default_value = "{}" + elif member["default"] and member["default"]["value"] is not None: + default_value = member["default"]["value"] + else: + default_value = None + output[member["name"]] = { "type": _flatten_type(member["idl_type"], parent_ext_attrs=member["ext_attrs"]), # type: ignore "required": bool(member["required"]), - "default": member["default"]["value"] if member["default"] and member["default"]["value"] else None, + "default": default_value, + "default_type": default_type, } else: msg = f"Unsupported WebIDL type '{member['type']}'." diff --git a/code-gen/src/poly_scribe_code_gen/py_gen.py b/code-gen/src/poly_scribe_code_gen/py_gen.py index 069ce23..bd441b6 100644 --- a/code-gen/src/poly_scribe_code_gen/py_gen.py +++ b/code-gen/src/poly_scribe_code_gen/py_gen.py @@ -147,6 +147,15 @@ def _transform_types(parsed_idl: ParsedIDL) -> ParsedIDL: for struct_name, struct_data in parsed_idl["structs"].items(): for member_data in struct_data["members"].values(): member_data["type"] = _transformer(member_data["type"], parsed_idl["inheritance_data"]) + + if member_data["default"] == "{}" and member_data["default_type"] is not None: + # member_data["default"] = f"Field(default={member_data['default_type']}())" + member_data["default"] = f"{member_data['default_type']}()" + + if member_data["default"] == "{}" and member_data["default_type"] is None: + # member_data["default"] = f"Field(default={member_data['type']}())" + member_data["default"] = f"{member_data['type']}()" + if not member_data["required"]: member_data["type"] = f"Optional[{member_data['type']}]" diff --git a/code-gen/src/poly_scribe_code_gen/templates/reflect.jinja b/code-gen/src/poly_scribe_code_gen/templates/reflect.jinja index e4ad170..12bc11a 100644 --- a/code-gen/src/poly_scribe_code_gen/templates/reflect.jinja +++ b/code-gen/src/poly_scribe_code_gen/templates/reflect.jinja @@ -69,7 +69,7 @@ namespace {{ package }} { {% if member_data.block_comment %} {{ member_data.block_comment|indent(8) -}} {% endif %} - {{ member_data.type }} {{ member_name }}{% if member_data.default %} = {{ member_data.default }}{% endif %}; + {{ member_data.type }} {{ member_name }}{% if member_data.default is not none %} = {% if member_data.default_type %}{{ member_data.default_type }}{% endif %} {{ member_data.default }}{% endif %}; {% endfor %} }; {%- if not loop.last %} diff --git a/code-gen/tests/cpp_gen_test.py b/code-gen/tests/cpp_gen_test.py index e1a3a4a..66980be 100644 --- a/code-gen/tests/cpp_gen_test.py +++ b/code-gen/tests/cpp_gen_test.py @@ -637,3 +637,38 @@ def test_render_template_struct_with_multi_inheritance() -> None: assert "C_t content;".replace(" ", "") in struct_body.replace(" ", "") assert 'using X_t = rfl::TaggedUnion<"type", X, B, C, M, N>;'.replace(" ", "") in result.replace(" ", "") + + +def test_render_template_struct_with_empty_type_default() -> None: + idl = """ + dictionary Base { + }; + + dictionary Foo : Base { + }; + + dictionary Bar : Base { + }; + + dictionary Data { + [Default=Foo] Base base = {}; + }; + """ + + parsed_idl = _validate_and_parse(idl) + + result = cpp_gen._render_template(parsed_idl, {"package": "test"}) + + pattern = re.compile(r"struct (\w+) \{((?:[^{}]|\{[^{}]*\})*)\};", re.MULTILINE) + matches = pattern.findall(result) + + assert len(matches) == 4 + assert "Base" in [match[0] for match in matches] + assert "Foo" in [match[0] for match in matches] + assert "Bar" in [match[0] for match in matches] + assert "Data" in [match[0] for match in matches] + + for match in matches: + struct_body = match[1] + if match[0] == "Data": + assert "Base_t base = Foo{};".replace(" ", "") in struct_body.replace(" ", "") diff --git a/code-gen/tests/parse_idl_test.py b/code-gen/tests/parse_idl_test.py index 5df0ca2..1f573d3 100644 --- a/code-gen/tests/parse_idl_test.py +++ b/code-gen/tests/parse_idl_test.py @@ -119,11 +119,17 @@ def test__validate_and_parse_struct() -> None: struct_data = parsed_idl["structs"]["FooBar"] assert struct_data["inheritance"] is None struct_members = struct_data["members"] - assert struct_members["foo"] == {"type": "int", "default": None, "required": False} + assert struct_members["foo"] == { + "type": "int", + "default": None, + "required": False, + "default_type": None, + } assert struct_members["bar"] == { "type": "float", "default": None, "required": False, + "default_type": None, } assert struct_members["baz"] == { "type": { @@ -136,6 +142,7 @@ def test__validate_and_parse_struct() -> None: }, "default": None, "required": False, + "default_type": None, } assert struct_members["qux"] == { "type": { @@ -148,6 +155,7 @@ def test__validate_and_parse_struct() -> None: }, "default": None, "required": False, + "default_type": None, } assert struct_members["quux"] == { # Fails due to ext attrs! "type": { @@ -160,6 +168,7 @@ def test__validate_and_parse_struct() -> None: }, "default": None, "required": False, + "default_type": None, } struct_data = parsed_idl["structs"]["BazQux"] @@ -176,6 +185,7 @@ def test__validate_and_parse_struct() -> None: }, "default": None, "required": False, + "default_type": None, } @@ -246,16 +256,19 @@ def test__validate_and_parse_struct_default_values_and_required() -> None: "type": "int", "default": "42", "required": False, + "default_type": None, } assert struct_members["default_float"] == { "type": "float", "default": "3.14", "required": False, + "default_type": None, } assert struct_members["required_int"] == { "type": "int", "default": None, "required": True, + "default_type": None, } @@ -907,6 +920,7 @@ def test__validate_and_parse_string_default_value() -> None: "type": "string", "default": "default_value", "required": False, + "default_type": None, } @@ -925,3 +939,50 @@ def test__find_comments_are_associated_with_correct_type() -> None: # Check that no inline comment is associated with the type "Cls" for key in comment_data["inline_comments"]: assert "Cls" not in key, f"Unexpected inline comment key containing 'Cls': {key}" + + +def test__validate_and_parse_default_empty() -> None: + idl = """ + dictionary Foo { + int bar = {}; + }; + """ + + parsed_idl = parsing._validate_and_parse(idl) + + struct_data = parsed_idl["structs"]["Foo"] + struct_members = struct_data["members"] + assert struct_members["bar"] == { + "type": "int", + "default": "{}", + "required": False, + "default_type": None, + } + + +def test__validate_and_parse_default_empty_type_defined() -> None: + idl = """ + dictionary Base { + }; + + dictionary Foo : Base { + }; + + dictionary Bar : Base { + }; + + dictionary Data { + [Default=Foo] Base base = {}; + }; + """ + + parsed_idl = parsing._validate_and_parse(idl) + + struct_data = parsed_idl["structs"]["Data"] + struct_members = struct_data["members"] + assert struct_members["base"] == { + "type": "Base", + "default": "{}", + "required": False, + "default_type": "Foo", + } diff --git a/code-gen/tests/py_gen_test.py b/code-gen/tests/py_gen_test.py index cfa0175..12e3052 100644 --- a/code-gen/tests/py_gen_test.py +++ b/code-gen/tests/py_gen_test.py @@ -563,3 +563,48 @@ def test__render_template_string_default_value() -> None: struct_body = match[2] if match[0] == "Foo": assert 'foo: Optional[str] = "bar"'.replace(" ", "") in struct_body.replace(" ", "") + + +def test_render_template_struct_with_empty_type_default() -> None: + idl = """ + dictionary Base { + }; + + dictionary Foo : Base { + }; + + dictionary Bar : Base { + int value = {}; + }; + + dictionary Data { + [Default=Foo] Base base = {}; + }; + """ + + parsed_idl = _validate_and_parse(idl) + + result = py_gen._render_template(parsed_idl, {"package": "foo"}) + + pattern = re.compile(r"class\s+(\w+)\((\w*)\):\s*(.*?)\n\n", re.DOTALL) + matches = pattern.findall(result) + + assert len(matches) == 4 + assert "Base" in [match[0] for match in matches] + assert "Foo" in [match[0] for match in matches] + assert "Bar" in [match[0] for match in matches] + assert "Data" in [match[0] for match in matches] + + for match in matches: + struct_body = match[2] + if match[0] == "Data": + assert 'base: Optional[Annotated[Union[Foo, Bar, Base],Field(discriminator="type")]] = Foo()'.replace( + " ", "" + ) in struct_body.replace(" ", "") + elif match[0] == "Base": + assert 'type: Literal["Base"] = "Base"'.replace(" ", "") in struct_body.replace(" ", "") + elif match[0] == "Foo": + assert 'type: Literal["Foo"] = "Foo"'.replace(" ", "") in struct_body.replace(" ", "") + elif match[0] == "Bar": + assert 'type: Literal["Bar"] = "Bar"'.replace(" ", "") in struct_body.replace(" ", "") + assert "value: Optional[int] = int()".replace(" ", "") in struct_body.replace(" ", "")