diff --git a/cadquery/assembly.py b/cadquery/assembly.py index 1aac2cb3d..88735d478 100644 --- a/cadquery/assembly.py +++ b/cadquery/assembly.py @@ -11,12 +11,13 @@ cast, get_args, ) -from typing_extensions import Literal +from typing_extensions import Literal, Self from typish import instance_of from uuid import uuid1 as uuid +from warnings import warn from .cq import Workplane -from .occ_impl.shapes import Shape, Compound +from .occ_impl.shapes import Shape, Compound, isSubshape, compound from .occ_impl.geom import Location from .occ_impl.assembly import Color from .occ_impl.solver import ( @@ -34,6 +35,7 @@ exportGLTF, STEPExportModeLiterals, ) +from .occ_impl.importers.assembly import importStep as _importStep from .selectors import _expression_grammar as _selector_grammar from .utils import deprecate @@ -155,6 +157,10 @@ def _copy(self) -> "Assembly": rv = self.__class__(self.obj, self.loc, self.name, self.color, self.metadata) + rv._subshape_colors = dict(self._subshape_colors) + rv._subshape_names = dict(self._subshape_names) + rv._subshape_layers = dict(self._subshape_layers) + for ch in self.children: ch_copy = ch._copy() ch_copy.parent = rv @@ -172,7 +178,7 @@ def add( loc: Optional[Location] = None, name: Optional[str] = None, color: Optional[Color] = None, - ) -> "Assembly": + ) -> Self: """ Add a subassembly to the current assembly. @@ -194,7 +200,7 @@ def add( name: Optional[str] = None, color: Optional[Color] = None, metadata: Optional[Dict[str, Any]] = None, - ) -> "Assembly": + ) -> Self: """ Add a subassembly to the current assembly with explicit location and name. @@ -342,11 +348,11 @@ def _subloc(self, name: str) -> Tuple[Location, str]: @overload def constrain( self, q1: str, q2: str, kind: ConstraintKind, param: Any = None - ) -> "Assembly": + ) -> Self: ... @overload - def constrain(self, q1: str, kind: ConstraintKind, param: Any = None) -> "Assembly": + def constrain(self, q1: str, kind: ConstraintKind, param: Any = None) -> Self: ... @overload @@ -358,13 +364,13 @@ def constrain( s2: Shape, kind: ConstraintKind, param: Any = None, - ) -> "Assembly": + ) -> Self: ... @overload def constrain( self, id1: str, s1: Shape, kind: ConstraintKind, param: Any = None, - ) -> "Assembly": + ) -> Self: ... def constrain(self, *args, param=None): @@ -409,7 +415,7 @@ def constrain(self, *args, param=None): return self - def solve(self, verbosity: int = 0) -> "Assembly": + def solve(self, verbosity: int = 0) -> Self: """ Solve the constraints. """ @@ -504,7 +510,7 @@ def save( tolerance: float = 0.1, angularTolerance: float = 0.1, **kwargs, - ) -> "Assembly": + ) -> Self: """ Save assembly to a file. @@ -560,7 +566,7 @@ def export( tolerance: float = 0.1, angularTolerance: float = 0.1, **kwargs, - ) -> "Assembly": + ) -> Self: """ Save assembly to a file. @@ -609,9 +615,26 @@ def export( return self @classmethod - def load(cls, path: str) -> "Assembly": + def importStep(cls, path: str) -> Self: + """ + Reads an assembly from a STEP file. + + :param path: Path and filename for writing. + :return: An Assembly object. + """ + + assy = cls() + _importStep(assy, path) + + return assy + + @classmethod + def load(cls, path: str) -> Self: + """ + Alias of importStep for now. + """ - raise NotImplementedError + return cls.importStep(path) @property def shapes(self) -> List[Shape]: @@ -712,12 +735,81 @@ def addSubshape( :return: The modified assembly. """ + # check if the subshape belongs to the stored object + if any(isSubshape(s, obj) for obj in self.shapes): + assy = self + else: + warn( + "Current node does not contain any Shapes, searching in subnodes. In the future this will result in an error." + ) + + found = False + for ch in self.children: + if any(isSubshape(s, obj) for obj in ch.shapes): + assy = ch + found = True + break + + if not found: + raise ValueError( + f"{s} is not a subshape of the current node or its children" + ) + # Handle any metadata we were passed if name: - self._subshape_names[s] = name + assy._subshape_names[s] = name if color: - self._subshape_colors[s] = color + assy._subshape_colors[s] = color if layer: - self._subshape_layers[s] = layer + assy._subshape_layers[s] = layer return self + + def __getitem__(self, name: str) -> "Assembly": + """ + [] based access to children. + """ + + return self.objects[name] + + def _ipython_key_completions_(self) -> List[str]: + """ + IPython autocompletion helper. + """ + + return list(self.objects.keys()) + + def __contains__(self, name: str) -> bool: + + return name in self.objects + + def __getattr__(self, name: str) -> "Assembly": + """ + . based access to children. + """ + + if name in self.objects: + return self.objects[name] + + raise AttributeError + + def __dir__(self): + """ + Modified __dir__ for autocompletion. + """ + + return list(self.__dict__) + list(ch.name for ch in self.children) + + def __getstate__(self): + """ + Explicit getstate needed due to getattr. + """ + + return self.__dict__ + + def __setstate__(self, d): + """ + Explicit setstate needed due to getattr. + """ + + self.__dict__ = d diff --git a/cadquery/func.py b/cadquery/func.py index 65f0c8266..d1fc8450e 100644 --- a/cadquery/func.py +++ b/cadquery/func.py @@ -51,4 +51,5 @@ setThreads, project, faceOn, + isSubshape, ) diff --git a/cadquery/occ_impl/assembly.py b/cadquery/occ_impl/assembly.py index 61a6bac4b..d76422bb7 100644 --- a/cadquery/occ_impl/assembly.py +++ b/cadquery/occ_impl/assembly.py @@ -10,13 +10,19 @@ List, cast, ) -from typing_extensions import Protocol +from typing_extensions import Protocol, Self from math import degrees, radians from OCP.TDocStd import TDocStd_Document from OCP.TCollection import TCollection_ExtendedString -from OCP.XCAFDoc import XCAFDoc_DocumentTool, XCAFDoc_ColorType, XCAFDoc_ColorGen +from OCP.XCAFDoc import ( + XCAFDoc_DocumentTool, + XCAFDoc_ColorType, + XCAFDoc_ColorGen, +) from OCP.XCAFApp import XCAFApp_Application +from OCP.BinXCAFDrivers import BinXCAFDrivers +from OCP.XmlXCAFDrivers import XmlXCAFDrivers from OCP.TDataStd import TDataStd_Name from OCP.TDF import TDF_Label from OCP.TopLoc import TopLoc_Location @@ -24,6 +30,7 @@ Quantity_ColorRGBA, Quantity_Color, Quantity_TOC_sRGB, + Quantity_TOC_RGB, ) from OCP.BRepAlgoAPI import BRepAlgoAPI_Fuse from OCP.TopTools import TopTools_ListOfShape @@ -67,7 +74,7 @@ def __init__(self, name: str): ... @overload - def __init__(self, r: float, g: float, b: float, a: float = 0): + def __init__(self, r: float, g: float, b: float, a: float = 0, srgb: bool = True): """ Construct a Color from RGB(A) values. @@ -75,6 +82,7 @@ def __init__(self, r: float, g: float, b: float, a: float = 0): :param g: green value, 0-1 :param b: blue value, 0-1 :param a: alpha value, 0-1 (default: 0) + :param srgb: srgb/linear rgb switch, bool (default: True) """ ... @@ -106,6 +114,14 @@ def __init__(self, *args, **kwargs): self.wrapped = Quantity_ColorRGBA( Quantity_Color(r, g, b, Quantity_TOC_sRGB), a ) + elif len(args) == 5: + r, g, b, a, srgb = args + self.wrapped = Quantity_ColorRGBA( + Quantity_Color( + r, g, b, Quantity_TOC_sRGB if srgb else Quantity_TOC_RGB + ), + a, + ) else: raise ValueError(f"Unsupported arguments: {args}, {kwargs}") @@ -136,6 +152,15 @@ def __setstate__(self, data: Tuple[float, float, float, float]): class AssemblyProtocol(Protocol): + def __init__( + self, + obj: AssemblyObjects = None, + loc: Optional[Location] = None, + name: Optional[str] = None, + color: Optional[Color] = None, + ): + ... + @property def loc(self) -> Location: ... @@ -148,6 +173,10 @@ def loc(self, value: Location) -> None: def name(self) -> str: ... + @name.setter + def name(self, value: str) -> None: + ... + @property def parent(self) -> Optional["AssemblyProtocol"]: ... @@ -156,10 +185,22 @@ def parent(self) -> Optional["AssemblyProtocol"]: def color(self) -> Optional[Color]: ... + @color.setter + def color(self, value: Optional[Color]) -> None: + ... + @property def obj(self) -> AssemblyObjects: ... + @obj.setter + def obj(self, value: AssemblyObjects) -> None: + ... + + @property + def objects(self) -> Dict[str, Self]: + ... + @property def shapes(self) -> Iterable[Shape]: ... @@ -180,6 +221,50 @@ def _subshape_colors(self) -> Dict[Shape, Color]: def _subshape_layers(self) -> Dict[Shape, str]: ... + @overload + def add( + self, + obj: Self, + loc: Optional[Location] = None, + name: Optional[str] = None, + color: Optional[Color] = None, + ) -> Self: + ... + + @overload + def add( + self, + obj: AssemblyObjects, + loc: Optional[Location] = None, + name: Optional[str] = None, + color: Optional[Color] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Self: + ... + + def add( + self, + obj: Union[Self, AssemblyObjects], + loc: Optional[Location] = None, + name: Optional[str] = None, + color: Optional[Color] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Self: + """ + Add a subassembly to the current assembly. + """ + ... + + def addSubshape( + self, + s: Shape, + name: Optional[str] = None, + color: Optional[Color] = None, + layer: Optional[str] = None, + ) -> Self: + ... + def traverse(self) -> Iterable[Tuple[str, "AssemblyProtocol"]]: ... @@ -191,6 +276,12 @@ def __iter__( ) -> Iterator[Tuple[Shape, str, Location, Optional[Color]]]: ... + def __getitem__(self, name: str) -> Self: + ... + + def __contains__(self, name: str) -> bool: + ... + def setName(l: TDF_Label, name: str, tool): @@ -208,31 +299,40 @@ def toCAF( mesh: bool = False, tolerance: float = 1e-3, angularTolerance: float = 0.1, + binary: bool = True, ) -> Tuple[TDF_Label, TDocStd_Document]: # prepare a doc app = XCAFApp_Application.GetApplication_s() - doc = TDocStd_Document(TCollection_ExtendedString("XmlOcaf")) + if binary: + BinXCAFDrivers.DefineFormat_s(app) + doc = TDocStd_Document(TCollection_ExtendedString("BinXCAF")) + else: + XmlXCAFDrivers.DefineFormat_s(app) + doc = TDocStd_Document(TCollection_ExtendedString("XmlXCAF")) + app.InitDocument(doc) tool = XCAFDoc_DocumentTool.ShapeTool_s(doc.Main()) tool.SetAutoNaming_s(False) ctool = XCAFDoc_DocumentTool.ColorTool_s(doc.Main()) + ltool = XCAFDoc_DocumentTool.LayerTool_s(doc.Main()) # used to store labels with unique part-color combinations - unique_objs: Dict[Tuple[Color, AssemblyObjects], TDF_Label] = {} + unique_objs: Dict[Tuple[Color | None, AssemblyObjects], TDF_Label] = {} # used to cache unique, possibly meshed, compounds; allows to avoid redundant meshing operations if same object is referenced multiple times in an assy compounds: Dict[AssemblyObjects, Compound] = {} - def _toCAF(el, ancestor, color) -> TDF_Label: + def _toCAF(el: AssemblyProtocol, ancestor: TDF_Label | None) -> TDF_Label: - # create a subassy - subassy = tool.NewShape() - setName(subassy, el.name, tool) + # create a subassy if needed + if el.children: + subassy = tool.NewShape() + setName(subassy, el.name, tool) # define the current color - current_color = el.color if el.color else color + current_color = el.color if el.color else None # add a leaf with the actual part if needed if el.obj: @@ -254,36 +354,82 @@ def _toCAF(el, ancestor, color) -> TDF_Label: compounds[key1] = compound tool.SetShape(lab, compound.wrapped) - setName(lab, f"{el.name}_part", tool) + setName(lab, f"{el.name}_part" if el.children else el.name, tool) unique_objs[key0] = lab # handle colors when exporting to STEP if coloredSTEP and current_color: setColor(lab, current_color, ctool) - tool.AddComponent(subassy, lab, TopLoc_Location()) + # handle subshape names/colors/layers + subshape_colors = el._subshape_colors + subshape_names = el._subshape_names + subshape_layers = el._subshape_layers + + for k in ( + subshape_colors.keys() | subshape_names.keys() | subshape_layers.keys() + ): + + subshape_label = tool.AddSubShape(lab, k.wrapped) + + # Sanity check, this is in principle enforced when calling addSubshape + assert not subshape_label.IsNull(), "Invalid subshape" + + # Set the name + if k in subshape_names: + TDataStd_Name.Set_s( + subshape_label, TCollection_ExtendedString(subshape_names[k]), + ) + + # Set the individual subshape color + if k in subshape_colors: + ctool.SetColor( + subshape_label, subshape_colors[k].wrapped, XCAFDoc_ColorGen, + ) + + # Also add a layer to hold the subshape label data + if k in subshape_layers: + layer_label = ltool.AddLayer( + TCollection_ExtendedString(subshape_layers[k]) + ) + ltool.SetLayer(subshape_label, layer_label) + + if el.children: + lab = tool.AddComponent(subassy, lab, TopLoc_Location()) + setName(lab, f"{el.name}_part", tool) + elif ancestor is not None: + lab = tool.AddComponent(ancestor, lab, el.loc.wrapped) + setName(lab, f"{el.name}", tool) # handle colors when *not* exporting to STEP if not coloredSTEP and current_color: - setColor(subassy, current_color, ctool) + if el.children: + setColor(subassy, current_color, ctool) + + if el.obj: + setColor(lab, current_color, ctool) # add children recursively for child in el.children: - _toCAF(child, subassy, current_color) + _toCAF(child, subassy) - if ancestor: + if ancestor and el.children: tool.AddComponent(ancestor, subassy, el.loc.wrapped) rv = subassy + elif ancestor: + rv = ancestor else: # update the top level location rv = TDF_Label() # NB: additional label is needed to apply the location + + # set location, is location is identity return subassy tool.SetLocation(subassy, assy.loc.wrapped, rv) setName(rv, assy.name, tool) return rv # process the whole assy recursively - top = _toCAF(assy, None, None) + top = _toCAF(assy, None) tool.UpdateAssemblies() diff --git a/cadquery/occ_impl/exporters/assembly.py b/cadquery/occ_impl/exporters/assembly.py index 328d8c04a..4d34342cf 100644 --- a/cadquery/occ_impl/exporters/assembly.py +++ b/cadquery/occ_impl/exporters/assembly.py @@ -3,12 +3,11 @@ from tempfile import TemporaryDirectory from shutil import make_archive -from itertools import chain from typing import Optional from typing_extensions import Literal from vtkmodules.vtkIOExport import vtkJSONSceneExporter, vtkVRMLExporter -from vtkmodules.vtkRenderingCore import vtkRenderer, vtkRenderWindow +from vtkmodules.vtkRenderingCore import vtkRenderWindow from OCP.XSControl import XSControl_WorkSession from OCP.STEPCAFControl import STEPCAFControl_Writer @@ -19,10 +18,16 @@ from OCP.TDocStd import TDocStd_Document from OCP.XCAFApp import XCAFApp_Application from OCP.XCAFDoc import XCAFDoc_DocumentTool, XCAFDoc_ColorGen -from OCP.XmlDrivers import ( - XmlDrivers_DocumentStorageDriver, - XmlDrivers_DocumentRetrievalDriver, +from OCP.XmlXCAFDrivers import ( + XmlXCAFDrivers_DocumentRetrievalDriver, + XmlXCAFDrivers_DocumentStorageDriver, ) +from OCP.BinXCAFDrivers import ( + BinXCAFDrivers_DocumentRetrievalDriver, + BinXCAFDrivers_DocumentStorageDriver, +) + + from OCP.TCollection import TCollection_ExtendedString, TCollection_AsciiString from OCP.PCDM import PCDM_StoreStatus from OCP.RWGltf import RWGltf_CafWriter @@ -33,7 +38,6 @@ from ..assembly import AssemblyProtocol, toCAF, toVTK, toFusedCAF from ..geom import Location from ..shapes import Shape, Compound -from ..assembly import Color class ExportModes: @@ -82,9 +86,6 @@ def exportAssembly( fuzzy_tol = kwargs["fuzzy_tol"] if "fuzzy_tol" in kwargs else None glue = kwargs["glue"] if "glue" in kwargs else False - # Use the assembly name if the user set it - assembly_name = assy.name if assy.name else str(uuid.uuid1()) - # Handle the doc differently based on which mode we are using if mode == "fused": _, doc = toFusedCAF(assy, glue, fuzzy_tol) @@ -98,6 +99,7 @@ def exportAssembly( writer.SetNameMode(True) Interface_Static.SetIVal_s("write.surfacecurve.mode", pcurves) Interface_Static.SetIVal_s("write.precision.mode", precision_mode) + Interface_Static.SetIVal_s("write.stepcaf.subshapes.name", 1) writer.Transfer(doc, STEPControl_StepModelType.STEPControl_AsIs) status = writer.Write(path) @@ -174,10 +176,14 @@ def _process_child(child: AssemblyProtocol, assy_label: TDF_Label): # Handle shape name, color and location part_label = shape_tool.AddShape(shape.wrapped, False) + # NB: this might overwrite the name if shape is referenced multiple times TDataStd_Name.Set_s(part_label, TCollection_ExtendedString(name)) + if color: color_tool.SetColor(part_label, color.wrapped, XCAFDoc_ColorGen) - shape_tool.AddComponent(assy_label, part_label, loc.wrapped) + + comp_label = shape_tool.AddComponent(assy_label, part_label, loc.wrapped) + TDataStd_Name.Set_s(comp_label, TCollection_ExtendedString(name)) # If this assembly has shape metadata, add it to the shape if ( @@ -267,29 +273,39 @@ def _process_assembly( return status == IFSelect_ReturnStatus.IFSelect_RetDone -def exportCAF(assy: AssemblyProtocol, path: str) -> bool: +def exportCAF(assy: AssemblyProtocol, path: str, binary: bool = False) -> bool: """ - Export an assembly to a OCAF xml file (internal OCCT format). + Export an assembly to an XCAF xml or xbf file (internal OCCT formats). """ folder, fname = os.path.split(path) name, ext = os.path.splitext(fname) ext = ext[1:] if ext[0] == "." else ext - _, doc = toCAF(assy) + _, doc = toCAF(assy, binary=binary) app = XCAFApp_Application.GetApplication_s() - store = XmlDrivers_DocumentStorageDriver( - TCollection_ExtendedString("Copyright: Open Cascade, 2001-2002") - ) - ret = XmlDrivers_DocumentRetrievalDriver() + store: BinXCAFDrivers_DocumentStorageDriver | XmlXCAFDrivers_DocumentStorageDriver + ret: BinXCAFDrivers_DocumentRetrievalDriver | XmlXCAFDrivers_DocumentRetrievalDriver + + # XBF + if binary: + ret = XmlXCAFDrivers_DocumentRetrievalDriver() + format_name = TCollection_AsciiString("BinXCAF") + format_desc = TCollection_AsciiString("Binary XCAF Document") + store = BinXCAFDrivers_DocumentStorageDriver() + ret = BinXCAFDrivers_DocumentRetrievalDriver() + # XML + else: + format_name = TCollection_AsciiString("XmlXCAF") + format_desc = TCollection_AsciiString("Xml XCAF Document") + store = XmlXCAFDrivers_DocumentStorageDriver( + TCollection_ExtendedString("Copyright: Open Cascade, 2001-2002") + ) + ret = XmlXCAFDrivers_DocumentRetrievalDriver() app.DefineFormat( - TCollection_AsciiString("XmlOcaf"), - TCollection_AsciiString("Xml XCAF Document"), - TCollection_AsciiString(ext), - ret, - store, + format_name, format_desc, TCollection_AsciiString(ext), ret, store, ) doc.SetRequestedFolder(TCollection_ExtendedString(folder)) diff --git a/cadquery/occ_impl/importers/assembly.py b/cadquery/occ_impl/importers/assembly.py new file mode 100644 index 000000000..7951b1366 --- /dev/null +++ b/cadquery/occ_impl/importers/assembly.py @@ -0,0 +1,216 @@ +from OCP.TopoDS import TopoDS_Shape +from OCP.TCollection import TCollection_ExtendedString +from OCP.Quantity import Quantity_ColorRGBA +from OCP.TDF import TDF_Label, TDF_LabelSequence, TDF_AttributeIterator +from OCP.IFSelect import IFSelect_RetDone +from OCP.TDocStd import TDocStd_Document +from OCP.TDataStd import TDataStd_Name +from OCP.TNaming import TNaming_NamedShape +from OCP.STEPCAFControl import STEPCAFControl_Reader +from OCP.XCAFDoc import XCAFDoc_ColorSurf, XCAFDoc_DocumentTool, XCAFDoc_GraphNode +from OCP.Interface import Interface_Static + +from ..assembly import AssemblyProtocol, Color +from ..geom import Location +from ..shapes import Shape + + +def importStep(assy: AssemblyProtocol, path: str): + """ + Import a step file into an assembly. + + :param assy: An Assembly object that will be packed with the contents of the STEP file. + :param path: Path and filename to the STEP file to read. + + :return: None + """ + + def _process_label(lbl: TDF_Label, parent: AssemblyProtocol): + """ + Recursive method to process the assembly in a top-down manner. + """ + + # Look for components + comp_labels = TDF_LabelSequence() + shape_tool.GetComponents_s(lbl, comp_labels) + + for i in range(comp_labels.Length()): + comp_label = comp_labels.Value(i + 1) + + # Get the location of the component label + loc = shape_tool.GetLocation_s(comp_label) + cq_loc = Location(loc) if loc else Location() + + if shape_tool.IsReference_s(comp_label): + ref_label = TDF_Label() + shape_tool.GetReferredShape_s(comp_label, ref_label) + + # Find the name of this referenced part + ref_name_attr = TDataStd_Name() + if ref_label.FindAttribute(TDataStd_Name.GetID_s(), ref_name_attr): + ref_name = str(ref_name_attr.Get().ToExtString()) + + if shape_tool.IsAssembly_s(ref_label): + + sub_assy = assy.__class__(name=ref_name) + + # Recursively process subassemblies + _ = _process_label(ref_label, sub_assy) + + # Add the subassy + parent.add(sub_assy, name=ref_name, loc=cq_loc) + + elif shape_tool.IsSimpleShape_s(ref_label): + + # A single shape needs to be added to the assembly + final_shape = shape_tool.GetShape_s(ref_label) + cq_shape = Shape.cast(final_shape) + + # Find the shape color, if there is one set for this shape + color = Quantity_ColorRGBA() + + # Extract the color, if present on the shape + if color_tool.GetColor(final_shape, XCAFDoc_ColorSurf, color): + rgb = color.GetRGB() + cq_color = Color( + rgb.Red(), rgb.Green(), rgb.Blue(), color.Alpha() + ) + else: + cq_color = None + + # this if/else is needed to handle different structures of STEP files + # "*"/"*_part" based naming is the default strucutre produced by CQ + if ref_name.endswith("_part") and ref_name.startswith(parent.name): + parent.obj = cq_shape + parent.loc = cq_loc + parent.color = cq_color + + # change the current assy to handle subshape data + current = parent + else: + tmp = assy.__class__( + cq_shape, loc=cq_loc, name=ref_name, color=cq_color + ) + parent.add(tmp) + + # change the current assy to handle subshape data + current = parent[ref_name] + + # iterate over subshape and handle names, layers and colors + subshape_labels = TDF_LabelSequence() + shape_tool.GetSubShapes_s(ref_label, subshape_labels) + + for child_label in subshape_labels: + + # Save the shape so that we can add it to the subshape data + cur_shape: TopoDS_Shape = shape_tool.GetShape_s(child_label) + + # Handle subshape name + name_attr = TDataStd_Name() + + if child_label.IsAttribute(TDataStd_Name.GetID_s()): + child_label.FindAttribute( + TDataStd_Name.GetID_s(), name_attr + ) + + current.addSubshape( + Shape.cast(cur_shape), + name=name_attr.Get().ToExtString(), + ) + + # Find the layer name, if there is one set for this shape + layers = TDF_LabelSequence() + layer_tool.GetLayers(child_label, layers) + + for lbl in layers: + name_attr = TDataStd_Name() + lbl.FindAttribute(TDataStd_Name.GetID_s(), name_attr) + + # Extract the layer name for the shape here + layer_name = name_attr.Get().ToExtString() + + # Add the layer as a subshape entry on the assembly + current.addSubshape(Shape.cast(cur_shape), layer=layer_name) + + # Find the subshape color, if there is one set for this shape + color = Quantity_ColorRGBA() + # Extract the color, if present on the shape + if color_tool.GetColor(cur_shape, XCAFDoc_ColorSurf, color): + rgb = color.GetRGB() + cq_color = Color( + rgb.Red(), rgb.Green(), rgb.Blue(), color.Alpha(), + ) + + # Save the color info via the assembly subshape mechanism + current.addSubshape(Shape.cast(cur_shape), color=cq_color) + + return parent + + # Document that the step file will be read into + doc = TDocStd_Document(TCollection_ExtendedString("XmlOcaf")) + + # Create and configure a STEP reader + step_reader = STEPCAFControl_Reader() + step_reader.SetColorMode(True) + step_reader.SetNameMode(True) + step_reader.SetLayerMode(True) + step_reader.SetSHUOMode(True) + + Interface_Static.SetIVal_s("read.stepcaf.subshapes.name", 1) + + # Read the STEP file + status = step_reader.ReadFile(path) + if status != IFSelect_RetDone: + raise ValueError(f"Error reading STEP file: {path}") + + # Transfer the contents of the STEP file to the document + step_reader.Transfer(doc) + + # Shape and color tools for extracting XCAF data + shape_tool = XCAFDoc_DocumentTool.ShapeTool_s(doc.Main()) + color_tool = XCAFDoc_DocumentTool.ColorTool_s(doc.Main()) + layer_tool = XCAFDoc_DocumentTool.LayerTool_s(doc.Main()) + + # Collect all the labels representing shapes in the document + labels = TDF_LabelSequence() + shape_tool.GetFreeShapes(labels) + + # Get the top-level label, which should represent an assembly + top_level_label = labels.Value(1) + # Make sure there is a top-level assembly + if shape_tool.IsTopLevel(top_level_label) and shape_tool.IsAssembly_s( + top_level_label + ): + # Set the name of the top-level assembly to match the top-level label + name_attr = TDataStd_Name() + top_level_label.FindAttribute(TDataStd_Name.GetID_s(), name_attr) + + # Manipulation of .objects is needed to maintain consistency + assy.objects.pop(assy.name) + assy.name = str(name_attr.Get().ToExtString()) + assy.objects[assy.name] = assy + + # Get the location of the top-level component + comp_labels = TDF_LabelSequence() + shape_tool.GetComponents_s(top_level_label, comp_labels) + comp_label = comp_labels.Value(1) + loc = shape_tool.GetLocation_s(comp_label) + assy.loc = Location(loc) + + # Start the recursive processing of labels + imported_assy = assy.__class__() + _process_label(top_level_label, imported_assy) + + # Handle a possible extra top-level node. This is done because cq.Assembly.export + # adds an extra top-level node which will cause a cascade of + # extras on successive round-trips. exportStepMeta does not add the extra top-level + # node and so does not exhibit this behavior. + if assy.name in imported_assy: + imported_assy = imported_assy[assy.name] + + # Copy all of the children over to the main assembly object + for child in imported_assy.children: + assy.add(child, name=child.name, color=child.color, loc=child.loc) + + else: + raise ValueError("Step file does not contain an assembly") diff --git a/cadquery/occ_impl/nurbs.py b/cadquery/occ_impl/nurbs.py new file mode 100644 index 000000000..4e8bc7fbf --- /dev/null +++ b/cadquery/occ_impl/nurbs.py @@ -0,0 +1,1991 @@ +# %% imports +import numpy as np +import scipy.sparse as sp + +from numba import njit as _njit + +from typing import NamedTuple, Optional, Tuple, List, Union, cast + +from math import comb + +from numpy.typing import NDArray +from numpy import linspace, ndarray + +from casadi import ldl, ldl_solve + +from OCP.Geom import Geom_BSplineCurve, Geom_BSplineSurface +from OCP.TColgp import TColgp_Array1OfPnt, TColgp_Array2OfPnt +from OCP.TColStd import ( + TColStd_Array1OfInteger, + TColStd_Array1OfReal, +) +from OCP.gp import gp_Pnt +from OCP.BRepBuilderAPI import BRepBuilderAPI_MakeEdge, BRepBuilderAPI_MakeFace + +from .shapes import Face, Edge + +from multimethod import multidispatch + +njit = _njit(cache=True, error_model="numpy", fastmath=True, nogil=True, parallel=False) + +njiti = _njit( + cache=True, inline="always", error_model="numpy", fastmath=True, parallel=False +) + + +# %% internal helpers + + +def _colPtsArray(pts: NDArray) -> TColgp_Array1OfPnt: + + rv = TColgp_Array1OfPnt(1, pts.shape[0]) + + for i, p in enumerate(pts): + rv.SetValue(i + 1, gp_Pnt(*p)) + + return rv + + +def _colPtsArray2(pts: NDArray) -> TColgp_Array2OfPnt: + + assert pts.ndim == 3 + + nu, nv, _ = pts.shape + + rv = TColgp_Array2OfPnt(1, len(pts), 1, len(pts[0])) + + for i, row in enumerate(pts): + for j, pt in enumerate(row): + rv.SetValue(i + 1, j + 1, gp_Pnt(*pt)) + + return rv + + +def _colRealArray(knots: NDArray) -> TColStd_Array1OfReal: + + rv = TColStd_Array1OfReal(1, len(knots)) + + for i, el in enumerate(knots): + rv.SetValue(i + 1, el) + + return rv + + +def _colIntArray(knots: NDArray) -> TColStd_Array1OfInteger: + + rv = TColStd_Array1OfInteger(1, len(knots)) + + for i, el in enumerate(knots): + rv.SetValue(i + 1, el) + + return rv + + +# %% vocabulary types + +Array = ndarray # NDArray[np.floating] +ArrayI = ndarray # NDArray[np.int_] + + +class COO(NamedTuple): + """ + COO sparse matrix container. + """ + + i: ArrayI + j: ArrayI + v: Array + + def coo(self): + + return sp.coo_matrix((self.v, (self.i, self.j))) + + def csc(self): + + return self.coo().tocsc() + + def csr(self): + + return self.coo().tocsr() + + +class Curve(NamedTuple): + """ + B-spline curve container. + """ + + pts: Array + knots: Array + order: int + periodic: bool + + def curve(self) -> Geom_BSplineCurve: + + if self.periodic: + mults = _colIntArray(np.ones_like(self.knots, dtype=int)) + knots = _colRealArray(self.knots) + else: + unique_knots, mults_arr = np.unique(self.knots, return_counts=True) + knots = _colRealArray(unique_knots) + mults = _colIntArray(mults_arr) + + return Geom_BSplineCurve( + _colPtsArray(self.pts), knots, mults, self.order, self.periodic, + ) + + def edge(self) -> Edge: + + return Edge(BRepBuilderAPI_MakeEdge(self.curve()).Shape()) + + @classmethod + def fromEdge(cls, e: Edge): + + assert ( + e.geomType() == "BSPLINE" + ), "B-spline geometry required, try converting first." + + g = e._geomAdaptor().BSpline() + + knots = np.repeat(list(g.Knots()), list(g.Multiplicities())) + pts = np.array([(p.X(), p.Y(), p.Z()) for p in g.Poles()]) + order = g.Degree() + periodic = g.IsPeriodic() + + return cls(pts, knots, order, periodic) + + def __call__(self, us: Array) -> Array: + + return nbCurve( + np.atleast_1d(us), self.order, self.knots, self.pts, self.periodic + ) + + def der(self, us: NDArray, dorder: int) -> NDArray: + + return nbCurveDer( + np.atleast_1d(us), self.order, dorder, self.knots, self.pts, self.periodic + ) + + +class Surface(NamedTuple): + """ + B-spline surface container. + """ + + pts: Array + uknots: Array + vknots: Array + uorder: int + vorder: int + uperiodic: bool + vperiodic: bool + + def surface(self) -> Geom_BSplineSurface: + + unique_knots, mults_arr = np.unique(self.uknots, return_counts=True) + uknots = _colRealArray(unique_knots) + umults = _colIntArray(mults_arr) + + unique_knots, mults_arr = np.unique(self.vknots, return_counts=True) + vknots = _colRealArray(unique_knots) + vmults = _colIntArray(mults_arr) + + return Geom_BSplineSurface( + _colPtsArray2(self.pts), + uknots, + vknots, + umults, + vmults, + self.uorder, + self.vorder, + self.uperiodic, + self.vperiodic, + ) + + def face(self, tol: float = 1e-3) -> Face: + + return Face(BRepBuilderAPI_MakeFace(self.surface(), tol).Shape()) + + @classmethod + def fromFace(cls, f: Face): + """ + Construct a surface from a face. + """ + + assert ( + f.geomType() == "BSPLINE" + ), "B-spline geometry required, try converting first." + + g = cast(Geom_BSplineSurface, f._geomAdaptor()) + + uknots = np.repeat(list(g.UKnots()), list(g.UMultiplicities())) + vknots = np.repeat(list(g.VKnots()), list(g.VMultiplicities())) + + tmp = [] + for i in range(1, g.NbUPoles() + 1): + tmp.append( + [ + [g.Pole(i, j).X(), g.Pole(i, j).Y(), g.Pole(i, j).Z(),] + for j in range(1, g.NbVPoles() + 1) + ] + ) + + pts = np.array(tmp) + + uorder = g.UDegree() + vorder = g.VDegree() + + uperiodic = g.IsUPeriodic() + vperiodic = g.IsVPeriodic() + + return cls(pts, uknots, vknots, uorder, vorder, uperiodic, vperiodic) + + def __call__(self, u: Array, v: Array) -> Array: + """ + Evaluate surface at (u,v) points. + """ + + return nbSurface( + np.atleast_1d(u), + np.atleast_1d(v), + self.uorder, + self.vorder, + self.uknots, + self.vknots, + self.pts, + self.uperiodic, + self.vperiodic, + ) + + def der(self, u: Array, v: Array, dorder: int) -> Array: + """ + Evaluate surface and derivatives at (u,v) points. + """ + + return nbSurfaceDer( + np.atleast_1d(u), + np.atleast_1d(v), + self.uorder, + self.vorder, + dorder, + self.uknots, + self.vknots, + self.pts, + self.uperiodic, + self.vperiodic, + ) + + def normal(self, u: Array, v: Array) -> Tuple[Array, Array]: + """ + Evaluate surface normals. + """ + + ders = self.der(u, v, 1) + + du = ders[:, 1, 0, :].squeeze() + dv = ders[:, 0, 1, :].squeeze() + + rv = np.atleast_2d(np.cross(du, dv)) + rv /= np.linalg.norm(rv, axis=1)[:, np.newaxis] + + return rv, ders[:, 0, 0, :].squeeze() + + +# %% basis functions + + +@njiti +def _preprocess( + u: Array, order: int, knots: Array, periodic: float +) -> Tuple[Array, Array, Optional[int], Optional[int], int]: + """ + Helper for handling peridocity. This function extends the knot vector, + wraps the parameters and calculates the delta span. + """ + + # handle periodicity + if periodic: + period = knots[-1] - knots[0] + u_ = u % period + knots_ext = extendKnots(order, knots) + minspan = 0 + maxspan = len(knots) - 1 + deltaspan = order - 1 + else: + u_ = u + knots_ext = knots + minspan = order + maxspan = knots.shape[0] - order - 1 + deltaspan = 0 + + return u_, knots_ext, minspan, maxspan, deltaspan + + +@njiti +def extendKnots(order: int, knots: Array) -> Array: + """ + Knot vector extension for periodic b-splines. + + Parameters + ---------- + order : int + B-spline order. + knots : Array + Knot vector. + + Returns + ------- + knots_ext : Array + Extended knots vector. + + """ + + return np.concat((knots[-order:-1] - knots[-1], knots, knots[-1] + knots[1:order])) + + +@njiti +def nbFindSpan( + u: float, + order: int, + knots: Array, + low: Optional[int] = None, + high: Optional[int] = None, +) -> int: + """ + NURBS book A2.1 with modifications to handle periodic usecases. + + Parameters + ---------- + u : float + Parameter value. + order : int + Spline order. + knots : ndarray + Knot vector. + + Returns + ------- + Span index. + + """ + + if low is None: + low = order + + if high is None: + high = knots.shape[0] - order - 1 + + mid = (low + high) // 2 + + if u >= knots[-1]: + return high - 1 # handle last span + elif u < knots[0]: + return low + + while u < knots[mid] or u >= knots[mid + 1]: + if u < knots[mid]: + high = mid + else: + low = mid + + mid = (low + high) // 2 + + return mid + + +@njiti +def nbBasis(i: int, u: float, order: int, knots: Array, out: Array): + """ + NURBS book A2.2 + + Parameters + ---------- + i : int + Span index. + u : float + Parameter value. + order : int + B-spline order. + knots : ndarray + Knot vector. + out : ndarray + B-spline basis function values. + + Returns + ------- + None. + + """ + + out[0] = 1.0 + + left = np.zeros_like(out) + right = np.zeros_like(out) + + for j in range(1, order + 1): + left[j] = u - knots[i + 1 - j] + right[j] = knots[i + j] - u + + saved = 0.0 + + for r in range(j): + temp = out[r] / (right[r + 1] + left[j - r]) + out[r] = saved + right[r + 1] * temp + saved = left[j - r] * temp + + out[j] = saved + + +@njiti +def nbBasisDer(i: int, u: float, order: int, dorder: int, knots: Array, out: Array): + """ + NURBS book A2.3 + + Parameters + ---------- + i : int + Span index. + u : float + Parameter value. + order : int + B-spline order. + dorder : int + Derivative order. + knots : ndarray + Knot vector. + out : ndarray + B-spline basis function and derivative values. + + Returns + ------- + None. + + """ + + ndu = np.zeros((order + 1, order + 1)) + + left = np.zeros(order + 1) + right = np.zeros(order + 1) + + a = np.zeros((2, order + 1)) + + ndu[0, 0] = 1 + + for j in range(1, order + 1): + left[j] = u - knots[i + 1 - j] + right[j] = knots[i + j] - u + + saved = 0.0 + + for r in range(j): + ndu[j, r] = right[r + 1] + left[j - r] + temp = ndu[r, j - 1] / ndu[j, r] + + ndu[r, j] = saved + right[r + 1] * temp + saved = left[j - r] * temp + + ndu[j, j] = saved + + # store the basis functions + out[0, :] = ndu[:, order] + + # calculate and store derivatives + + # loop over basis functions + for r in range(order + 1): + s1 = 0 + s2 = 1 + + a[0, 0] = 1 + + # loop over derivative orders + for k in range(1, dorder + 1): + d = 0.0 + rk = r - k + pk = order - k + + if r >= k: + a[s2, 0] = a[s1, 0] / ndu[pk + 1, rk] + d = a[s2, 0] * ndu[rk, pk] + + if rk >= -1: + j1 = 1 + else: + j1 = -rk + + if r - 1 <= pk: + j2 = k - 1 + else: + j2 = order - r + + for j in range(j1, j2 + 1): + a[s2, j] = (a[s1, j] - a[s1, j - 1]) / ndu[pk + 1, rk + j] + d += a[s2, j] * ndu[rk + j, pk] + + if r <= pk: + a[s2, k] = -a[s1, k - 1] / ndu[pk + 1, r] + d += a[s2, k] * ndu[r, pk] + + # store the kth derivative of rth basis + out[k, r] = d + + # switch + s1, s2 = s2, s1 + + # multiply recursively by the order + r = order + + for k in range(1, dorder + 1): + out[k, :] *= r + r *= order - k + + +# %% evaluation + + +@njit +def nbCurve( + u: Array, order: int, knots: Array, pts: Array, periodic: bool = False +) -> Array: + """ + NURBS book A3.1 with modifications to handle periodicity. + + Parameters + ---------- + u : Array + Parameter values. + order : int + B-spline order. + knots : Array + Knot vector. + pts : Array + Control points. + periodic : bool, optional + Periodicity flag. The default is False. + + Returns + ------- + Array + Curve values. + + """ + + # number of control points + nb = pts.shape[0] + + u_, knots_ext, minspan, maxspan, deltaspan = _preprocess(u, order, knots, periodic) + + # number of param values + nu = np.size(u) + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros(n) + + # initialize + out = np.zeros((nu, 3)) + + for i in range(nu): + ui = u_[i] + + # find span + span = nbFindSpan(ui, order, knots, minspan, maxspan) + deltaspan + + # evaluate chunk + nbBasis(span, ui, order, knots_ext, temp) + + # multiply by ctrl points + for j in range(order + 1): + out[i, :] += temp[j] * pts[(span - order + j) % nb, :] + + return out + + +@njit +def nbCurveDer( + u: Array, order: int, dorder: int, knots: Array, pts: Array, periodic: bool = False +) -> Array: + """ + NURBS book A3.2 with modifications to handle periodicity. + + Parameters + ---------- + u : Array + Parameter values. + order : int + B-spline order. + dorder : int + Derivative order. + knots : Array + Knot vector. + pts : Array + Control points. + periodic : bool, optional + Periodicity flag. The default is False. + + + Returns + ------- + Array + Curve values and derivatives. + + """ + # number of control points + nb = pts.shape[0] + + # handle periodicity + u_, knots_ext, minspan, maxspan, deltaspan = _preprocess(u, order, knots, periodic) + + # number of param values + nu = np.size(u) + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros((dorder + 1, n)) + + # initialize + out = np.zeros((nu, dorder + 1, 3)) + + for i in range(nu): + ui = u_[i] + + # find span + span = nbFindSpan(ui, order, knots, minspan, maxspan) + deltaspan + + # evaluate chunk + nbBasisDer(span, ui, order, dorder, knots_ext, temp) + + # multiply by ctrl points + for j in range(order + 1): + for k in range(dorder + 1): + out[i, k, :] += temp[k, j] * pts[(span - order + j) % nb, :] + + return out + + +@njit +def nbSurface( + u: Array, + v: Array, + uorder: int, + vorder: int, + uknots: Array, + vknots: Array, + pts: Array, + uperiodic: bool = False, + vperiodic: bool = False, +) -> Array: + """ + NURBS book A3.5 with modifications to handle periodicity. + + Parameters + ---------- + u : Array + U parameter values. + v : Array + V parameter values. + uorder : int + B-spline u order. + vorder : int + B-spline v order. + uknots : Array + U knot vector.. + vknots : Array + V knot vector.. + pts : Array + Control points. + uperiodic : bool, optional + U periodicity flag. The default is False. + vperiodic : bool, optional + V periodicity flag. The default is False. + + Returns + ------- + Array + Surface values. + + """ + + # number of control points + nub = pts.shape[0] + nvb = pts.shape[1] + + # handle periodicity + u_, uknots_ext, minspanu, maxspanu, deltaspanu = _preprocess( + u, uorder, uknots, uperiodic + ) + v_, vknots_ext, minspanv, maxspanv, deltaspanv = _preprocess( + v, vorder, vknots, vperiodic + ) + + # number of param values + nu = np.size(u) + + # chunck sizes + un = uorder + 1 + vn = vorder + 1 + + # temp chunck storage + utemp = np.zeros(un) + vtemp = np.zeros(vn) + + # initialize + out = np.zeros((nu, 3)) + + for i in range(nu): + ui = u_[i] + vi = v_[i] + + # find span + uspan = nbFindSpan(ui, uorder, uknots, minspanu, maxspanu) + deltaspanu + vspan = nbFindSpan(vi, vorder, vknots, minspanv, maxspanv) + deltaspanv + + # evaluate chunk + nbBasis(uspan, ui, uorder, uknots_ext, utemp) + nbBasis(vspan, vi, vorder, vknots_ext, vtemp) + + uind = uspan - uorder + temp = np.empty(3) + + # multiply by ctrl points: Nu.T*P*Nv + for j in range(vorder + 1): + + temp[:] = 0.0 + vind = vspan - vorder + j + + # calculate Nu.T*P + for k in range(uorder + 1): + temp += utemp[k] * pts[(uind + k) % nub, vind % nvb, :] + + # multiple by Nv + out[i, :] += vtemp[j] * temp + + return out + + +@njit +def nbSurfaceDer( + u: Array, + v: Array, + uorder: int, + vorder: int, + dorder: int, + uknots: Array, + vknots: Array, + pts: Array, + uperiodic: bool = False, + vperiodic: bool = False, +) -> Array: + """ + NURBS book A3.6 with modifications to handle periodicity. + + Parameters + ---------- + u : Array + U parameter values. + v : Array + V parameter values. + uorder : int + B-spline u order. + vorder : int + B-spline v order. + dorder : int + Maximum derivative order. + uknots : Array + U knot vector.. + vknots : Array + V knot vector.. + pts : Array + Control points. + uperiodic : bool, optional + U periodicity flag. The default is False. + vperiodic : bool, optional + V periodicity flag. The default is False. + + Returns + ------- + Array + Surface and derivative values. + + """ + + # max derivative orders + du = min(dorder, uorder) + dv = min(dorder, vorder) + + # number of control points + nub = pts.shape[0] + nvb = pts.shape[1] + + # handle periodicity + u_, uknots_ext, minspanu, maxspanu, deltaspanu = _preprocess( + u, uorder, uknots, uperiodic + ) + v_, vknots_ext, minspanv, maxspanv, deltaspanv = _preprocess( + v, vorder, vknots, vperiodic + ) + + # number of param values + nu = np.size(u) + + # chunck sizes + un = uorder + 1 + vn = vorder + 1 + + # temp chunck storage + + utemp = np.zeros((du + 1, un)) + vtemp = np.zeros((dv + 1, vn)) + + # initialize + out = np.zeros((nu, du + 1, dv + 1, 3)) + + for i in range(nu): + ui = u_[i] + vi = v_[i] + + # find span + uspan = nbFindSpan(ui, uorder, uknots, minspanu, maxspanu) + deltaspanu + vspan = nbFindSpan(vi, vorder, vknots, minspanv, maxspanv) + deltaspanv + + # evaluate chunk + nbBasisDer(uspan, ui, uorder, du, uknots_ext, utemp) + nbBasisDer(vspan, vi, vorder, dv, vknots_ext, vtemp) + + for k in range(du + 1): + + temp = np.zeros((vorder + 1, 3)) + + # Nu.T^(k)*pts + for s in range(vorder + 1): + for r in range(uorder + 1): + temp[s, :] += ( + utemp[k, r] + * pts[(uspan - uorder + r) % nub, (vspan - vorder + s) % nvb, :] + ) + + # ramaining derivative orders: dk + du <= dorder + dd = min(dorder - k, dv) + + # .. * Nv^(l) + for l in range(dd + 1): + for s in range(vorder + 1): + out[i, k, l, :] += vtemp[l, s] * temp[s, :] + + return out + + +# %% matrices + + +@njit +def designMatrix(u: Array, order: int, knots: Array, periodic: bool = False) -> COO: + """ + Create a sparse (possibly periodic) design matrix. + """ + + # extend the knots + u_, knots_ext, minspan, maxspan, deltaspan = _preprocess(u, order, knots, periodic) + + # number of param values + nu = len(u) + + # number of basis functions + nb = maxspan + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros(n) + + # initialize the empty matrix + rv = COO( + i=np.empty(n * nu, dtype=np.int64), + j=np.empty(n * nu, dtype=np.int64), + v=np.empty(n * nu), + ) + + # loop over param values + for i in range(nu): + ui = u_[i] + + # find the supporting span + span = nbFindSpan(ui, order, knots, minspan, maxspan) + deltaspan + + # evaluate non-zero functions + nbBasis(span, ui, order, knots_ext, temp) + + # update the matrix + rv.i[i * n : (i + 1) * n] = i + rv.j[i * n : (i + 1) * n] = ( + span - order + np.arange(n) + ) % nb # NB: this is due to peridicity + rv.v[i * n : (i + 1) * n] = temp + + return rv + + +@njit +def designMatrix2D( + u: Array, + v: Array, + uorder: int, + vorder: int, + uknots: Array, + vknots: Array, + uperiodic: bool = False, + vperiodic: bool = False, +) -> COO: + """ + Create a sparse tensor product design matrix. + """ + + # extend the knots and preprocess + u_, uknots_ext, minspanu, maxspanu, deltaspanu = _preprocess( + u, uorder, uknots, uperiodic + ) + v_, vknots_ext, minspanv, maxspanv, deltaspanv = _preprocess( + v, vorder, vknots, vperiodic + ) + + # number of param values + ni = len(u) + + # chunck size + nu = uorder + 1 + nv = vorder + 1 + nj = nu * nv + + # number of basis + nu_total = maxspanu + nv_total = maxspanv + + # temp chunck storage + utemp = np.zeros(nu) + vtemp = np.zeros(nv) + + # initialize the empty matrix + rv = COO( + i=np.empty(ni * nj, dtype=np.int64), + j=np.empty(ni * nj, dtype=np.int64), + v=np.empty(ni * nj), + ) + + # loop over param values + for i in range(ni): + ui, vi = u_[i], v_[i] + + # find the supporting span + uspan = nbFindSpan(ui, uorder, uknots, minspanu, maxspanu) + deltaspanu + vspan = nbFindSpan(vi, vorder, vknots, minspanv, maxspanv) + deltaspanv + + # evaluate non-zero functions + nbBasis(uspan, ui, uorder, uknots_ext, utemp) + nbBasis(vspan, vi, vorder, vknots_ext, vtemp) + + # update the matrix + rv.i[i * nj : (i + 1) * nj] = i + rv.j[i * nj : (i + 1) * nj] = ( + ((uspan - uorder + np.arange(nu)) % nu_total) * nv_total + + ((vspan - vorder + np.arange(nv)) % nv_total)[:, np.newaxis] + ).ravel() + rv.v[i * nj : (i + 1) * nj] = (utemp * vtemp[:, np.newaxis]).ravel() + + return rv + + +@njit +def periodicDesignMatrix(u: Array, order: int, knots: Array) -> COO: + """ + Create a sparse periodic design matrix. + """ + + return designMatrix(u, order, knots, periodic=True) + + +@njit +def derMatrix(u: Array, order: int, dorder: int, knots: Array) -> list[COO]: + """ + Create a sparse design matrix and corresponding derivative matrices. + """ + + # number of param values + nu = np.size(u) + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros((dorder + 1, n)) + + # initialize the empty matrix + rv = [] + + for _ in range(dorder + 1): + rv.append( + COO( + i=np.empty(n * nu, dtype=np.int64), + j=np.empty(n * nu, dtype=np.int64), + v=np.empty(n * nu), + ) + ) + + # loop over param values + for i in range(nu): + ui = u[i] + + # find the supporting span + span = nbFindSpan(ui, order, knots) + + # evaluate non-zero functions + nbBasisDer(span, ui, order, dorder, knots, temp) + + # update the matrices + for di in range(dorder + 1): + rv[di].i[i * n : (i + 1) * n] = i + rv[di].j[i * n : (i + 1) * n] = span - order + np.arange(n) + rv[di].v[i * n : (i + 1) * n] = temp[di, :] + + return rv + + +@njit +def periodicDerMatrix(u: Array, order: int, dorder: int, knots: Array) -> list[COO]: + """ + Create a sparse periodic design matrix and corresponding derivative matrices. + """ + + # extend the knots + knots_ext = np.concat( + (knots[-order:-1] - knots[-1], knots, knots[-1] + knots[1:order]) + ) + + # number of param values + nu = len(u) + + # number of basis functions + nb = len(knots) - 1 + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros((dorder + 1, n)) + + # initialize the empty matrix + rv = [] + + for _ in range(dorder + 1): + rv.append( + COO( + i=np.empty(n * nu, dtype=np.int64), + j=np.empty(n * nu, dtype=np.int64), + v=np.empty(n * nu), + ) + ) + + # loop over param values + for i in range(nu): + ui = u[i] + + # find the supporting span + span = nbFindSpan(ui, order, knots, 0, nb) + order - 1 + + # evaluate non-zero functions + nbBasisDer(span, ui, order, dorder, knots_ext, temp) + + # update the matrices + for di in range(dorder + 1): + rv[di].i[i * n : (i + 1) * n] = i + rv[di].j[i * n : (i + 1) * n] = ( + span - order + np.arange(n) + ) % nb # NB: this is due to peridicity + rv[di].v[i * n : (i + 1) * n] = temp[di, :] + + return rv + + +@njit +def periodicDiscretePenalty(us: Array, order: int) -> COO: + + if order not in (1, 2): + raise ValueError( + f"Only 1st and 2nd order penalty is supported, requested order {order}" + ) + + # number of rows + nb = len(us) + + # number of elements per row + ne = order + 1 + + # initialize the penlaty matrix + rv = COO( + i=np.empty(nb * ne, dtype=np.int64), + j=np.empty(nb * ne, dtype=np.int64), + v=np.empty(nb * ne), + ) + + if order == 1: + for ix in range(nb): + rv.i[ne * ix] = ix + rv.j[ne * ix] = (ix - 1) % nb + rv.v[ne * ix] = -0.5 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = (ix + 1) % nb + rv.v[ne * ix + 1] = 0.5 + + elif order == 2: + for ix in range(nb): + rv.i[ne * ix] = ix + rv.j[ne * ix] = (ix - 1) % nb + rv.v[ne * ix] = 1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + rv.v[ne * ix + 1] = -2 + + rv.i[ne * ix + 2] = ix + rv.j[ne * ix + 2] = (ix + 1) % nb + rv.v[ne * ix + 2] = 1 + + return rv + + +@njit +def discretePenalty(us: Array, order: int, splineorder: int = 3) -> COO: + + if order not in (1, 2): + raise ValueError( + f"Only 1st and 2nd order penalty is supported, requested order {order}" + ) + + # number of rows + nb = len(us) + + # number of elements per row + ne = order + 1 + + # initialize the penlaty matrix + rv = COO( + i=np.empty(nb * ne, dtype=np.int64), + j=np.empty(nb * ne, dtype=np.int64), + v=np.empty(nb * ne), + ) + + if order == 1: + for ix in range(nb): + if ix == 0: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix + rv.v[ne * ix] = -1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + 1 + rv.v[ne * ix + 1] = 1 + elif ix < nb - 1: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix - 1 + rv.v[ne * ix] = -0.5 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + 1 + rv.v[ne * ix + 1] = 0.5 + else: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix - 1 + rv.v[ne * ix] = -1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + rv.v[ne * ix + 1] = 1 + + elif order == 2: + for ix in range(nb): + if ix == 0: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix + rv.v[ne * ix] = 1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + 1 + rv.v[ne * ix + 1] = -2 + + rv.i[ne * ix + 2] = ix + rv.j[ne * ix + 2] = ix + 2 + rv.v[ne * ix + 2] = 1 + elif ix < nb - 1: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix - 1 + rv.v[ne * ix] = 1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + rv.v[ne * ix + 1] = -2 + + rv.i[ne * ix + 2] = ix + rv.j[ne * ix + 2] = ix + 1 + rv.v[ne * ix + 2] = 1 + else: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix - 2 + rv.v[ne * ix] = 1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix - 1 + rv.v[ne * ix + 1] = -2 + + rv.i[ne * ix + 2] = ix + rv.j[ne * ix + 2] = ix + rv.v[ne * ix + 2] = 1 + + return rv + + +@njit +def penaltyMatrix2D( + u: Array, + v: Array, + uorder: int, + vorder: int, + dorder: int, + uknots: Array, + vknots: Array, + uperiodic: bool = False, + vperiodic: bool = False, +) -> list[COO]: + """ + Create sparse tensor product 2D derivative matrices. + """ + + # extend the knots and preprocess + u_, uknots_ext, minspanu, maxspanu, deltaspanu = _preprocess( + u, uorder, uknots, uperiodic + ) + v_, vknots_ext, minspanv, maxspanv, deltaspanv = _preprocess( + v, vorder, vknots, vperiodic + ) + + # number of param values + ni = len(u) + + # chunck size + nu = uorder + 1 + nv = vorder + 1 + nj = nu * nv + + # number of basis + nu_total = maxspanu + nv_total = maxspanv + + # temp chunck storage + utemp = np.zeros((dorder + 1, nu)) + vtemp = np.zeros((dorder + 1, nv)) + + # initialize the emptry matrices + rv = [] + for i in range(dorder + 1): + rv.append( + COO( + i=np.empty(ni * nj, dtype=np.int64), + j=np.empty(ni * nj, dtype=np.int64), + v=np.empty(ni * nj), + ) + ) + + # loop over param values + for i in range(ni): + ui, vi = u_[i], v_[i] + + # find the supporting span + uspan = nbFindSpan(ui, uorder, uknots, minspanu, maxspanu) + deltaspanu + vspan = nbFindSpan(vi, vorder, vknots, minspanv, maxspanv) + deltaspanv + + # evaluate non-zero functions + nbBasisDer(uspan, ui, uorder, dorder, uknots_ext, utemp) + nbBasisDer(vspan, vi, vorder, dorder, vknots_ext, vtemp) + + # update the matrices - iterate over all derivative paris + for dv in range(dorder + 1): + + du = dorder - dv # NB: du + dv == dorder + + rv[dv].i[i * nj : (i + 1) * nj] = i + rv[dv].j[i * nj : (i + 1) * nj] = ( + ((uspan - uorder + np.arange(nu)) % nu_total) * nv_total + + ((vspan - vorder + np.arange(nv)) % nv_total)[:, np.newaxis] + ).ravel() + rv[dv].v[i * nj : (i + 1) * nj] = ( + utemp[du, :] * vtemp[dv, :, np.newaxis] + ).ravel() + + return rv + + +def uniformGrid( + uknots: Array, + vknots: Array, + uorder: int, + vorder: int, + uperiodic: bool, + vperiodic: bool, +) -> Tuple[Array, Array]: + """ + Create a uniform grid for evaluating penalties. + """ + + Up, Vp = np.meshgrid( + np.linspace( + uknots[0], uknots[-1], 2 * len(uknots) * uorder, endpoint=not uperiodic + ), + np.linspace( + vknots[0], vknots[-1], 2 * len(vknots) * vorder, endpoint=not vperiodic + ), + ) + up = Up.ravel() + vp = Vp.ravel() + + return up, vp + + +# %% construction + + +def parametrizeChord(data: Array) -> Array: + """ + Chord length parametrization. + """ + + dists = np.linalg.norm(data - np.roll(data, 1), axis=1) + params = np.cumulative_sum(dists) + + return params / params[-1] + + +@multidispatch +def periodicApproximate( + data: Array, + us: Optional[Array] = None, + knots: int | Array = 50, + order: int = 3, + penalty: int = 4, + lam: float = 0, +) -> Curve: + + npts = data.shape[0] + + # parametrize the points if needed + if us is None: + us = linspace(0, 1, npts, endpoint=False) + + # construct the knot vector + if isinstance(knots, int): + knots_ = linspace(0, 1, knots) + else: + knots_ = np.array(knots) + + # construct the design matrix + C = periodicDesignMatrix(us, order, knots_).csc() + CtC = C.T @ C + + # add the penalty if requested + if lam: + up = linspace(0, 1, order * npts, endpoint=False) + + assert penalty <= order + 2 + + # discrete + exact derivatives + if penalty > order: + Pexact = periodicDerMatrix(up, order, order - 1, knots_)[-1].csc() + Pdiscrete = periodicDiscretePenalty(up, penalty - order).csc() + + P = Pdiscrete @ Pexact + + # only exact derivatives + else: + P = periodicDerMatrix(up, order, penalty, knots_)[-1].csc() + + CtC += lam * P.T @ P + + # factorize + D, L, P = ldl(CtC, True) + + # invert + pts = ldl_solve(C.T @ data, D, L, P).toarray() + + # convert to an edge + rv = Curve(pts, knots_, order, periodic=True) + + return rv + + +@periodicApproximate.register +def periodicApproximate( + data: List[Array], + us: Optional[Array] = None, + knots: int | Array = 50, + order: int = 3, + penalty: int = 4, + lam: float = 0, +) -> List[Curve]: + + rv = [] + + npts = data[0].shape[0] + + # parametrize the points + us = linspace(0, 1, npts, endpoint=False) + + # construct the knot vector + if isinstance(knots, int): + knots_ = linspace(0, 1, knots) + else: + knots_ = np.array(knots) + + # construct the design matrix + C = periodicDesignMatrix(us, order, knots_).csc() + CtC = C.T @ C + + # add the penalty if requested + if lam: + up = linspace(0, 1, order * npts, endpoint=False) + + assert penalty <= order + 2 + + # discrete + exact derivatives + if penalty > order: + Pexact = periodicDerMatrix(up, order, order - 1, knots_)[-1].csc() + Pdiscrete = periodicDiscretePenalty(up, penalty - order).csc() + + P = Pdiscrete @ Pexact + + # only exact derivatives + else: + P = periodicDerMatrix(up, order, penalty, knots_)[-1].csc() + + CtC += lam * P.T @ P + + # factorize + D, L, P = ldl(CtC, True) + + # invert every dataset + for dataset in data: + pts = ldl_solve(C.T @ dataset, D, L, P).toarray() + + # convert to an edge and store + rv.append(Curve(pts, knots_, order, periodic=True)) + + return rv + + +@multidispatch +def approximate( + data: Array, + us: Optional[Array] = None, + knots: int | Array = 50, + order: int = 3, + penalty: int = 4, + lam: float = 0, + tangents: Optional[Tuple[Array, Array]] = None, +) -> Curve: + + npts = data.shape[0] + + # parametrize the points + us = linspace(0, 1, npts) + + # construct the knot vector + if isinstance(knots, int): + knots_ = np.concatenate( + (np.repeat(0, order), linspace(0, 1, knots), np.repeat(1, order)) + ) + else: + knots_ = np.array(knots) + + # construct the design matrix + C = designMatrix(us, order, knots_).csc() + CtC = C.T @ C + + # add a penalty term if requested + if lam: + up = linspace(0, 1, order * npts) + + assert penalty <= order + 2 + + # discrete + exact derivatives + if penalty > order: + Pexact = derMatrix(up, order, order - 1, knots_)[-1].csc() + Pdiscrete = discretePenalty(up, penalty - order, order).csc() + + P = Pdiscrete @ Pexact + + # only exact derivatives + else: + P = derMatrix(up, order, penalty, knots_)[-1].csc() + + CtC += lam * P.T @ P + + # clamp first and last point + Cc = C[[0, -1], :] + bc = data[[0, -1], :] + nc = 2 # number of constraints + + # handle tangent constraints if needed + if tangents: + nc += 2 + + Cc2 = derMatrix(us[[0, -1]], order, 1, knots_)[-1].csc() + + Cc = sp.vstack((Cc, Cc2)) + bc = np.vstack((bc, *tangents)) + + # final matrix and vector + Aug = sp.bmat([[CtC, Cc.T], [Cc, None]]) + data_aug = np.vstack((C.T @ data, bc)) + + # factorize + D, L, P = ldl(Aug, False) + + # invert + pts = ldl_solve(data_aug, D, L, P).toarray()[:-nc, :] + + # convert to an edge + rv = Curve(pts, knots_, order, periodic=False) + + return rv + + +@approximate.register +def approximate( + data: List[Array], + us: Optional[Array] = None, + knots: int | Array = 50, + order: int = 3, + penalty: int = 4, + lam: float = 0, + tangents: Optional[Union[Tuple[Array, Array], List[Tuple[Array, Array]]]] = None, +) -> List[Curve]: + + rv = [] + + npts = data[0].shape[0] + + # parametrize the points + us = linspace(0, 1, npts) + + # construct the knot vector + if isinstance(knots, int): + knots_ = np.concatenate( + (np.repeat(0, order), linspace(0, 1, knots), np.repeat(1, order)) + ) + else: + knots_ = np.array(knots) + + # construct the design matrix + C = designMatrix(us, order, knots_).csc() + CtC = C.T @ C + + # add a penalty term if requested + if lam: + up = linspace(0, 1, order * npts) + + assert penalty <= order + 2 + + # discrete + exact derivatives + if penalty > order: + Pexact = derMatrix(up, order, order - 1, knots_)[-1].csc() + Pdiscrete = discretePenalty(up, penalty - order, order).csc() + + P = Pdiscrete @ Pexact + + # only exact derivatives + else: + P = derMatrix(up, order, penalty, knots_)[-1].csc() + + CtC += lam * P.T @ P + + # clamp first and last point + Cc = C[[0, -1], :] + + nc = 2 # number of constraints + + # handle tangent constraints if needed + if tangents: + nc += 2 + Cc2 = derMatrix(us[[0, -1]], order, 1, knots_)[-1].csc() + Cc = sp.vstack((Cc, Cc2)) + + # final matrix and vector + Aug = sp.bmat([[CtC, Cc.T], [Cc, None]]) + + # factorize + D, L, P = ldl(Aug, False) + + # invert all datasets + for ix, dataset in enumerate(data): + bc = dataset[[0, -1], :] # first and last point for clamping + + if tangents: + if len(tangents) == len(data): + bc = np.vstack((bc, *tangents[ix])) + else: + bc = np.vstack((bc, *tangents)) + + # construct the LHS of the linear system + dataset_aug = np.vstack((C.T @ dataset, bc)) + + # actual solver + pts = ldl_solve(dataset_aug, D, L, P).toarray()[:-nc, :] + + # convert to an edge + rv.append(Curve(pts, knots_, order, periodic=False)) + + return rv + + +def approximate2D( + data: Array, + u: Array, + v: Array, + uorder: int, + vorder: int, + uknots: int | Array = 50, + vknots: int | Array = 50, + uperiodic: bool = False, + vperiodic: bool = False, + penalty: int = 3, + lam: float = 0, +) -> Surface: + """ + Simple 2D surface approximation (without any penalty). + """ + + # process the knots + uknots_ = uknots if isinstance(uknots, Array) else np.linspace(0, 1, uknots) + vknots_ = vknots if isinstance(vknots, Array) else np.linspace(0, 1, vknots) + + # create the desing matrix + C = designMatrix2D( + u, v, uorder, vorder, uknots_, vknots_, uperiodic, vperiodic + ).csc() + + # handle penalties if requested + if lam: + # construct the penalty grid + up, vp = uniformGrid(uknots_, vknots_, uorder, vorder, uperiodic, vperiodic) + + # construct the derivative matrices + penalties = penaltyMatrix2D( + up, vp, uorder, vorder, penalty, uknots_, vknots_, uperiodic, vperiodic, + ) + + # augment the design matrix + tmp = [comb(penalty, i) * penalties[i].csc() for i in range(penalty + 1)] + Lu = uknots_[-1] - uknots_[0] # v lenght of the parametric domain + Lv = vknots_[-1] - vknots_[0] # u lenght of the parametric domain + P = Lu * Lv / len(up) * sp.vstack(tmp) + + CtC = C.T @ C + lam * P.T @ P + else: + CtC = C.T @ C + + # solve normal equations + D, L, P = ldl(CtC, False) + pts = ldl_solve(C.T @ data, D, L, P).toarray() + + # construt the result + rv = Surface( + pts.reshape((len(uknots_) - int(uperiodic), len(vknots_) - int(vperiodic), 3)), + uknots_, + vknots_, + uorder, + vorder, + uperiodic, + vperiodic, + ) + + return rv + + +def fairPenalty(surf: Surface, penalty: int, lam: float) -> Surface: + """ + Penalty-based surface fairing. + """ + + uknots = surf.uknots + vknots = surf.vknots + pts = surf.pts.reshape((-1, 3)) + + # generate penalty grid + up, vp = uniformGrid( + uknots, vknots, surf.uorder, surf.vorder, surf.uperiodic, surf.vperiodic + ) + + # generate penalty matrix + penalties = penaltyMatrix2D( + up, + vp, + surf.uorder, + surf.vorder, + penalty, + surf.uknots, + surf.vknots, + surf.uperiodic, + surf.vperiodic, + ) + + tmp = [comb(penalty, i) * penalties[i].csc() for i in range(penalty + 1)] + Lu = uknots[-1] - uknots[0] # v lenght of the parametric domain + Lv = vknots[-1] - vknots[0] # u lenght of the parametric domain + P = Lu * Lv / len(up) * sp.vstack(tmp) + + # form and solve normal equations + CtC = sp.identity(pts.shape[0]) + lam * P.T @ P + + D, L, P = ldl(CtC, False) + pts_new = ldl_solve(pts, D, L, P).toarray() + + # construt the result + rv = Surface( + pts_new.reshape( + (len(uknots) - int(surf.uperiodic), len(vknots) - int(surf.vperiodic), 3) + ), + uknots, + vknots, + surf.uorder, + surf.vorder, + surf.uperiodic, + surf.vperiodic, + ) + + return rv + + +def periodicLoft(*curves: Curve, order: int = 3) -> Surface: + + nknots: int = len(curves) + 1 + + # collect control pts + pts = [el for el in np.stack([c.pts for c in curves]).swapaxes(0, 1)] + + # approximate + pts_new = [el.pts for el in periodicApproximate(pts, knots=nknots, order=order)] + + # construct the final surface + rv = Surface( + np.stack(pts_new).swapaxes(0, 1), + linspace(0, 1, nknots), + curves[0].knots, + order, + curves[0].order, + True, + curves[0].periodic, + ) + + return rv + + +def loft( + *curves: Curve, + order: int = 3, + lam: float = 1e-9, + penalty: int = 4, + tangents: Optional[List[Tuple[Array, Array]]] = None, +) -> Surface: + + nknots: int = len(curves) + + # collect control pts + pts = np.stack([c.pts for c in curves]) + + # approximate + pts_new = [] + + for j in range(pts.shape[1]): + pts_new.append( + approximate( + pts[:, j, :], + knots=nknots, + order=order, + lam=lam, + penalty=penalty, + tangents=tangents[j] if tangents else None, + ).pts + ) + + # construct the final surface + rv = Surface( + np.stack(pts_new).swapaxes(0, 1), + np.concatenate( + (np.repeat(0, order), linspace(0, 1, nknots), np.repeat(1, order)) + ), + curves[0].knots, + order, + curves[0].order, + False, + curves[0].periodic, + ) + + return rv + + +def reparametrize( + *curves: Curve, n: int = 100, knots: int = 100, w1: float = 1, w2: float = 1 +) -> List[Curve]: + + from scipy.optimize import fmin_l_bfgs_b + + n_curves = len(curves) + + u0_0 = np.linspace(0, 1, n, False) + u0 = np.tile(u0_0, n_curves) + + # scaling for the second cost term + scale = n * np.linalg.norm(curves[0](u0[0]) - curves[1](u0[n])) + + def cost(u: Array) -> float: + + rv1 = 0 + us = np.split(u, n_curves) + + pts = [] + + for i, ui in enumerate(us): + + # evaluate + pts.append(curves[i](ui)) + + # parametric distance between points on the same curve + rv1 += np.sum((ui[:-1] - ui[1:]) ** 2) + np.sum((ui[0] + 1 - ui[-1]) ** 2) + + rv2 = 0 + + for p1, p2 in zip(pts, pts[1:]): + + # geometric distance between points on adjecent curves + rv2 += np.sum(((p1 - p2) / scale) ** 2) + + return w1 * rv1 + w2 * rv2 + + def grad(u: Array) -> Array: + + rv1 = np.zeros_like(u) + us = np.split(u, n_curves) + + pts = [] + tgts = [] + + for i, ui in enumerate(us): + + # evaluate up to 1st derivative + tmp = curves[i].der(ui, 1) + + pts.append(tmp[:, 0, :].squeeze()) + tgts.append(tmp[:, 1, :].squeeze()) + + # parametric distance between points on the same curve + delta = np.roll(ui, -1) - ui + delta[-1] += 1 + delta *= -2 + delta -= np.roll(delta, 1) + + rv1[i * n : (i + 1) * n] = delta + + rv2 = np.zeros_like(u) + + for i, _ in enumerate(us): + # geometric distance between points on adjecent curves + + # first profile + if i == 0: + p1, p2, t = pts[i], pts[i + 1], tgts[i] + + rv2[i * n : (i + 1) * n] = (2 / scale ** 2 * (p1 - p2) * t).sum(1) + + # middle profile + elif i + 1 < n_curves: + p1, p2, t = pts[i], pts[i + 1], tgts[i] + p0 = pts[i - 1] + + rv2[i * n : (i + 1) * n] = (2 / scale ** 2 * (p1 - p2) * t).sum(1) + rv2[i * n : (i + 1) * n] += (-2 / scale ** 2 * (p0 - p1) * t).sum(1) + + # last profile + else: + p1, p2, t = pts[i - 1], pts[i], tgts[i] + + rv2[i * n : (i + 1) * n] = (-2 / scale ** 2 * (p1 - p2) * t).sum(1) + + return w1 * rv1 + w2 * rv2 + + usol, _, _ = fmin_l_bfgs_b(cost, u0, grad) + + us = np.split(usol, n_curves) + + return periodicApproximate( + [crv(u) for crv, u in zip(curves, us)], knots=knots, lam=0 + ) + + +def offset(surf: Surface, d: float, lam: float = 1e-3) -> Surface: + """ + Simple approximate offset. + """ + + # construct the knot grid + U, V = np.meshgrid( + np.linspace(surf.uknots[0], surf.uknots[-1], surf.uorder * len(surf.uknots)), + np.linspace(surf.vknots[0], surf.vknots[-1], surf.vorder * len(surf.uknots)), + ) + + us = U.ravel() + vs = V.ravel() + + # evaluate the normals + ns, pts = surf.normal(us, vs) + + # move the control points + pts += d * ns + + return approximate2D( + pts, + us, + vs, + surf.uorder, + surf.vorder, + surf.uknots, + surf.vknots, + surf.uperiodic, + surf.vperiodic, + lam=lam, + ) + + +# %% for removal? +@njit +def findSpan(v, knots): + + return np.searchsorted(knots, v, "right") - 1 + + +@njit +def findSpanLinear(v, knots): + + for rv in range(len(knots)): + if knots[rv] <= v and knots[rv + 1] > v: + return rv + + return -1 + + +@njit +def periodicKnots(degree: int, n_pts: int): + rv = np.arange(0.0, n_pts + degree + 1, 1.0) + rv /= rv[-1] + + return rv diff --git a/cadquery/occ_impl/shapes.py b/cadquery/occ_impl/shapes.py index 5e4e8b22b..fe90077e1 100644 --- a/cadquery/occ_impl/shapes.py +++ b/cadquery/occ_impl/shapes.py @@ -6558,6 +6558,20 @@ def check( return rv +def isSubshape(s1: Shape, s2: Shape) -> bool: + """ + Check is s1 is a subshape of s2. + """ + + shape_map = TopTools_IndexedDataMapOfShapeListOfShape() + + TopExp.MapShapesAndAncestors_s( + s2.wrapped, shapetype(s1.wrapped), inverse_shape_LUT[s2.ShapeType()], shape_map + ) + + return shape_map.Contains(s1.wrapped) + + #%% properties diff --git a/conda/meta.yaml b/conda/meta.yaml index 2f976f844..a20da6c65 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -26,6 +26,8 @@ requirements: - multimethod >=1.11,<2.0 - casadi - typish + - numba + - scipy - trame - trame-vtk diff --git a/environment.yml b/environment.yml index 3d5f4ef69..372c2c23a 100644 --- a/environment.yml +++ b/environment.yml @@ -25,6 +25,8 @@ dependencies: - pathspec - click - appdirs + - numba + - scipy - trame - trame-vtk - pip diff --git a/mypy.ini b/mypy.ini index 7bc958faf..fdd66dd63 100644 --- a/mypy.ini +++ b/mypy.ini @@ -37,6 +37,9 @@ ignore_missing_imports = True [mypy-casadi.*] ignore_missing_imports = True +[mypy-numba.*] +ignore_missing_imports = True + [mypy-trame.*] ignore_missing_imports = True diff --git a/tests/test_assembly.py b/tests/test_assembly.py index 75acac7b1..cee324a70 100644 --- a/tests/test_assembly.py +++ b/tests/test_assembly.py @@ -17,7 +17,7 @@ exportVRML, ) from cadquery.occ_impl.assembly import toJSON, toCAF, toFusedCAF -from cadquery.occ_impl.shapes import Face, box, cone +from cadquery.occ_impl.shapes import Face, box, cone, plane from OCP.gp import gp_XYZ from OCP.TDocStd import TDocStd_Document @@ -34,7 +34,7 @@ from OCP.STEPCAFControl import STEPCAFControl_Reader from OCP.IFSelect import IFSelect_RetDone from OCP.TDF import TDF_ChildIterator -from OCP.Quantity import Quantity_ColorRGBA, Quantity_TOC_sRGB +from OCP.Quantity import Quantity_ColorRGBA, Quantity_TOC_sRGB, Quantity_NameOfColor from OCP.TopAbs import TopAbs_ShapeEnum @@ -366,6 +366,50 @@ def chassis0_assy(): return chassis +@pytest.fixture +def subshape_assy(): + """ + Builds an assembly with the needed subshapes to test the export and import of STEP files. + """ + + # Create a simple assembly + assy = cq.Assembly(name="top_level") + cube_1 = cq.Workplane().box(10.0, 10.0, 10.0) + assy.add(cube_1, name="cube_1", color=cq.Color("green")) + + # Add subshape name, color and layer + assy["cube_1"].addSubshape( + cube_1.faces(">Z").val(), + name="cube_1_top_face", + color=cq.Color("red"), + layer="cube_1_top_face", + ) + + # Add a cylinder to the assembly + cyl_1 = cq.Workplane().cylinder(10.0, 2.5) + assy.add( + cyl_1, name="cyl_1", color=cq.Color("blue"), loc=cq.Location((0.0, 0.0, -10.0)) + ) + + # Add a subshape face for the cylinder + assy["cyl_1"].addSubshape( + cyl_1.faces(" TDocStd_Document: """Read STEP file, return XCAF document""" @@ -551,8 +595,6 @@ def test_color(): assert c3.wrapped.GetRGB().Red() == 1 assert c3.wrapped.Alpha() == 0.5 - c4 = cq.Color() - with pytest.raises(ValueError): cq.Color("?????") @@ -679,21 +721,24 @@ def test_meta_step_export(tmp_path_factory): assy.addSubshape(cube_1.faces(">Z").val(), name="cube_1_top_face") assy.addSubshape(cube_1.faces(">Z").val(), color=cq.Color(1.0, 0.0, 0.0)) assy.addSubshape(cube_1.faces(">Z").val(), layer="cube_1_top_face") - assy.addSubshape(cube_2.faces("Z").val(), name="cylinder_1_top_face") - assy.addSubshape(cylinder_1.faces(">Z").val(), color=cq.Color(1.0, 0.0, 0.0)) - assy.addSubshape(cylinder_1.faces(">Z").val(), layer="cylinder_1_top_face") - assy.addSubshape(cylinder_2.faces("Z").val(), name="cone_1_top_face") - assy.addSubshape(cone_1.faces(">Z").val(), color=cq.Color(1.0, 0.0, 0.0)) - assy.addSubshape(cone_1.faces(">Z").val(), layer="cone_1_top_face") - assy.addSubshape(cone_2.faces("Z").val(), name="cylinder_1_top_face") + assy.addSubshape(cylinder_1.faces(">Z").val(), color=cq.Color(1.0, 0.0, 0.0)) + assy.addSubshape(cylinder_1.faces(">Z").val(), layer="cylinder_1_top_face") + assy.addSubshape(cylinder_2.faces("Z").val(), name="cone_1_top_face") + assy.addSubshape(cone_1.faces(">Z").val(), color=cq.Color(1.0, 0.0, 0.0)) + assy.addSubshape(cone_1.faces(">Z").val(), layer="cone_1_top_face") + assy.addSubshape(cone_2.faces("Z").val(), name="cube_top_face") @@ -785,6 +829,376 @@ def test_meta_step_export_edge_cases(tmp_path_factory): assert success +def test_assembly_step_import(tmp_path_factory, subshape_assy): + """ + Test if the STEP import works correctly for an assembly with subshape data attached. + """ + + # Use a temporary directory + tmpdir = tmp_path_factory.mktemp("out") + assy_step_path = os.path.join(tmpdir, "assembly_with_subshapes.step") + + subshape_assy.export(assy_step_path) + + # Import the STEP file back in + imported_assy = cq.Assembly.importStep(assy_step_path) + + # Check that the assembly was imported successfully + assert imported_assy is not None + + # Check for appropriate part name + assert imported_assy.children[0].name == "cube_1" + # Check for approximate color match + assert pytest.approx(imported_assy.children[0].color.toTuple(), rel=0.01) == ( + 0.0, + 1.0, + 0.0, + 1.0, + ) + # Check for appropriate part name + assert imported_assy.children[1].name == "cyl_1" + # Check for approximate color match + assert pytest.approx(imported_assy.children[1].color.toTuple(), rel=0.01) == ( + 0.0, + 0.0, + 1.0, + 1.0, + ) + + # Make sure the shape locations were applied correctly + assert imported_assy.children[1].loc.toTuple()[0] == (0.0, 0.0, -10.0) + + # Check the top-level assembly name + assert imported_assy.name == "top_level" + + # Test a STEP file that does not contain an assembly + wp_step_path = os.path.join(tmpdir, "plain_workplane.step") + res = cq.Workplane().box(10, 10, 10) + res.export(wp_step_path) + + # Import the STEP file back in + with pytest.raises(ValueError): + imported_assy = cq.Assembly.importStep(wp_step_path) + + +def test_assembly_subshape_step_import(tmp_path_factory, subshape_assy): + """ + Test if a STEP file containing subshape information can be imported correctly. + """ + + tmpdir = tmp_path_factory.mktemp("out") + assy_step_path = os.path.join(tmpdir, "subshape_assy.step") + + # Export the assembly + subshape_assy.export(assy_step_path) + + # Import the STEP file back in + imported_assy = cq.Assembly.load(assy_step_path) + assert imported_assy.name == "top_level" + + # Check the advanced face name + assert len(imported_assy.children[0]._subshape_names) == 1 + assert ( + list(imported_assy.children[0]._subshape_names.values())[0] == "cube_1_top_face" + ) + + # Check the color + color = list(imported_assy.children[0]._subshape_colors.values())[0] + assert Quantity_NameOfColor.Quantity_NOC_RED == color.wrapped.GetRGB().Name() + + # Check the layer info + layer_name = list(imported_assy["cube_1"]._subshape_layers.values())[0] + assert layer_name == "cube_1_top_face" + + layer_name = list(imported_assy["cyl_1"]._subshape_layers.values())[0] + assert layer_name == "cylinder_bottom_face" + + layer_name = list(imported_assy["cyl_1"]._subshape_layers.values())[1] + assert layer_name == "cylinder_bottom_wire" + + +def test_assembly_multi_subshape_step_import(tmp_path_factory): + """ + Test if a STEP file containing subshape information can be imported correctly. + """ + + tmpdir = tmp_path_factory.mktemp("out") + assy_step_path = os.path.join(tmpdir, "multi_subshape_assy.step") + + # Create a basic assembly + cube_1 = cq.Workplane().box(10, 10, 10) + assy = cq.Assembly(name="top_level") + assy.add(cube_1, name="cube_1", color=cq.Color("green")) + cube_2 = cq.Workplane().box(5, 5, 5) + assy.add(cube_2, name="cube_2", color=cq.Color("blue"), loc=cq.Location(10, 10, 10)) + + # Add subshape name, color and layer + assy.addSubshape( + cube_1.faces(">Z").val(), + name="cube_1_top_face", + color=cq.Color("red"), + layer="cube_1_top_face", + ) + assy.addSubshape( + cube_2.faces(">X").val(), + name="cube_2_right_face", + color=cq.Color("red"), + layer="cube_2_right_face", + ) + + # Export the assembly + success = exportStepMeta(assy, assy_step_path) + assert success + + # Import the STEP file back in + imported_assy = cq.Assembly.importStep(assy_step_path) + + # Check that the top-level assembly name is correct + assert imported_assy.name == "top_level" + + # Check the advanced face name for the first cube + assert len(imported_assy.children[0]._subshape_names) == 1 + assert ( + list(imported_assy.children[0]._subshape_names.values())[0] == "cube_1_top_face" + ) + + # Check the color for the first cube + color = list(imported_assy.children[0]._subshape_colors.values())[0] + assert Quantity_NameOfColor.Quantity_NOC_RED == color.wrapped.GetRGB().Name() + + # Check the layer info for the first cube + layer_name = list(imported_assy.children[0]._subshape_layers.values())[0] + assert layer_name == "cube_1_top_face" + + # Check the advanced face name for the second cube + assert len(imported_assy.children[1]._subshape_names) == 1 + assert ( + list(imported_assy.children[1]._subshape_names.values())[0] + == "cube_2_right_face" + ) + + # Check the color + color = list(imported_assy.children[1]._subshape_colors.values())[0] + assert Quantity_NameOfColor.Quantity_NOC_RED == color.wrapped.GetRGB().Name() + + # Check the layer info + layer_name = list(imported_assy.children[1]._subshape_layers.values())[0] + assert layer_name == "cube_2_right_face" + + +def test_bad_step_file_import(tmp_path_factory): + """ + Test if a bad STEP file raises an error when importing. + """ + + tmpdir = tmp_path_factory.mktemp("out") + bad_step_path = os.path.join(tmpdir, "bad_step.step") + + # Check that an error is raised when trying to import a non-existent STEP file + with pytest.raises(ValueError): + # Export the assembly + cq.Assembly.importStep(bad_step_path) + + +def test_plain_assembly_import(tmp_path_factory): + """ + Test to make sure that importing plain assemblies has not been broken. + """ + + tmpdir = tmp_path_factory.mktemp("out") + plain_step_path = os.path.join(tmpdir, "plain_assembly_step.step") + + # Simple cubes + cube_1 = cq.Workplane().box(10, 10, 10) + cube_2 = cq.Workplane().box(5, 5, 5) + cube_3 = cq.Workplane().box(5, 5, 5) + cube_4 = cq.Workplane().box(5, 5, 5) + + assy = cq.Assembly(name="top_level", loc=cq.Location(10, 10, 10)) + assy.add(cube_1, color=cq.Color("green")) + assy.add(cube_2, loc=cq.Location((10, 10, 10)), color=cq.Color("red")) + assy.add(cube_3, loc=cq.Location((-10, -10, -10)), color=cq.Color("red")) + assy.add(cube_4, loc=cq.Location((10, -10, -10)), color=cq.Color("red")) + + # Export the assembly, but do not use the meta STEP export method + assy.export(plain_step_path) + + # Import the STEP file back in + imported_assy = cq.Assembly.importStep(plain_step_path) + assert imported_assy.name == "top_level" + + # Check the locations + assert imported_assy.children[0].loc.toTuple()[0] == (0.0, 0.0, 0.0,) + assert imported_assy.children[1].loc.toTuple()[0] == (10.0, 10.0, 10.0,) + assert imported_assy.children[2].loc.toTuple()[0] == (-10.0, -10.0, -10.0,) + assert imported_assy.children[3].loc.toTuple()[0] == (10.0, -10.0, -10.0,) + + # Make sure the location of the top-level assembly was preserved + assert imported_assy.loc.toTuple() == cq.Location((10, 10, 10)).toTuple() + + # Check the colors + assert pytest.approx(imported_assy.children[0].color.toTuple(), rel=0.01) == ( + 0.0, + 1.0, + 0.0, + 1.0, + ) # green + assert pytest.approx(imported_assy.children[1].color.toTuple(), rel=0.01) == ( + 1.0, + 0.0, + 0.0, + 1.0, + ) # red + assert pytest.approx(imported_assy.children[2].color.toTuple(), rel=0.01) == ( + 1.0, + 0.0, + 0.0, + 1.0, + ) # red + assert pytest.approx(imported_assy.children[3].color.toTuple(), rel=0.01) == ( + 1.0, + 0.0, + 0.0, + 1.0, + ) # red + + +def test_copied_assembly_import(tmp_path_factory): + """ + Tests to make sure that copied children in assemblies work correctly. + """ + from cadquery import Assembly, Location, Color + from cadquery.func import box, rect + + # Create the temporary directory + tmpdir = tmp_path_factory.mktemp("out") + + # prepare the model + def make_model(name: str, COPY: bool): + name = os.path.join(tmpdir, name) + + b = box(1, 1, 1) + + assy = Assembly(name="test_assy") + assy.add(box(1, 2, 5), color=Color("green")) + + for v in rect(10, 10).vertices(): + assy.add( + b.copy() if COPY else b, loc=Location(v.Center()), color=Color("red") + ) + + assy.export(name) + + return assy + + make_model("test_assy_copy.step", True) + make_model("test_assy.step", False) + + # import the assy with copies + assy_copy = Assembly.importStep(os.path.join(tmpdir, "test_assy_copy.step")) + assert 5 == len(assy_copy.children) + + # import the assy without copies + assy_normal = Assembly.importStep(os.path.join(tmpdir, "test_assy.step")) + assert 5 == len(assy_normal.children) + + +def test_nested_subassembly_step_import(tmp_path_factory): + """ + Tests if the STEP import works correctly with nested subassemblies. + """ + + tmpdir = tmp_path_factory.mktemp("out") + nested_step_path = os.path.join(tmpdir, "plain_assembly_step.step") + + # Create a simple assembly + assy = cq.Assembly() + assy.add(cq.Workplane().box(10, 10, 10), name="box_1") + + # Create a simple subassembly + subassy = cq.Assembly() + subassy.add(cq.Workplane().box(5, 5, 5), name="box_2", loc=cq.Location(10, 10, 10)) + + # Nest the subassembly + assy.add(subassy) + + # Export and then re-import the nested assembly STEP + assy.export(nested_step_path) + imported_assy = cq.Assembly.importStep(nested_step_path) + + # Check the locations + assert imported_assy.children[0].loc.toTuple()[0] == (0.0, 0.0, 0.0) + assert imported_assy.children[1].objects["box_2"].loc.toTuple()[0] == ( + 10.0, + 10.0, + 10.0, + ) + + +def test_assembly_step_import_roundtrip(tmp_path_factory): + """ + Tests that the assembly does not mutate during successive export-import round trips. + """ + + # Set up the temporary directory + tmpdir = tmp_path_factory.mktemp("out") + round_trip_step_path = os.path.join(tmpdir, "round_trip.step") + + # Create a sample assembly + assy_orig = cq.Assembly(name="top-level") + assy_orig.add(cq.Workplane().box(10, 10, 10), name="cube_1", color=cq.Color("red")) + subshape_assy = cq.Assembly(name="nested-assy") + subshape_assy.add( + cq.Workplane().cylinder(height=10.0, radius=2.5), + name="cylinder_1", + color=cq.Color("blue"), + loc=cq.Location((20, 20, 20)), + ) + assy_orig.add(subshape_assy) + + # First export + assy_orig.export(round_trip_step_path) + + # First import + assy = cq.Assembly.importStep(round_trip_step_path) + + # Second export + assy.export(round_trip_step_path) + + # Second import + assy = cq.Assembly.importStep(round_trip_step_path) + + # Check some general aspects of the assembly structure now + for k in assy_orig.objects: + assert k in assy + + for k in assy.objects: + assert k in assy_orig + + assert len(assy.children) == 2 + assert assy.name == "top-level" + assert assy.children[0].name == "cube_1" + assert assy.children[1].children[0].name == "cylinder_1" + + # First meta export + exportStepMeta(assy, round_trip_step_path) + + # First meta import + assy = cq.Assembly.importStep(round_trip_step_path) + + # Second meta export + exportStepMeta(assy, round_trip_step_path) + + # Second meta import + assy = cq.Assembly.importStep(round_trip_step_path) + + # Check some general aspects of the assembly structure now + assert len(assy.children) == 2 + assert assy.name == "top-level" + assert assy.children[0].name == "cube_1" + assert assy.children[1].children[0].name == "cylinder_1" + + @pytest.mark.parametrize( "assy_fixture, expected", [ @@ -907,7 +1321,7 @@ def test_save_stl_formats(nested_assy_sphere): assert os.path.exists("nested.stl") # Trying to read a binary file as UTF-8/ASCII should throw an error - with pytest.raises(UnicodeDecodeError) as info: + with pytest.raises(UnicodeDecodeError): with open("nested.stl", "r") as file: file.read() @@ -924,7 +1338,7 @@ def test_save_gltf(nested_assy_sphere): assert os.path.exists("nested.glb") # Trying to read a binary file as UTF-8/ASCII should throw an error - with pytest.raises(UnicodeDecodeError) as info: + with pytest.raises(UnicodeDecodeError): with open("nested.glb", "r") as file: file.read() @@ -939,7 +1353,7 @@ def test_exportGLTF(nested_assy_sphere): # Test binary export inferred from file extension cq.exporters.assembly.exportGLTF(nested_assy_sphere, "nested_export_gltf.glb") - with pytest.raises(UnicodeDecodeError) as info: + with pytest.raises(UnicodeDecodeError): with open("nested_export_gltf.glb", "r") as file: file.read() @@ -947,7 +1361,7 @@ def test_exportGLTF(nested_assy_sphere): cq.exporters.assembly.exportGLTF( nested_assy_sphere, "nested_export_gltf_2.glb", binary=True ) - with pytest.raises(UnicodeDecodeError) as info: + with pytest.raises(UnicodeDecodeError): with open("nested_export_gltf_2.glb", "r") as file: file.read() @@ -2000,3 +2414,22 @@ def test_step_color(tmp_path_factory): assert "0.47" in line assert "0.25" in line assert "0.18" in line + + +def test_special_methods(subshape_assy): + """ + Smoke-test some special methods. + """ + + assert "cube_1" in subshape_assy.__dir__() + assert "cube_1" in subshape_assy._ipython_key_completions_() + assert "cube_1" in subshape_assy + + subshape_assy["cube_1"] + subshape_assy.cube_1 + + with pytest.raises(KeyError): + subshape_assy["123456"] + + with pytest.raises(AttributeError): + subshape_assy.cube_123456 diff --git a/tests/test_nurbs.py b/tests/test_nurbs.py new file mode 100644 index 000000000..9f57a8a46 --- /dev/null +++ b/tests/test_nurbs.py @@ -0,0 +1,276 @@ +from cadquery.occ_impl.nurbs import ( + designMatrix, + periodicDesignMatrix, + designMatrix2D, + nbFindSpan, + nbBasis, + nbBasisDer, + Curve, + Surface, + approximate, + periodicApproximate, + periodicLoft, + loft, + reparametrize, +) + +from cadquery.func import circle + +import numpy as np +import scipy.sparse as sp + +from pytest import approx, fixture, mark + + +@fixture +def circles() -> list[Curve]: + + # u,v periodic + c1 = circle(1).toSplines() + c2 = circle(5) + + cs = [ + Curve.fromEdge(c1.moved(loc)) + for loc in c2.locations(np.linspace(0, 1, 10, False)) + ] + + return cs + + +@fixture +def trimmed_circles() -> list[Curve]: + + c1 = circle(1).trim(0, 1).toSplines() + c2 = circle(5) + + cs = [ + Curve.fromEdge(c1.moved(loc)) + for loc in c2.locations(np.linspace(0, 1, 10, False)) + ] + + return cs + + +@fixture +def rotated_circles() -> list[Curve]: + + pts1 = np.array([v.toTuple() for v in circle(1).sample(100)[0]]) + pts2 = np.array([v.toTuple() for v in circle(1).moved(z=1, rz=90).sample(100)[0]]) + + c1 = periodicApproximate(pts1) + c2 = periodicApproximate(pts2) + + return [c1, c2] + + +def test_periodic_dm(): + + knots = np.linspace(0, 1, 5) + params = np.linspace(0, 1, 100) + order = 3 + + res = periodicDesignMatrix(params, order, knots) + + C = sp.coo_array((res.v, (res.i, res.j))) + + assert C.shape[0] == len(params) + assert C.shape[1] == len(knots) - 1 + + +def test_dm_2d(): + + uknots = np.array([0, 0, 0, 0, 0.25, 0.5, 0.75, 1, 1, 1, 1]) + uparams = np.linspace(0, 1, 100) + uorder = 3 + + vknots = np.array([0, 0, 0, 0.5, 1, 1, 1]) + vparams = np.linspace(0, 1, 100) + vorder = 2 + + res = designMatrix2D(uparams, vparams, uorder, vorder, uknots, vknots) + + C = res.coo() + + assert C.shape[0] == len(uparams) + assert C.shape[1] == (len(uknots) - uorder - 1) * (len(vknots) - vorder - 1) + + +def test_dm(): + + knots = np.array([0, 0, 0, 0, 0.25, 0.5, 0.75, 1, 1, 1, 1]) + params = np.linspace(0, 1, 100) + order = 3 + + res = designMatrix(params, order, knots) + + C = sp.coo_array((res.v, (res.i, res.j))) + + assert C.shape[0] == len(params) + assert C.shape[1] == len(knots) - order - 1 + + +def test_der(): + + knots = np.array([0, 0, 0, 0, 0.25, 0.5, 0.75, 1, 1, 1, 1]) + params = np.linspace(0, 1, 100) + order = 3 + + out_der = np.zeros((order + 1, order + 1)) + out = np.zeros(order + 1) + + for p in params: + nbBasisDer(nbFindSpan(p, order, knots), p, order, order - 1, knots, out_der) + nbBasis(nbFindSpan(p, order, knots), p, order, knots, out) + + # sanity check + assert np.allclose(out_der[0, :], out) + + +def test_periodic_curve(): + + knots = np.linspace(0, 1, 5) + pts = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 2], [0, 2, 0]]) + + crv = Curve(pts, knots, 3, True) + + # is it indeed periodic? + assert crv.curve().IsPeriodic() + + # convert to an edge + e = crv.edge() + + assert e.isValid() + assert e.ShapeType() == "Edge" + + +def test_curve(): + + knots = np.array([0, 0, 0, 0, 1, 1, 1, 1]) + pts = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 2], [0, 2, 0]]) + + crv = Curve(pts, knots, 3, False) + + # sanity check + assert not crv.curve().IsPeriodic() + + # convert to an edge + e = crv.edge() + + assert e.isValid() + assert e.ShapeType() == "Edge" + + # edge to curve + crv2 = Curve.fromEdge(e) + e2 = crv2.edge() + + assert e2.isValid() + + # check roundtrip + crv3 = Curve.fromEdge(e2) + + assert np.allclose(crv2.knots, crv3.knots) + assert np.allclose(crv2.pts, crv3.pts) + + +def test_surface(): + + uknots = vknots = np.array([0, 0, 1, 1]) + pts = np.array([[[0, 0, 0], [0, 1, 0]], [[1, 0, 0], [1, 1, 0]]]) + + srf = Surface(pts, uknots, vknots, 1, 1, False, False) + + # convert to a face + f = srf.face() + + assert f.isValid() + assert f.Area() == approx(1) + + # roundtrip + srf2 = Surface.fromFace(f) + + assert np.allclose(srf.uknots, srf2.uknots) + assert np.allclose(srf.vknots, srf2.vknots) + assert np.allclose(srf.pts, srf2.pts) + + +def test_approximate(): + + pts_ = circle(1).trim(0, 1).sample(100)[0] + pts = np.array([list(p) for p in pts_]) + + # regular approximate + crv = approximate(pts) + e = crv.edge() + + assert e.isValid() + assert e.Length() == approx(1) + + # approximate with a double penalty + crv = approximate(pts, penalty=4, lam=1e-9) + e = crv.edge() + + assert e.isValid() + assert e.Length() == approx(1) + + # approximate with a single penalty + crv = approximate(pts, penalty=2, lam=1e-9) + e = crv.edge() + + assert e.isValid() + assert e.Length() == approx(1) + + +def test_periodic_approximate(): + + pts_ = circle(1).sample(100)[0] + pts = np.array([list(p) for p in pts_]) + + crv = periodicApproximate(pts) + e = crv.edge() + + assert e.isValid() + assert e.Length() == approx(2 * np.pi) + + +def test_periodic_loft(circles, trimmed_circles): + + # u,v periodic + surf1 = periodicLoft(*circles) + + assert surf1.face().isValid() + + # u periodic + surf2 = periodicLoft(*trimmed_circles) + + assert surf2.face().isValid() + + +def test_loft(circles, trimmed_circles): + + # v periodic + surf1 = loft(*circles) + + assert surf1.face().isValid() + + # non-periodic + surf2 = loft(*trimmed_circles) + + assert surf2.face().isValid() + + +def test_reparametrize(rotated_circles): + + c1, c2 = rotated_circles + + # this surface will be twisted + surf = loft(c1, c2, order=2, lam=1e-6) + + # this should adjust the paramatrizations + c1r, c2r = reparametrize(c1, c2) + + # resulting loft should not be twisted + surfr = loft(c1r, c2r, order=2, lam=1e-6) + + # assert that the surface is indeed not twisted + assert surfr.face().Area() == approx(2 * np.pi, 1e-3) + assert surfr.face().Area() >= 1.01 * surf.face().Area()