diff --git a/code-gen/src/poly_scribe_code_gen/cpp_gen.py b/code-gen/src/poly_scribe_code_gen/cpp_gen.py index dae21f9..635a7c7 100644 --- a/code-gen/src/poly_scribe_code_gen/cpp_gen.py +++ b/code-gen/src/poly_scribe_code_gen/cpp_gen.py @@ -90,7 +90,7 @@ def _transform_types(parsed_idl: ParsedIDL) -> ParsedIDL: for struct_data in parsed_idl["structs"].values(): for member_data in struct_data["members"].values(): member_data["type"] = _transformer(member_data["type"], parsed_idl["inheritance_data"]) - if not member_data["required"] and member_data["default"] is None: + if not member_data["required"]: member_data["type"] = f"std::optional<{member_data['type']}>" if "std::string" in member_data["type"] and member_data["default"]: diff --git a/code-gen/src/poly_scribe_code_gen/py_gen.py b/code-gen/src/poly_scribe_code_gen/py_gen.py index 3d0b0ee..d32dfaf 100644 --- a/code-gen/src/poly_scribe_code_gen/py_gen.py +++ b/code-gen/src/poly_scribe_code_gen/py_gen.py @@ -170,7 +170,7 @@ def _transform_types(parsed_idl: ParsedIDL) -> ParsedIDL: type_str = type_str[1:-1] member_data["default"] = f"{type_str}()" - if not member_data["required"] and member_data["default"] is None: + if not member_data["required"]: member_data["type"] = f"Optional[{member_data['type']}]" if "str" in member_data["type"] and member_data["default"] is not None: diff --git a/code-gen/tests/cpp_gen_test.py b/code-gen/tests/cpp_gen_test.py index 12c99b1..6ca426b 100644 --- a/code-gen/tests/cpp_gen_test.py +++ b/code-gen/tests/cpp_gen_test.py @@ -404,7 +404,7 @@ def test_render_template_default_member_values() -> None: result = cpp_gen._render_template(parsed_idl, {"package": "test"}) - assert "int foo = 42;".replace(" ", "") in result.replace(" ", "") + assert "std::optional foo = 42;".replace(" ", "") in result.replace(" ", "") assert "std::optional bar;".replace(" ", "") in result.replace(" ", "") assert "namespace test" in result @@ -549,7 +549,7 @@ def test__render_template_string_default_value() -> None: for match in matches: struct_body = match[1] if match[0] == "Foo": - assert 'std::string foo = "bar";'.replace(" ", "") in struct_body.replace(" ", "") + assert 'std::optional foo = "bar";'.replace(" ", "") in struct_body.replace(" ", "") def test__render_template_boolean_default_value() -> None: @@ -572,8 +572,8 @@ def test__render_template_boolean_default_value() -> None: for match in matches: struct_body = match[1] if match[0] == "Foo": - assert "bool foo = true;".replace(" ", "") in struct_body.replace(" ", "") - assert "bool bar = false;".replace(" ", "") in struct_body.replace(" ", "") + assert "std::optional foo = true;".replace(" ", "") in struct_body.replace(" ", "") + assert "std::optional bar = false;".replace(" ", "") in struct_body.replace(" ", "") def test_render_template_struct_with_multi_inheritance() -> None: @@ -673,4 +673,4 @@ def test_render_template_struct_with_empty_type_default() -> None: for match in matches: struct_body = match[1] if match[0] == "Data": - assert "Base_t base = Foo{};".replace(" ", "") in struct_body.replace(" ", "") + assert "std::optional base = Foo{};".replace(" ", "") in struct_body.replace(" ", "") diff --git a/code-gen/tests/py_gen_test.py b/code-gen/tests/py_gen_test.py index 17dcf3c..fa587b9 100644 --- a/code-gen/tests/py_gen_test.py +++ b/code-gen/tests/py_gen_test.py @@ -285,7 +285,7 @@ def test_render_template_default_member_values() -> None: result = py_gen._render_template(parsed_idl, {"package": "test"}) - assert "foo: int = 42".replace(" ", "") in result.replace(" ", "") + assert "foo: Optional[int] = 42".replace(" ", "") in result.replace(" ", "") assert "bar: Optional[int] = None".replace(" ", "") in result.replace(" ", "") @@ -564,7 +564,7 @@ def test__render_template_string_default_value() -> None: for match in matches: struct_body = match[2] if match[0] == "Foo": - assert 'foo: str = "bar"'.replace(" ", "") in struct_body.replace(" ", "") + assert 'foo: Optional[str] = "bar"'.replace(" ", "") in struct_body.replace(" ", "") def test_render_template_struct_with_empty_type_default() -> None: @@ -600,7 +600,7 @@ def test_render_template_struct_with_empty_type_default() -> None: for match in matches: struct_body = match[2] if match[0] == "Data": - assert 'base: Annotated[Union["Foo", "Bar", "Base"],Field(discriminator="type")] = Foo()'.replace( + assert 'base: Optional[Annotated[Union["Foo", "Bar", "Base"],Field(discriminator="type")]] = Foo()'.replace( " ", "" ) in struct_body.replace(" ", "") elif match[0] == "Base": @@ -609,7 +609,7 @@ def test_render_template_struct_with_empty_type_default() -> None: assert 'type: Literal["Foo"] = "Foo"'.replace(" ", "") in struct_body.replace(" ", "") elif match[0] == "Bar": assert 'type: Literal["Bar"] = "Bar"'.replace(" ", "") in struct_body.replace(" ", "") - assert "value: int = int()".replace(" ", "") in struct_body.replace(" ", "") + assert "value: Optional[int] = int()".replace(" ", "") in struct_body.replace(" ", "") def test__render_template_boolean_default_value() -> None: @@ -630,5 +630,5 @@ def test__render_template_boolean_default_value() -> None: for match in matches: struct_body = match[2] if match[0] == "Foo": - assert "foo: bool = True".replace(" ", "") in struct_body.replace(" ", "") - assert "bar: bool = False".replace(" ", "") in struct_body.replace(" ", "") + assert "foo: Optional[bool] = True".replace(" ", "") in struct_body.replace(" ", "") + assert "bar: Optional[bool] = False".replace(" ", "") in struct_body.replace(" ", "")