Skip to content

Commit 34dc350

Browse files
[Rewriter]: fuse successive Relu/Clip nodes (#2410)
This PR adds the following transformation: - Relu(Relu(X)) -> Relu - Relu(Clip(X)) -> Clip - Clip(Relu(X)) -> Clip - Clip(Clip(X)) -> Clip --------- Co-authored-by: Justin Chu <[email protected]>
1 parent d708a7d commit 34dc350

File tree

4 files changed

+567
-4
lines changed

4 files changed

+567
-4
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
broadcast_to_matmul,
2121
cast_constant_of_shape,
2222
collapse_slices,
23+
fuse_relus_clips,
2324
no_op,
2425
pattern,
2526
redundant_scatter_nd,
@@ -32,6 +33,7 @@
3233
*broadcast_to_matmul.rules.rules,
3334
*cast_constant_of_shape.rules.rules,
3435
*collapse_slices.rules.rules,
36+
*fuse_relus_clips.fuse_relus_clips_rules().rules,
3537
*basic_rules.basic_optimization_rules().rules,
3638
*redundant_scatter_nd.rules.rules,
3739
)
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Does the following transformation:
4+
- Relu(Relu(X)) -> Relu
5+
- Relu(Clip(X)) -> Clip
6+
- Clip(Relu(X)) -> Clip
7+
- Clip(Clip(X)) -> Clip
8+
"""
9+
10+
import abc
11+
12+
import numpy as np
13+
import onnx_ir as ir
14+
15+
from onnxscript.rewriter import pattern as orp
16+
17+
18+
class FuseSuccessiveRelu(orp.RewriteRuleClassBase):
19+
"""Replaces ``Relu(Relu(X))`` with ``Relu(X)``."""
20+
21+
def rewrite(self, op, x):
22+
return op.Relu(x)
23+
24+
def pattern(self, op, x):
25+
return op.Relu(op.Relu(x))
26+
27+
28+
class _FuseReluClipBase(orp.RewriteRuleClassBase, abc.ABC):
29+
def rewrite(self, op, x, **kwargs):
30+
first_clip_node = kwargs.get("out_first_clip").producer()
31+
second_clip_node = None
32+
33+
if out_second_clip := kwargs.get("out_second_clip"):
34+
second_clip_node = out_second_clip.producer()
35+
36+
min_clip, max_clip = self.compute_clip_min_max(first_clip_node, second_clip_node)
37+
clip_min_max = []
38+
39+
if min_clip is not None:
40+
clip_min_max.append(
41+
op.initializer(min_clip, name=f"{first_clip_node.inputs[0].name}_min")
42+
)
43+
44+
if max_clip is not None:
45+
# ONNX Clip expects min and max inputs in order.
46+
# If min is not provided, we insert None to maintain correct argument positions.
47+
if min_clip is None:
48+
clip_min_max.append(None)
49+
50+
clip_min_max.append(
51+
op.initializer(max_clip, name=f"{first_clip_node.inputs[0].name}_max")
52+
)
53+
54+
return op.Clip(x, *clip_min_max)
55+
56+
@abc.abstractmethod
57+
def compute_clip_min_max(
58+
self, first_clip_node: ir.Node, second_clip_node: ir.Node | None = None
59+
):
60+
pass
61+
62+
def extract_min_max(self, node: ir.Node):
63+
# Infer dtype from node first input
64+
dtype = node.inputs[0].dtype.numpy()
65+
min_clip, max_clip = None, None
66+
67+
if len(node.inputs) > 1:
68+
min_input = node.inputs[1]
69+
# If only a max is provided, min is implicitly None, so we check that
70+
if min_input is not None:
71+
min_clip = min_input.const_value.numpy()
72+
73+
if len(node.inputs) > 2:
74+
max_clip = node.inputs[2].const_value.numpy()
75+
76+
return min_clip, max_clip, dtype
77+
78+
def check(self, context, **kwargs):
79+
"""Condition to check if we need to replace the pattern.
80+
81+
The pattern is applied only when the min and max inputs of the Clip nodes are
82+
not graph inputs and are constant values (i.e., provided by Constant nodes or initializers).
83+
84+
Returns:
85+
MatchResult:
86+
Success if we need to replace the pattern, Failure otherwise.
87+
"""
88+
del context # Unused
89+
check_result = orp.MatchResult()
90+
91+
# Check if Clip min/max are not graph inputs and are constant values
92+
clip_min_max = []
93+
94+
first_clip_node = kwargs.get("out_first_clip").producer()
95+
clip_min_max.extend([inp for inp in first_clip_node.inputs[1:] if inp is not None])
96+
97+
if out_second_clip := kwargs.get("out_second_clip"):
98+
second_clip_node = out_second_clip.producer()
99+
clip_min_max.extend(
100+
[inp for inp in second_clip_node.inputs[1:] if inp is not None]
101+
)
102+
103+
for m in clip_min_max:
104+
if m.is_graph_input():
105+
return check_result.fail(f"{m.name} is a graph input.")
106+
107+
if ir.convenience.get_const_tensor(m) is None:
108+
return check_result.fail(f"{m.name} is not a constant.")
109+
110+
return check_result
111+
112+
113+
class FuseSuccessiveClip(_FuseReluClipBase):
114+
"""Replaces ``Clip(Clip(X))`` with ``Clip(X)``."""
115+
116+
def pattern(self, op, x):
117+
return op.Clip(
118+
op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"]),
119+
_allow_other_inputs=True,
120+
_outputs=["out_second_clip"],
121+
)
122+
123+
def compute_clip_min_max(self, first_clip_node: ir.Node, second_clip_node: ir.Node):
124+
min_clip1, max_clip1, dtype = self.extract_min_max(first_clip_node)
125+
min_clip2, max_clip2, _ = self.extract_min_max(second_clip_node)
126+
127+
def combine(val1, val2, op):
128+
if val1 is not None and val2 is not None:
129+
return ir.tensor(np.array(op(val1, val2), dtype=dtype))
130+
elif val1 is not None:
131+
return ir.tensor(val1)
132+
elif val2 is not None:
133+
return ir.tensor(val2)
134+
return None
135+
136+
min_clip = combine(min_clip1, min_clip2, np.maximum)
137+
max_clip = combine(max_clip1, max_clip2, np.minimum)
138+
139+
return min_clip, max_clip
140+
141+
142+
class FuseSuccessiveClipRelu(_FuseReluClipBase):
143+
"""Replaces ``Clip(Relu(X))`` with ``Clip(X)``."""
144+
145+
def pattern(self, op, x):
146+
return op.Clip(op.Relu(x), _allow_other_inputs=True, _outputs=["out_first_clip"])
147+
148+
def compute_clip_min_max(self, first_clip_node: ir.Node, _):
149+
min_clip, max_clip, dtype = self.extract_min_max(first_clip_node)
150+
151+
if min_clip is None:
152+
# The minimum clipping value is implicitly 0 (Relu clamps at 0)
153+
min_clip = 0
154+
155+
min_clip = ir.tensor(np.array(np.maximum(0.0, min_clip), dtype=dtype))
156+
157+
if max_clip is not None:
158+
max_clip = ir.tensor(max_clip)
159+
return min_clip, max_clip
160+
161+
162+
class FuseSuccessiveReluClip(FuseSuccessiveClipRelu):
163+
"""Replaces ``Relu(Clip(X))`` with ``Clip(X)``."""
164+
165+
def pattern(self, op, x):
166+
return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"]))
167+
168+
169+
fuse_successive_relu_rule = FuseSuccessiveRelu().rule()
170+
fuse_successive_clip_rule = FuseSuccessiveClip().rule()
171+
fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule()
172+
fuse_successive_relu_clip_rule = FuseSuccessiveReluClip().rule()
173+
174+
175+
def fuse_relus_clips_rules() -> orp.RewriteRuleSet:
176+
"""Returns a set of rewrite rules that fuse successive Relu/Clip nodes.
177+
178+
Returns:
179+
RewriteRuleSet
180+
"""
181+
182+
# Order is important
183+
return orp.RewriteRuleSet(
184+
[
185+
fuse_successive_clip_relu_rule,
186+
fuse_successive_relu_clip_rule,
187+
fuse_successive_relu_rule,
188+
fuse_successive_clip_rule,
189+
]
190+
)

0 commit comments

Comments
 (0)