diff --git a/.github/workflows/doc_build.yml b/.github/workflows/doc_build.yml index f1b7108c..27965a7d 100644 --- a/.github/workflows/doc_build.yml +++ b/.github/workflows/doc_build.yml @@ -14,6 +14,10 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.12" + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y graphviz - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/doc_build_deploy.yml b/.github/workflows/doc_build_deploy.yml index d843e35b..513d73bc 100644 --- a/.github/workflows/doc_build_deploy.yml +++ b/.github/workflows/doc_build_deploy.yml @@ -35,6 +35,10 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.12" + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y graphviz - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c22b7d2f..0802d707 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,12 +23,15 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y graphviz - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install --upgrade uv - python -m uv pip install pytest - python -m uv pip install -e ./cuequivariance + python -m uv pip install -e ./cuequivariance[dev] - name: Test with pytest run: | pytest --doctest-modules -x -m "not slow" cuequivariance diff --git a/cuequivariance/cuequivariance/segmented_polynomials/__init__.py b/cuequivariance/cuequivariance/segmented_polynomials/__init__.py index 42cadb24..3a3f39b9 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/__init__.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/__init__.py @@ -23,6 +23,7 @@ from .operation import Operation from .segmented_polynomial import SegmentedPolynomial +from .visualization import visualize_polynomial __all__ = [ @@ -37,4 +38,5 @@ "dispatch", "Operation", "SegmentedPolynomial", + "visualize_polynomial", ] diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index cba9429b..62e34388 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -17,7 +17,7 @@ import copy import dataclasses import itertools -from typing import Any, Callable, Sequence +from typing import Any, Callable, Optional, Sequence import numpy as np @@ -998,7 +998,10 @@ def jvp( self, has_tangent: list[bool] ) -> tuple[ SegmentedPolynomial, - Callable[[tuple[list[Any], list[Any]]], tuple[list[Any], list[Any]]], + Callable[ + [tuple[list[Any], list[Any]], Optional[Callable[[Any], Any]]], + tuple[list[Any], list[Any]], + ], ]: """Compute the Jacobian-vector product of the polynomial. @@ -1023,14 +1026,20 @@ def jvp( ): new_operations.append((ope, multiplicator * stp)) - def mapping(x: tuple[list[Any], list[Any]]) -> tuple[list[Any], list[Any]]: + def mapping( + x: tuple[list[Any], list[Any]], + into_grad: Optional[Callable[[Any], Any]] = None, + ) -> tuple[list[Any], list[Any]]: inputs, outputs = x inputs, outputs = list(inputs), list(outputs) assert len(inputs) == self.num_inputs assert len(outputs) == self.num_outputs + into_grad = into_grad if callable(into_grad) else lambda x: x - new_inputs = inputs + [x for has, x in zip(has_tangent, inputs) if has] - new_outputs = outputs + new_inputs = inputs + [ + into_grad(x) for has, x in zip(has_tangent, inputs) if has + ] + new_outputs = [into_grad(x) for x in outputs] return new_inputs, new_outputs @@ -1088,7 +1097,10 @@ def backward( self, requires_gradient: list[bool], has_cotangent: list[bool] ) -> tuple[ SegmentedPolynomial, - Callable[[tuple[list[Any], list[Any]]], tuple[list[Any], list[Any]]], + Callable[ + [tuple[list[Any], list[Any]], Optional[Callable[[Any], Any]]], + tuple[list[Any], list[Any]], + ], ]: """Compute the backward pass of the polynomial for gradient computation. @@ -1106,7 +1118,10 @@ def backward( has_cotangent, ) - def mapping(x: tuple[list[Any], list[Any]]) -> tuple[list[Any], list[Any]]: - return map2(map1(x)) + def mapping( + x: tuple[list[Any], list[Any]], + into_grad: Optional[Callable[[Any], Any]] = None, + ) -> tuple[list[Any], list[Any]]: + return map2(map1(x, into_grad)) return p, mapping diff --git a/cuequivariance/cuequivariance/segmented_polynomials/visualization.py b/cuequivariance/cuequivariance/segmented_polynomials/visualization.py new file mode 100644 index 00000000..c03674f1 --- /dev/null +++ b/cuequivariance/cuequivariance/segmented_polynomials/visualization.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import graphviz + + import cuequivariance as cue + + +def visualize_polynomial( + poly: "cue.SegmentedPolynomial", + input_names: list[str], + output_names: list[str], +) -> "graphviz.Digraph": + """ + Create a graphviz diagram showing the dataflow from inputs through STPs to outputs. + + Args: + poly: The SegmentedPolynomial to visualize. + input_names: Names for each input operand (length must match poly.num_inputs). + output_names: Names for each output operand (length must match poly.num_outputs). + + Returns: + A graphviz.Digraph object that can be rendered, saved, or displayed. + + Example: + >>> import cuequivariance as cue + >>> from cuequivariance.segmented_polynomials.visualization import visualize_polynomial + >>> poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [1, 2, 3]).polynomial + >>> graph = visualize_polynomial(poly, ["x"], ["Y"]) + >>> graph.render("spherical_harmonics", format="png", cleanup=True) # doctest: +SKIP + >>> # Or in Jupyter: + >>> # graph # Displays inline + + Raises: + ValueError: If the number of names doesn't match the number of inputs/outputs. + ImportError: If graphviz is not installed. + """ + # Validate parameters first + if len(input_names) != poly.num_inputs: + raise ValueError( + f"Expected {poly.num_inputs} input names, got {len(input_names)}" + ) + if len(output_names) != poly.num_outputs: + raise ValueError( + f"Expected {poly.num_outputs} output names, got {len(output_names)}" + ) + + # Import graphviz (checked after parameter validation) + try: + import graphviz + except ImportError as e: + raise ImportError( + "graphviz is required for visualization. Install it with: pip install graphviz" + ) from e + + # Create directed graph + dot = graphviz.Digraph(comment="Segmented Polynomial Flow") + dot.attr(rankdir="LR") # Left to right layout + dot.attr("node", shape="box") + + # Create input nodes + for i, (name, operand) in enumerate(zip(input_names, poly.inputs)): + label = f"{name}\\n{operand.num_segments} segments\\nsize={operand.size}" + dot.node(f"input_{i}", label, style="filled", fillcolor="lightblue") + + # Create output nodes + for i, (name, operand) in enumerate(zip(output_names, poly.outputs)): + label = f"{name}\\n{operand.num_segments} segments\\nsize={operand.size}" + dot.node(f"output_{i}", label, style="filled", fillcolor="lightgreen") + + # Create STP nodes and edges + for stp_idx, (operation, stp) in enumerate(poly.operations): + # Create STP node + stp_label = f"{stp.subscripts}\\n{stp.num_paths} paths" + dot.node(f"stp_{stp_idx}", stp_label, style="filled", fillcolor="lightyellow") + + # Create edges from inputs to this STP + for operand_idx in operation.input_buffers(poly.num_inputs): + dot.edge(f"input_{operand_idx}", f"stp_{stp_idx}") + + # Create edge from this STP to output + output_buffer = operation.output_buffer(poly.num_inputs) + output_idx = output_buffer - poly.num_inputs + dot.edge(f"stp_{stp_idx}", f"output_{output_idx}") + + return dot diff --git a/cuequivariance/pyproject.toml b/cuequivariance/pyproject.toml index 1b1cfb44..7d571252 100644 --- a/cuequivariance/pyproject.toml +++ b/cuequivariance/pyproject.toml @@ -40,6 +40,12 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] +[project.optional-dependencies] +dev = [ + "pytest", + "graphviz", +] + [tool.hatch.version] path = "cuequivariance/VERSION" pattern = "(?P\\d+\\.\\d+\\.\\d+(?:[a-z]+\\d+)?)" diff --git a/docs/requirements.txt b/docs/requirements.txt index aac8c51d..e5a595e2 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -8,4 +8,5 @@ ipykernel matplotlib jupyter-sphinx e3nn -flax \ No newline at end of file +flax +graphviz \ No newline at end of file diff --git a/docs/tutorials/poly.rst b/docs/tutorials/poly.rst index d76d3ab2..da0b4b7f 100644 --- a/docs/tutorials/poly.rst +++ b/docs/tutorials/poly.rst @@ -151,4 +151,86 @@ This hierarchical structure allows for efficient representation and computation .. jupyter-execute:: p.operations - \ No newline at end of file + +Visualization +------------- + +You can visualize the dataflow of a :class:`cue.SegmentedPolynomial ` using graphviz. This creates a diagram showing how inputs flow through segmented tensor products to produce outputs. + +First, install graphviz: + +.. code-block:: bash + + pip install graphviz + +Then create a visualization: + +.. jupyter-execute:: + + from cuequivariance.segmented_polynomials import visualize_polynomial + + # Visualize the spherical harmonics polynomial + sh_poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [1, 2]).polynomial + graph = visualize_polynomial(sh_poly, input_names=["x"], output_names=["Y"]) + + # Display the graph (in Jupyter it renders inline) + graph + +The diagram shows: + +* **Input nodes** (blue): Display the input name, number of segments, and total size +* **STP nodes** (yellow): Show the subscripts and number of computation paths +* **Output nodes** (green): Display the output name, number of segments, and total size +* **Edges**: Represent the dataflow, with multiple edges drawn when an input is used multiple times + +You can save the diagram to a file: + +.. jupyter-execute:: + :hide-output: + + # Save as PNG (or 'svg', 'pdf', etc.) + graph.render('spherical_harmonics', format='png', cleanup=True) + +For more complex examples: + +.. jupyter-execute:: + + # Visualize a linear layer + irreps_in = cue.Irreps("O3", "8x0e + 8x1o") + irreps_out = cue.Irreps("O3", "4x0e + 4x1o") + linear_poly = cue.descriptors.linear(irreps_in, irreps_out).polynomial + + graph = visualize_polynomial(linear_poly, input_names=["weights", "input"], output_names=["output"]) + graph + +.. jupyter-execute:: + + # Visualize a tensor product + irreps = cue.Irreps("O3", "0e + 1o") + tp_poly = cue.descriptors.channelwise_tensor_product(irreps, irreps, irreps).polynomial + + graph = visualize_polynomial(tp_poly, input_names=["weights", "x1", "x2"], output_names=["y"]) + graph + +Visualizing Backward Pass +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can also visualize the backward pass of a polynomial. The mapping function returned by :meth:`cue.SegmentedPolynomial.backward ` accepts an optional `into_grad` parameter that can transform operand names, which is useful for labeling gradients: + +.. jupyter-execute:: + + # Create a polynomial and compute its backward pass + irreps = cue.Irreps("O3", "0e + 1o") + tp_poly = cue.descriptors.channelwise_tensor_product(irreps, irreps, irreps).polynomial + + # Compute backward pass (all inputs require gradients, output has cotangent) + poly_bwd, m = tp_poly.backward([True, True, True], [True]) + + # Transform operand names using the mapping function with into_grad + # The mapping function takes (inputs, outputs) and returns (new_inputs, new_outputs) + operand_names = (["weights", "x1", "x2"], ["y"]) + operand_names_bwd = m(operand_names, lambda n: f"d{n}") + + # Visualize the backward polynomial + graph = visualize_polynomial(poly_bwd, input_names=operand_names_bwd[0], output_names=operand_names_bwd[1]) + graph diff --git a/test_visualization.py b/test_visualization.py new file mode 100644 index 00000000..5d714b25 --- /dev/null +++ b/test_visualization.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +"""Test script for the visualize_polynomial function (works without graphviz).""" + +import cuequivariance as cue + + +def test_visualization_api(): + """Test that the visualization function has the correct API without rendering.""" + from cuequivariance.segmented_polynomials import visualize_polynomial + + # Create a simple polynomial + sh_poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [1, 2]).polynomial + + print("Testing visualize_polynomial API...") + print(f"Polynomial: {sh_poly}") + print(f" num_inputs: {sh_poly.num_inputs}") + print(f" num_outputs: {sh_poly.num_outputs}") + print(f" num_operations: {len(sh_poly.operations)}") + print() + + # Test error handling for wrong number of names + try: + visualize_polynomial(sh_poly, ["x", "y"], ["Y"]) # Too many input names + print("❌ Should have raised ValueError for wrong number of inputs") + except ValueError as e: + print(f"✓ Correctly raised ValueError: {e}") + + try: + visualize_polynomial(sh_poly, ["x"], ["Y", "Z"]) # Too many output names + print("❌ Should have raised ValueError for wrong number of outputs") + except ValueError as e: + print(f"✓ Correctly raised ValueError: {e}") + + # Test that it raises ImportError if graphviz is not installed + try: + graph = visualize_polynomial(sh_poly, ["x"], ["Y"]) + print("✓ graphviz is installed, graph created successfully") + print(f" Graph type: {type(graph)}") + # Print the DOT source + print("\nGenerated DOT source:") + print(graph.source) + except ImportError as e: + print(f"✓ Correctly raised ImportError when graphviz not installed: {e}") + + print("\n✓ All API tests passed!") + + +if __name__ == "__main__": + test_visualization_api()