Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion code-gen/src/poly_scribe_code_gen/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def poly_scribe_code_gen() -> int:
if args.py:
spec = importlib.util.spec_from_file_location(module_name, args.py)
else:
source_dir = args.py_package / "src" / args.py_package.name
source_dir = args.py_package / "src" / additional_data["package"]
init_file = source_dir / "__init__.py"
if not init_file.exists():
msg = f"Python package '{args.py_package}' does not contain an __init__.py file"
Expand Down
10 changes: 8 additions & 2 deletions code-gen/src/poly_scribe_code_gen/py_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import black
import isort
import jinja2
from docstring_parser import DocstringStyle, compose
from docstring_parser import Docstring, DocstringStyle, compose

from poly_scribe_code_gen._types import AdditionalData, ParsedIDL

Expand Down Expand Up @@ -49,7 +49,7 @@ def generate_python_package(parsed_idl: ParsedIDL, additional_data: AdditionalDa

out_dir.mkdir(parents=True, exist_ok=True)

source_dir = out_dir / "src" / out_dir.name
source_dir = out_dir / "src" / additional_data["package"]
source_dir.mkdir(parents=True, exist_ok=True)

generate_python(parsed_idl, additional_data, source_dir / "__init__.py")
Expand Down Expand Up @@ -186,17 +186,23 @@ def _transform_types(parsed_idl: ParsedIDL) -> ParsedIDL:

for derived_types in parsed_idl["inheritance_data"].values():
if struct_name in derived_types and not any(member == "type" for member in struct_data["members"]):
doc_string = Docstring()
doc_string.short_description = "Discriminator field"
struct_data["members"]["type"] = {
"type": f'Literal["{struct_name}"]',
"default": f'"{struct_name}"',
"block_comment": doc_string,
}

if struct_name in parsed_idl["inheritance_data"] and not any(
member == "type" for member in struct_data["members"]
):
doc_string = Docstring()
doc_string.short_description = "Discriminator field"
struct_data["members"]["type"] = {
"type": f'Literal["{struct_name}"]',
"default": f'"{struct_name}"',
"block_comment": doc_string,
}

for type_def in parsed_idl["typedefs"].values():
Expand Down
29 changes: 29 additions & 0 deletions code-gen/src/poly_scribe_code_gen/templates/python.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ class {{ struct_name }}{% if struct_data["inheritance"] %}({{ struct_data["inher
{% endfor %}

def load(model_type: Type[T], file: Union[Path, str]) -> T:
"""
Load a model from a file.

This function loads a file from the file system and tries to parse it as a given type.

Args:
model_type: The type of the model to load.
file: The file to load the model from.

Returns:
An instance of the model type.

Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file extension is not supported.
"""
if isinstance(file, str):
file = Path(file).resolve()
elif isinstance(file, Path):
Expand All @@ -92,6 +108,19 @@ def load(model_type: Type[T], file: Union[Path, str]) -> T:


def save(file: Union[Path, str], model: BaseModel):
"""
Save a model to a file.

This function saves a data structure to the file system.

Args:
file: The file to save the model to.
model: The model to save.

Raises:
TypeError: If the file argument is not a Path, str, or stream.
ValueError: If the file extension is not supported.
"""
if isinstance(file, str): # local path to file
file = Path(file).resolve()
elif isinstance(file, Path):
Expand Down
6 changes: 4 additions & 2 deletions code-gen/tests/py_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def test_generate_python_package(tmp_path: Path) -> None:

validate_pyproject_toml(toml_result, additional_data)

init_file = tmp_path / "src" / tmp_path.name / "__init__.py"
init_file = tmp_path / "src" / additional_data["package"] / "__init__.py"
assert init_file.exists()
with open(init_file) as f:
content = f.read()
Expand Down Expand Up @@ -527,7 +527,7 @@ def test_render_template_comments() -> None:
pattern = re.compile(r'"""\s*(.*?)\s*"""', re.DOTALL)
matches = pattern.findall(result)

assert len(matches) == 8
assert len(matches) == 10
assert "Typedef comment\n\ninline typedef comment" in matches[0]
assert "My Enum comment" in matches[1]
assert "Enum value 1 comment" in matches[2]
Expand All @@ -541,6 +541,8 @@ def test_render_template_comments() -> None:
assert "inline comment" in matches[5]
assert "Short comment for foo" in matches[6]
assert "Short comment for bar" in matches[7]
assert "Load" in matches[8]
assert "Save" in matches[9]


def test__render_template_string_default_value() -> None:
Expand Down
8 changes: 4 additions & 4 deletions test/integration_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
import integration_data
import integration_space
import os
import subprocess

Expand Down Expand Up @@ -30,7 +30,7 @@ def test_integration_data(test_num):
json_data = data_struct.model_dump_json()
assert json_data is not None

new_data_struct = integration_data.IntegrationTest.model_validate_json(json_data)
new_data_struct = integration_space.IntegrationTest.model_validate_json(json_data)

assert data_struct == new_data_struct

Expand All @@ -48,11 +48,11 @@ def test_integration_data_round_trip(input_format, output_format):
py_out = Path(tmp_dir).absolute() / f"integration_py_out.{output_format}"
cpp_out = Path(tmp_dir).absolute() / f"integration_cpp_out.{input_format}"

integration_data.save(py_out, data_struct)
integration_space.save(py_out, data_struct)

subprocess.run([cpp_exe, cpp_out, py_out], check=True)

new_data = integration_data.load(integration_data.IntegrationTest, cpp_out)
new_data = integration_space.load(integration_space.IntegrationTest, cpp_out)

compare_integration_data(data_struct, new_data)

Expand Down
2 changes: 1 addition & 1 deletion test/test_gen_data.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ if (NOT EXISTS "${expected_python_package}")
endif ()

# Check if the generated Python package contains the expected __init__.py file
set (expected_python_init_file "${expected_python_package}/src/integration_data/__init__.py")
set (expected_python_init_file "${expected_python_package}/src/${expected_namespace}/__init__.py")
if (NOT EXISTS "${expected_python_init_file}")
message (SEND_ERROR "Expected Python __init__.py file does not exist: ${expected_python_init_file}")
endif ()
Expand Down
22 changes: 11 additions & 11 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import random
import string

import integration_data
import integration_space


def random_string(length: int) -> str:
return "".join(random.choice(string.ascii_letters) for _ in range(length))


def gen_random_base():
obj = integration_data.Base(
obj = integration_space.Base(
vec_3d=[random.random() for _ in range(3)],
union_member=random.choice([random.random(), random.randint(0, 100), None]),
str_vec=[random_string(5) for _ in range(random.choice([1, 2, 5]))],
Expand All @@ -20,7 +20,7 @@ def gen_random_base():
def gen_random_derived_one():
base_dict = gen_random_base().model_dump()
base_dict.pop("type", None)
obj = integration_data.DerivedOne(
obj = integration_space.DerivedOne(
**base_dict,
string_map={
random_string(5): random_string(5) for _ in range(random.choice([1, 2, 5]))
Expand All @@ -32,17 +32,17 @@ def gen_random_derived_one():
def gen_random_derived_two():
base_dict = gen_random_base().model_dump()
base_dict.pop("type", None)
obj = integration_data.DerivedTwo(**base_dict)
obj = integration_space.DerivedTwo(**base_dict)
return obj


def gen_random_non_poly_derived():
obj = integration_data.NonPolyDerived(value=random.randint(0, 100))
obj = integration_space.NonPolyDerived(value=random.randint(0, 100))
return obj


def gen_random_integration_test():
obj = integration_data.IntegrationTest(
obj = integration_space.IntegrationTest(
object_map={
random_string(5): random.choice(
[gen_random_derived_one, gen_random_derived_two]
Expand All @@ -58,15 +58,15 @@ def gen_random_integration_test():
for _ in range(2)
],
enum_value=random.choice(
[integration_data.Enumeration.value1, integration_data.Enumeration.value2]
[integration_space.Enumeration.value1, integration_space.Enumeration.value2]
),
non_poly_derived=gen_random_non_poly_derived(),
)
return obj


def compare_integration_data(
lhs: integration_data.IntegrationTest, rhs: integration_data.IntegrationTest
lhs: integration_space.IntegrationTest, rhs: integration_space.IntegrationTest
):
assert lhs.non_poly_derived == rhs.non_poly_derived
assert lhs.enum_value == rhs.enum_value
Expand All @@ -85,15 +85,15 @@ def compare_integration_data(
compare_poly_structure(lhs.object_array[i], rhs.object_array[i])


def compare_poly_structure(lhs: integration_data.Base, rhs: integration_data.Base):
def compare_poly_structure(lhs: integration_space.Base, rhs: integration_space.Base):
assert all(abs(a - b) < 1e-6 for a, b in zip(lhs.vec_3d, rhs.vec_3d))
if isinstance(lhs.union_member, float) and isinstance(rhs.union_member, float):
assert abs(lhs.union_member - rhs.union_member) < 1e-6
else:
assert lhs.union_member == rhs.union_member
assert lhs.str_vec == rhs.str_vec

if isinstance(lhs, integration_data.DerivedOne):
if isinstance(lhs, integration_space.DerivedOne):
assert lhs.string_map == rhs.string_map
elif isinstance(lhs, integration_data.DerivedTwo):
elif isinstance(lhs, integration_space.DerivedTwo):
assert lhs.optional_value == rhs.optional_value
Loading