Skip to content

Commit 3536960

Browse files
authored
[IR] Implement model.graphs() (#2200)
Implement model.graphs() as a way to retrieve the main graph and all subgraphs of it in the model. Given (1) how useful the method is (2) I couldn't find an appropriate name for it in `traversal.py` (3) Users familiar with onnxruntime optimization tools expect this method. In PyTorch a similar `modules()` method exists. I created this method as a core method instead of an iterator in `traversal.py`. Depends on #2183
1 parent d1a8215 commit 3536960

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

onnxscript/ir/_core.py

+19
Original file line numberDiff line numberDiff line change
@@ -2563,6 +2563,25 @@ def __repr__(self) -> str:
25632563
graph={textwrap.indent(repr(self.graph), " " * 4).strip()}
25642564
)"""
25652565

2566+
def graphs(self) -> Iterable[Graph]:
2567+
"""Get all graphs and subgraphs in the model.
2568+
2569+
This is a convenience method to traverse the model. Consider using
2570+
`onnxscript.ir.traversal.RecursiveGraphIterator` for more advanced
2571+
traversals on nodes.
2572+
"""
2573+
# NOTE(justinchuby): Given
2574+
# (1) how useful the method is
2575+
# (2) I couldn't find an appropriate name for it in `traversal.py`
2576+
# (3) Users familiar with onnxruntime optimization tools expect this method
2577+
# I created this method as a core method instead of an iterator in
2578+
# `traversal.py`.
2579+
seen_graphs: set[Graph] = set()
2580+
for node in onnxscript.ir.traversal.RecursiveGraphIterator(self.graph):
2581+
if node.graph is not None and node.graph not in seen_graphs:
2582+
seen_graphs.add(node.graph)
2583+
yield node.graph
2584+
25662585

25672586
class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
25682587
"""IR functions.

onnxscript/ir/_core_test.py

+55
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,61 @@ def test_topological_sort_subgraph(self):
11521152
)
11531153

11541154

1155+
class ModelTest(unittest.TestCase):
1156+
def test_graphs_returns_all_subgraphs(self):
1157+
# main_graph: nodes=[a,b,c,d,>,if], edges=[(a,>),(b,>),(>,if)], subgraphs={if:[then_graph,else_graph]}
1158+
# then_graph: nodes=[sub], edges=[(c,sub),(d,sub)]
1159+
# else_graph: nodes=[add], edges=[(c,add),(d,add)]
1160+
v0 = _core.Value(name="va")
1161+
v1 = _core.Value(name="vb")
1162+
v2 = _core.Value(name="vc")
1163+
v3 = _core.Value(name="vd")
1164+
node0 = _core.Node("", "a", inputs=(v0,), num_outputs=1)
1165+
node1 = _core.Node("", "b", inputs=(v1,), num_outputs=1)
1166+
node2 = _core.Node("", "c", inputs=(v2,), num_outputs=1)
1167+
node3 = _core.Node("", "d", inputs=(v3,), num_outputs=1)
1168+
node4 = _core.Node(
1169+
"", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1
1170+
)
1171+
node5 = _core.Node(
1172+
"", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1
1173+
)
1174+
node6 = _core.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1)
1175+
then_graph = _core.Graph(
1176+
inputs=(node2.outputs[0], node3.outputs[0]),
1177+
outputs=(node4.outputs[0],),
1178+
nodes=(node4,),
1179+
name="then_graph",
1180+
)
1181+
else_graph = _core.Graph(
1182+
inputs=(node2.outputs[0], node3.outputs[0]),
1183+
outputs=(node5.outputs[0],),
1184+
nodes=(node5,),
1185+
name="else_graph",
1186+
)
1187+
node7 = _core.Node(
1188+
"",
1189+
"if",
1190+
inputs=(node6.outputs[0],),
1191+
num_outputs=1,
1192+
attributes=[
1193+
ir.AttrGraph("then_branch", then_graph),
1194+
ir.AttrGraph("else_branch", else_graph),
1195+
],
1196+
)
1197+
main_graph = _core.Graph(
1198+
inputs=(v0, v1, v2, v3),
1199+
outputs=(node7.outputs[0],),
1200+
nodes=(node0, node1, node2, node6, node7),
1201+
name="main_graph",
1202+
)
1203+
model = _core.Model(main_graph, ir_version=10)
1204+
self.assertEqual(
1205+
tuple(model.graphs()),
1206+
(main_graph, then_graph, else_graph),
1207+
)
1208+
1209+
11551210
class TypeTest(unittest.TestCase):
11561211
@parameterized.parameterized.expand(
11571212
[

0 commit comments

Comments
 (0)