diff --git a/onnxscript/ir/passes/common/topological_sort.py b/onnxscript/ir/passes/common/topological_sort.py new file mode 100644 index 000000000..9be183cf0 --- /dev/null +++ b/onnxscript/ir/passes/common/topological_sort.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Pass for topologically sorting the graphs.""" + +from __future__ import annotations + +__all__ = [ + "TopologicalSortPass", +] + + +from onnxscript import ir + + +class TopologicalSortPass(ir.passes.InPlacePass): + """Topologically sort graphs and functions in a model.""" + + def call(self, model: ir.Model) -> ir.passes.PassResult: + original_nodes = list(model.graph) + model.graph.sort() + sorted_nodes = list(model.graph) + for function in model.functions.values(): + original_nodes.extend(function) + function.sort() + sorted_nodes.extend(function) + + # Compare node orders to determine if any changes were made + modified = False + for node, new_node in zip(original_nodes, sorted_nodes): + if node is not new_node: + modified = True + break + return ir.passes.PassResult(model=model, modified=modified) diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py new file mode 100644 index 000000000..ca9d1377f --- /dev/null +++ b/onnxscript/ir/passes/common/topological_sort_test.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Unit tests for the TopologicalSortPass.""" + +import unittest + +from onnxscript import ir +from onnxscript.ir.passes.common import topological_sort + + +class TopologicalSortPassTest(unittest.TestCase): + def setUp(self): + self.node_a = ir.node("A", inputs=[], name="node_a") + self.node_b = ir.node("B", inputs=self.node_a.outputs, name="node_b") + self.node_c = ir.node("C", inputs=self.node_b.outputs, name="node_c") + + def test_topological_sort_modified_true(self): + graph = ir.Graph( + inputs=self.node_a.inputs, + outputs=self.node_c.outputs, + nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes + name="test_graph", + ) + model = ir.Model(graph, ir_version=10) + result = topological_sort.TopologicalSortPass()(model) + self.assertTrue(result.modified) + self.assertEqual( + tuple(result.model.graph), + (self.node_a, self.node_b, self.node_c), + ) + + def test_topological_sort_modified_false(self): + """Test that modified is False when the input model is already sorted.""" + sorted_graph = ir.Graph( + inputs=self.node_a.inputs, + outputs=self.node_c.outputs, + nodes=[self.node_a, self.node_b, self.node_c], # Sorted nodes + name="test_graph", + ) + sorted_model = ir.Model(sorted_graph, ir_version=10) + result = topological_sort.TopologicalSortPass()(sorted_model) + self.assertFalse(result.modified) + self.assertEqual( + tuple(result.model.graph), + (self.node_a, self.node_b, self.node_c), + ) + + +if __name__ == "__main__": + unittest.main()