diff --git a/code-gen/src/poly_scribe_code_gen/cpp_gen.py b/code-gen/src/poly_scribe_code_gen/cpp_gen.py index 7aa7702..dae21f9 100644 --- a/code-gen/src/poly_scribe_code_gen/cpp_gen.py +++ b/code-gen/src/poly_scribe_code_gen/cpp_gen.py @@ -169,6 +169,14 @@ def _handle_rfl_tagged_union(parsed_idl: ParsedIDL) -> ParsedIDL: new_inheritance_data = {} for key in parsed_idl["inheritance_data"]: new_inheritance_data[f"{key}_t"] = parsed_idl["inheritance_data"][key] + + # check if the values of the current key are also in the inheritance data and add their value to the new inheritance data as well + types_to_add = [] + for value in parsed_idl["inheritance_data"][key]: + if value in parsed_idl["inheritance_data"]: + types_to_add.extend(parsed_idl["inheritance_data"][value]) + new_inheritance_data[f"{key}_t"].extend(types_to_add) + new_inheritance_data[f"{key}_t"].insert(0, key) parsed_idl["inheritance_data"] = new_inheritance_data return parsed_idl diff --git a/code-gen/tests/cpp_gen_test.py b/code-gen/tests/cpp_gen_test.py index 6ece935..e1a3a4a 100644 --- a/code-gen/tests/cpp_gen_test.py +++ b/code-gen/tests/cpp_gen_test.py @@ -572,3 +572,68 @@ def test__render_template_boolean_default_value() -> None: struct_body = match[1] if match[0] == "Foo": assert "bool foo = true;".replace(" ", "") in struct_body.replace(" ", "") + + +def test_render_template_struct_with_multi_inheritance() -> None: + idl = """ +dictionary X { + required int foo; +}; + +dictionary B : X { + required int bar; +}; + +dictionary C : X { + required float baz; +}; + +dictionary M : C { + required int qux; +}; + +dictionary N : C { + required int quux; +}; + +dictionary Y { + required N content; +}; +""" + parsed_idl = _validate_and_parse(idl) + + result = cpp_gen._render_template(parsed_idl, {"package": "test"}) + + pattern = re.compile(r"struct (\w+) \{([^}]*)\};", re.MULTILINE) + matches = pattern.findall(result) + + assert len(matches) == 6 + assert "X" in [match[0] for match in matches] + assert "B" in [match[0] for match in matches] + assert "C" in [match[0] for match in matches] + assert "Y" in [match[0] for match in matches] + assert "M" in [match[0] for match in matches] + assert "N" in [match[0] for match in matches] + + for match in matches: + struct_body = match[1] + if match[0] == "X": + assert "int foo;".replace(" ", "") in struct_body.replace(" ", "") + elif match[0] == "B": + assert "int bar;".replace(" ", "") in struct_body.replace(" ", "") + assert "int foo;".replace(" ", "") in struct_body.replace(" ", "") + elif match[0] == "C": + assert "float baz;".replace(" ", "") in struct_body.replace(" ", "") + assert "int foo;".replace(" ", "") in struct_body.replace(" ", "") + elif match[0] == "M": + assert "int qux;".replace(" ", "") in struct_body.replace(" ", "") + assert "float baz;".replace(" ", "") in struct_body.replace(" ", "") + assert "int foo;".replace(" ", "") in struct_body.replace(" ", "") + elif match[0] == "N": + assert "int quux;".replace(" ", "") in struct_body.replace(" ", "") + assert "float baz;".replace(" ", "") in struct_body.replace(" ", "") + assert "int foo;".replace(" ", "") in struct_body.replace(" ", "") + elif match[0] == "Y": + assert "C_t content;".replace(" ", "") in struct_body.replace(" ", "") + + assert 'using X_t = rfl::TaggedUnion<"type", X, B, C, M, N>;'.replace(" ", "") in result.replace(" ", "")