Skip to content

Commit c2103e7

Browse files
A couple of extensions to rewriter (#2001)
Extends the rewriter with a couple of features: * A debugging mode to perform the pattern matching (without any graph modifications) and to report instances that get the best score for a match (even if incomplete). Helps quickly identify causes for mismatches when we expect a match. * Rewrite-rules can now specify a pre/post visitor method called before applying it to a graph/function. This is useful for rules that need to create "cached" values that are reused across multiple instances of the pattern. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent a942e95 commit c2103e7

File tree

3 files changed

+276
-36
lines changed

3 files changed

+276
-36
lines changed

onnxscript/rewriter/_ir_utils.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,29 @@
33
from __future__ import annotations
44

55
import math
6-
from typing import Callable
6+
from typing import Callable, Sequence
77

88
import numpy as np
99

1010
import onnxscript.ir as ir
1111
from onnxscript.optimizer import basic_constant_propagation
1212

1313

14+
def display_nodes(nodes: Sequence[ir.Node]) -> None:
15+
"""Display a list of nodes in the order they appear in the graph."""
16+
if nodes:
17+
graph = nodes[0].graph
18+
if graph:
19+
# Display nodes in same order as in graph:
20+
# Currently doesn't handle (control-flow) subgraphs
21+
for node in graph:
22+
if node in nodes:
23+
node.display()
24+
else:
25+
for node in nodes:
26+
node.display()
27+
28+
1429
def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5) -> None:
1530
"""Display the (backward or forward) subgraph from a given value or node upto a certain depth."""
1631
slice = []
@@ -33,17 +48,7 @@ def visit(node: ir.Node, depth):
3348
visit(x, 0)
3449
elif isinstance(x, ir.Value) and x.producer() is not None:
3550
visit(x.producer(), 0) # type: ignore[arg-type]
36-
if slice:
37-
graph = slice[0].graph
38-
if graph:
39-
# Display nodes in same order as in graph:
40-
# Currently doesn't handle (control-flow) subgraphs
41-
for node in graph:
42-
if node in slice:
43-
node.display()
44-
else:
45-
for node in reversed(slice):
46-
node.display()
51+
display_nodes(slice)
4752

4853

4954
def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:

0 commit comments

Comments
 (0)