Skip to content

Commit af49eff

Browse files
authored
[Passes] Consolidate DCE passes into common passes (#2143)
Consolidate DCE passes into common passes (unused_removal) for them to be available for pass users. Refactored usage. Added a pass to remove unused opset imports.
1 parent a36ec86 commit af49eff

11 files changed

+264
-200
lines changed

onnxscript/ir/passes/_pass_infra.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ def call(self, model: ir.Model) -> PassResult:
249249
overall_modified = False
250250
for step in range(self.steps):
251251
try:
252-
step_result = super().__call__(model)
252+
# Call the call method of Sequential
253+
step_result = super().call(model)
253254
except Exception as e:
254255
raise PassError(f"An error occurred at step {step}") from e
255256
model = step_result.model
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
__all__ = [
6+
"RemoveUnusedNodesPass",
7+
"RemoveUnusedFunctionsPass",
8+
"RemoveUnusedOpsetsPass",
9+
]
10+
11+
import logging
12+
13+
import onnx
14+
15+
from onnxscript import ir
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def _remove_unused_optional_outputs(
21+
node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
22+
) -> None:
23+
try:
24+
if node.domain not in {"", "onnx.ai"}:
25+
return
26+
op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
27+
except Exception: # pylint: disable=broad-exception-caught
28+
logger.info(
29+
"Failed to get schema for %s, skipping optional output removal",
30+
node,
31+
stack_info=True,
32+
)
33+
return
34+
35+
if node.op_type == "BatchNormalization":
36+
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
37+
# If running_mean and running_var are not used, remove them, and the training_mode attribute
38+
def is_used_output(i: int) -> bool:
39+
if i < len(node.outputs):
40+
val = node.outputs[i]
41+
return val in graph_outputs or bool(val.uses())
42+
return False
43+
44+
if is_used_output(1) or is_used_output(2):
45+
return
46+
if len(node.outputs) > 1:
47+
node.outputs[1].name = ""
48+
if len(node.outputs) > 2:
49+
node.outputs[2].name = ""
50+
node.attributes.pop("training_mode", None)
51+
return
52+
53+
optional_info = []
54+
for o in op_schema.outputs:
55+
# Current ops do not have optional outputs if they have variable number of outputs
56+
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
57+
return
58+
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
59+
# If no optional outputs in spec, skip delete operations
60+
if len([o == 1 for o in optional_info]) == 0:
61+
return
62+
63+
for i, out in enumerate(node.outputs):
64+
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
65+
out.name = ""
66+
67+
68+
def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
69+
graph_outputs = frozenset(function_or_graph.outputs)
70+
onnx_opset_version = function_or_graph.opset_imports.get("", None)
71+
count = 0
72+
for node in reversed(function_or_graph):
73+
removable = True
74+
for output in node.outputs:
75+
if output in graph_outputs or output.uses():
76+
removable = False
77+
break
78+
if removable:
79+
function_or_graph.remove(node, safe=True)
80+
count += 1
81+
else:
82+
if onnx_opset_version is not None:
83+
_remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
84+
for attr in node.attributes.values():
85+
if not isinstance(attr, ir.Attr):
86+
continue
87+
if attr.type == ir.AttributeType.GRAPH:
88+
count += _remove_unused_nodes_in_graph_like(attr.as_graph())
89+
elif attr.type == ir.AttributeType.GRAPHS:
90+
for graph in attr.as_graphs():
91+
count += _remove_unused_nodes_in_graph_like(graph)
92+
return count
93+
94+
95+
class RemoveUnusedNodesPass(ir.passes.InPlacePass):
96+
def call(self, model: ir.Model) -> ir.passes.PassResult:
97+
count = _remove_unused_nodes_in_graph_like(model.graph)
98+
graph_outputs = frozenset(model.graph.outputs)
99+
initializers = model.graph.initializers
100+
for init in list(initializers.values()):
101+
if not (init in graph_outputs or init.uses()):
102+
assert init.name is not None
103+
del initializers[init.name]
104+
count += 1
105+
for function in model.functions.values():
106+
count += _remove_unused_nodes_in_graph_like(function)
107+
if count:
108+
logger.info("Removed %s unused nodes", count)
109+
return ir.passes.PassResult(model, modified=bool(count))
110+
111+
112+
class RemoveUnusedFunctionsPass(ir.passes.InPlacePass):
113+
def __init__(self):
114+
super().__init__()
115+
self._used: set[ir.OperatorIdentifier] | None = None
116+
117+
def call(self, model: ir.Model) -> ir.passes.PassResult:
118+
self._used = set()
119+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
120+
self._call_node(model, node)
121+
122+
# Update the model to remove unused functions
123+
unused = set(model.functions) - self._used
124+
if not unused:
125+
logger.info("No unused functions to remove")
126+
return ir.passes.PassResult(model, modified=False)
127+
128+
for op_identifier in unused:
129+
del model.functions[op_identifier]
130+
131+
logger.info("Removed %s unused functions", len(unused))
132+
logger.debug("Functions left: %s", list(model.functions))
133+
logger.debug("Functions removed: %s", unused)
134+
135+
self._used = None
136+
return ir.passes.PassResult(model, modified=bool(unused))
137+
138+
def _call_function(self, model: ir.Model, function: ir.Function) -> None:
139+
assert self._used is not None
140+
if function.identifier() in self._used:
141+
# The function and its nodes are already recorded as used
142+
return
143+
self._used.add(function.identifier())
144+
for node in ir.traversal.RecursiveGraphIterator(function):
145+
self._call_node(model, node)
146+
147+
def _call_node(self, model: ir.Model, node: ir.Node) -> None:
148+
op_identifier = node.op_identifier()
149+
if op_identifier not in model.functions:
150+
return
151+
self._call_function(model, model.functions[op_identifier])
152+
153+
154+
class RemoveUnusedOpsetsPass(ir.passes.InPlacePass):
155+
"""Remove unused opset imports from the model and functions.
156+
157+
Attributes:
158+
process_functions: Whether to process functions in the model. If True, the pass will
159+
remove unused opset imports from functions as well. If False, only the main graph
160+
will be processed.
161+
"""
162+
163+
def __init__(self, process_functions: bool = True):
164+
super().__init__()
165+
self.process_functions = process_functions
166+
167+
def _process_graph_like(
168+
self, graph_like: ir.Graph | ir.Function, used_domains: set[str]
169+
) -> bool:
170+
for node in ir.traversal.RecursiveGraphIterator(graph_like):
171+
used_domains.add(node.domain)
172+
unused = set(graph_like.opset_imports) - used_domains
173+
for domain in unused:
174+
del graph_like.opset_imports[domain]
175+
return bool(unused)
176+
177+
def call(self, model: ir.Model) -> ir.passes.PassResult:
178+
# Record domains of all functions
179+
used_domains = set()
180+
for function in model.functions.values():
181+
used_domains.add(function.domain)
182+
modified = self._process_graph_like(model.graph, used_domains=used_domains)
183+
184+
if self.process_functions:
185+
for function in model.functions.values():
186+
modified |= self._process_graph_like(function, used_domains=set())
187+
188+
return ir.passes.PassResult(model, modified=modified)

onnxscript/optimizer/_remove_unused_test.py onnxscript/ir/passes/common/unused_removal_test.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def remove_unused_nodes(self, model: onnx.ModelProto):
2525
def test_remove_unused_nodes(self):
2626
model = onnx.parser.parse_model(
2727
"""
28-
<ir_version: 7, opset_import: [ "" : 17]>
28+
<ir_version: 10, opset_import: [ "" : 17]>
2929
agraph (float[N] x) => (float[N] z) {
3030
two = Constant <value_float=2.0> ()
3131
four = Add(two, two)
@@ -40,7 +40,7 @@ def test_remove_unused_nodes(self):
4040
def test_remove_unused_initializers(self):
4141
model = onnx.parser.parse_model(
4242
"""
43-
<ir_version: 7, opset_import: [ "" : 17]>
43+
<ir_version: 10, opset_import: [ "" : 17]>
4444
agraph (float[N] x) => (float[N] z)
4545
<float two = {2.0}> {
4646
four = Add(two, two)
@@ -57,7 +57,7 @@ def test_remove_unused_initializers(self):
5757
def test_partially_used_nodes(self):
5858
model = onnx.parser.parse_model(
5959
"""
60-
<ir_version: 7, opset_import: [ "" : 17]>
60+
<ir_version: 10, opset_import: [ "" : 17]>
6161
agraph (float[N] x) => (float[M] z) {
6262
w1, w2, w3 = Split (x)
6363
z = Mul(w3, w3)
@@ -71,7 +71,7 @@ def test_partially_used_nodes(self):
7171
def test_remove_unused_optional_outputs_maxpool(self):
7272
model = onnx.parser.parse_model(
7373
"""
74-
<ir_version: 7, opset_import: [ "" : 17]>
74+
<ir_version: 10, opset_import: [ "" : 17]>
7575
agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) {
7676
z, indices = MaxPool <pads = [2, 2, 2, 2], kernel_shape = [5, 5]> (x)
7777
}
@@ -88,7 +88,7 @@ def test_remove_unused_optional_outputs_maxpool(self):
8888
def test_remove_unused_optional_outputs_dropout_in_function(self):
8989
model = onnx.parser.parse_model(
9090
"""
91-
<ir_version: 7, opset_import: [ "" : 17, "pkg.custom": 1]>
91+
<ir_version: 10, opset_import: [ "" : 17, "pkg.custom": 1]>
9292
agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z)
9393
{
9494
z = pkg.custom.afunction (x)
@@ -113,7 +113,7 @@ def test_remove_unused_optional_outputs_dropout_in_function(self):
113113
def test_remove_used_optional_outputs_maxpool(self):
114114
model = onnx.parser.parse_model(
115115
"""
116-
<ir_version: 7, opset_import: [ "" : 17]>
116+
<ir_version: 10, opset_import: [ "" : 17]>
117117
agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] y, float[1, 1, 5, 5] z) {
118118
y, z = MaxPool <pads = [2, 2, 2, 2], kernel_shape = [5, 5]> (x)
119119
}
@@ -130,7 +130,7 @@ def test_remove_used_optional_outputs_maxpool(self):
130130
def test_remove_multiple_unused_optional_outputs_layernorm(self):
131131
model = onnx.parser.parse_model(
132132
"""
133-
<ir_version: 7, opset_import: [ "" : 17]>
133+
<ir_version: 10, opset_import: [ "" : 17]>
134134
agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z) {
135135
scale = Constant <value_ints=[3]> ()
136136
B = Constant <value_ints=[3]> ()
@@ -149,7 +149,7 @@ def test_remove_multiple_unused_optional_outputs_layernorm(self):
149149
def test_remove_trailing_unused_optional_outputs_layernorm(self):
150150
model = onnx.parser.parse_model(
151151
"""
152-
<ir_version: 7, opset_import: [ "" : 17]>
152+
<ir_version: 10, opset_import: [ "" : 17]>
153153
agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] mean) {
154154
scale = Constant <value_ints=[3]> ()
155155
B = Constant <value_ints=[3]> ()
@@ -168,7 +168,7 @@ def test_remove_trailing_unused_optional_outputs_layernorm(self):
168168
def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self):
169169
model = onnx.parser.parse_model(
170170
"""
171-
<ir_version: 7, opset_import: [ "" : 17]>
171+
<ir_version: 10, opset_import: [ "" : 17]>
172172
agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] InvStdDev) {
173173
scale = Constant <value_ints=[3]> ()
174174
B = Constant <value_ints=[3]> ()
@@ -187,7 +187,7 @@ def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self):
187187
def test_remove_trailing_unused_optional_outputs_batchnorm(self):
188188
model = onnx.parser.parse_model(
189189
"""
190-
<ir_version: 7, opset_import: [ "" : 17]>
190+
<ir_version: 10, opset_import: [ "" : 17]>
191191
agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z) {
192192
z, mean_out, var_out = BatchNormalization <training_mode=1> (x, scale, B, mean, var)
193193
}
@@ -204,7 +204,7 @@ def test_remove_trailing_unused_optional_outputs_batchnorm(self):
204204
def test_avoid_remove_used_optional_outputs_batchnorm(self):
205205
model = onnx.parser.parse_model(
206206
"""
207-
<ir_version: 7, opset_import: [ "" : 17]>
207+
<ir_version: 10, opset_import: [ "" : 17]>
208208
agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z, float[3] mean_out, float[3] var_out) {
209209
z, mean_out, var_out = BatchNormalization <training_mode=1> (x, scale, B, mean, var)
210210
}

onnxscript/optimizer/__init__.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
import onnx
1616

17+
import onnxscript.ir.passes.common.unused_removal
1718
import onnxscript.optimizer._constant_folding as constant_folding
1819
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
1920
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
2021
from onnxscript import ir
2122
from onnxscript.optimizer._inliner import inline
2223
from onnxscript.optimizer._optimizer import optimize_ir
23-
from onnxscript.optimizer._remove_unused import remove_unused_nodes
2424

2525
basic_constant_propagation = constant_folding.basic_constant_propagation
2626
fold_constants_ir = constant_folding.fold_constants
@@ -40,3 +40,31 @@ def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs) -> bool:
4040
return constant_folding.fold_constants(model, *args, **kwargs)
4141
else:
4242
return legacy_constant_folding.fold_constants(model, *args, **kwargs)
43+
44+
45+
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
46+
"""Removes unused nodes from a model inplace."""
47+
if isinstance(model, ir.Model):
48+
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model)
49+
else:
50+
model_ir = ir.serde.deserialize_model(model)
51+
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(
52+
model_ir
53+
).model
54+
new_proto = ir.serde.serialize_model(model_ir)
55+
model.Clear()
56+
model.CopyFrom(new_proto)
57+
58+
59+
def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None:
60+
"""Removes unused functions from a model inplace."""
61+
if isinstance(model, ir.Model):
62+
onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()(model)
63+
else:
64+
model_ir = ir.serde.deserialize_model(model)
65+
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()(
66+
model_ir
67+
).model
68+
new_proto = ir.serde.serialize_model(model_ir)
69+
model.Clear()
70+
model.CopyFrom(new_proto)

onnxscript/optimizer/_legacy/_optimizer.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88
import onnx
99
import onnx.shape_inference
1010

11+
import onnxscript.optimizer
1112
from onnxscript import rewriter
1213
from onnxscript.optimizer._legacy._simple_function_folding import (
1314
inline_functions_with_unused_outputs,
1415
inline_simple_functions,
1516
)
1617
from onnxscript.optimizer._legacy.constant_folding import fold_constants
1718
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES
18-
from onnxscript.optimizer._remove_unused import remove_unused_nodes
19-
from onnxscript.optimizer._remove_unused_function import remove_unused_functions
2019

2120
logger = logging.getLogger(__name__)
2221

@@ -71,9 +70,9 @@ def optimize(
7170
model, external_data_folder, onnx_shape_inference=onnx_shape_inference
7271
)
7372

74-
remove_unused_nodes(model)
73+
onnxscript.optimizer.remove_unused_nodes(model)
7574
inline_simple_functions(model)
76-
model = remove_unused_functions(model)
75+
onnxscript.optimizer.remove_unused_functions(model)
7776
inline_functions_with_unused_outputs(model)
7877
# NOTE: This is general rewrite rules
7978
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)

0 commit comments

Comments
 (0)