|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +"""Does the following transformation: |
| 4 | +- Relu(Relu(X)) -> Relu |
| 5 | +- Relu(Clip(X)) -> Clip |
| 6 | +- Clip(Relu(X)) -> Clip |
| 7 | +- Clip(Clip(X)) -> Clip |
| 8 | +""" |
| 9 | + |
| 10 | +import abc |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import onnx_ir as ir |
| 14 | + |
| 15 | +from onnxscript.rewriter import pattern as orp |
| 16 | + |
| 17 | + |
| 18 | +class FuseSuccessiveRelu(orp.RewriteRuleClassBase): |
| 19 | + """Replaces ``Relu(Relu(X))`` with ``Relu(X)``.""" |
| 20 | + |
| 21 | + def rewrite(self, op, x): |
| 22 | + return op.Relu(x) |
| 23 | + |
| 24 | + def pattern(self, op, x): |
| 25 | + return op.Relu(op.Relu(x)) |
| 26 | + |
| 27 | + |
| 28 | +class _FuseReluClipBase(orp.RewriteRuleClassBase, abc.ABC): |
| 29 | + def rewrite(self, op, x, **kwargs): |
| 30 | + first_clip_node = kwargs.get("out_first_clip").producer() |
| 31 | + second_clip_node = None |
| 32 | + |
| 33 | + if out_second_clip := kwargs.get("out_second_clip"): |
| 34 | + second_clip_node = out_second_clip.producer() |
| 35 | + |
| 36 | + min_clip, max_clip = self.compute_clip_min_max(first_clip_node, second_clip_node) |
| 37 | + clip_min_max = [] |
| 38 | + |
| 39 | + if min_clip is not None: |
| 40 | + clip_min_max.append( |
| 41 | + op.initializer(min_clip, name=f"{first_clip_node.inputs[0].name}_min") |
| 42 | + ) |
| 43 | + |
| 44 | + if max_clip is not None: |
| 45 | + # ONNX Clip expects min and max inputs in order. |
| 46 | + # If min is not provided, we insert None to maintain correct argument positions. |
| 47 | + if min_clip is None: |
| 48 | + clip_min_max.append(None) |
| 49 | + |
| 50 | + clip_min_max.append( |
| 51 | + op.initializer(max_clip, name=f"{first_clip_node.inputs[0].name}_max") |
| 52 | + ) |
| 53 | + |
| 54 | + return op.Clip(x, *clip_min_max) |
| 55 | + |
| 56 | + @abc.abstractmethod |
| 57 | + def compute_clip_min_max( |
| 58 | + self, first_clip_node: ir.Node, second_clip_node: ir.Node | None = None |
| 59 | + ): |
| 60 | + pass |
| 61 | + |
| 62 | + def extract_min_max(self, node: ir.Node): |
| 63 | + # Infer dtype from node first input |
| 64 | + dtype = node.inputs[0].dtype.numpy() |
| 65 | + min_clip, max_clip = None, None |
| 66 | + |
| 67 | + if len(node.inputs) > 1: |
| 68 | + min_input = node.inputs[1] |
| 69 | + # If only a max is provided, min is implicitly None, so we check that |
| 70 | + if min_input is not None: |
| 71 | + min_clip = min_input.const_value.numpy() |
| 72 | + |
| 73 | + if len(node.inputs) > 2: |
| 74 | + max_clip = node.inputs[2].const_value.numpy() |
| 75 | + |
| 76 | + return min_clip, max_clip, dtype |
| 77 | + |
| 78 | + def check(self, context, **kwargs): |
| 79 | + """Condition to check if we need to replace the pattern. |
| 80 | +
|
| 81 | + The pattern is applied only when the min and max inputs of the Clip nodes are |
| 82 | + not graph inputs and are constant values (i.e., provided by Constant nodes or initializers). |
| 83 | +
|
| 84 | + Returns: |
| 85 | + MatchResult: |
| 86 | + Success if we need to replace the pattern, Failure otherwise. |
| 87 | + """ |
| 88 | + del context # Unused |
| 89 | + check_result = orp.MatchResult() |
| 90 | + |
| 91 | + # Check if Clip min/max are not graph inputs and are constant values |
| 92 | + clip_min_max = [] |
| 93 | + |
| 94 | + first_clip_node = kwargs.get("out_first_clip").producer() |
| 95 | + clip_min_max.extend([inp for inp in first_clip_node.inputs[1:] if inp is not None]) |
| 96 | + |
| 97 | + if out_second_clip := kwargs.get("out_second_clip"): |
| 98 | + second_clip_node = out_second_clip.producer() |
| 99 | + clip_min_max.extend( |
| 100 | + [inp for inp in second_clip_node.inputs[1:] if inp is not None] |
| 101 | + ) |
| 102 | + |
| 103 | + for m in clip_min_max: |
| 104 | + if m.is_graph_input(): |
| 105 | + return check_result.fail(f"{m.name} is a graph input.") |
| 106 | + |
| 107 | + if ir.convenience.get_const_tensor(m) is None: |
| 108 | + return check_result.fail(f"{m.name} is not a constant.") |
| 109 | + |
| 110 | + return check_result |
| 111 | + |
| 112 | + |
| 113 | +class FuseSuccessiveClip(_FuseReluClipBase): |
| 114 | + """Replaces ``Clip(Clip(X))`` with ``Clip(X)``.""" |
| 115 | + |
| 116 | + def pattern(self, op, x): |
| 117 | + return op.Clip( |
| 118 | + op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"]), |
| 119 | + _allow_other_inputs=True, |
| 120 | + _outputs=["out_second_clip"], |
| 121 | + ) |
| 122 | + |
| 123 | + def compute_clip_min_max(self, first_clip_node: ir.Node, second_clip_node: ir.Node): |
| 124 | + min_clip1, max_clip1, dtype = self.extract_min_max(first_clip_node) |
| 125 | + min_clip2, max_clip2, _ = self.extract_min_max(second_clip_node) |
| 126 | + |
| 127 | + def combine(val1, val2, op): |
| 128 | + if val1 is not None and val2 is not None: |
| 129 | + return ir.tensor(np.array(op(val1, val2), dtype=dtype)) |
| 130 | + elif val1 is not None: |
| 131 | + return ir.tensor(val1) |
| 132 | + elif val2 is not None: |
| 133 | + return ir.tensor(val2) |
| 134 | + return None |
| 135 | + |
| 136 | + min_clip = combine(min_clip1, min_clip2, np.maximum) |
| 137 | + max_clip = combine(max_clip1, max_clip2, np.minimum) |
| 138 | + |
| 139 | + return min_clip, max_clip |
| 140 | + |
| 141 | + |
| 142 | +class FuseSuccessiveClipRelu(_FuseReluClipBase): |
| 143 | + """Replaces ``Clip(Relu(X))`` with ``Clip(X)``.""" |
| 144 | + |
| 145 | + def pattern(self, op, x): |
| 146 | + return op.Clip(op.Relu(x), _allow_other_inputs=True, _outputs=["out_first_clip"]) |
| 147 | + |
| 148 | + def compute_clip_min_max(self, first_clip_node: ir.Node, _): |
| 149 | + min_clip, max_clip, dtype = self.extract_min_max(first_clip_node) |
| 150 | + |
| 151 | + if min_clip is None: |
| 152 | + # The minimum clipping value is implicitly 0 (Relu clamps at 0) |
| 153 | + min_clip = 0 |
| 154 | + |
| 155 | + min_clip = ir.tensor(np.array(np.maximum(0.0, min_clip), dtype=dtype)) |
| 156 | + |
| 157 | + if max_clip is not None: |
| 158 | + max_clip = ir.tensor(max_clip) |
| 159 | + return min_clip, max_clip |
| 160 | + |
| 161 | + |
| 162 | +class FuseSuccessiveReluClip(FuseSuccessiveClipRelu): |
| 163 | + """Replaces ``Relu(Clip(X))`` with ``Clip(X)``.""" |
| 164 | + |
| 165 | + def pattern(self, op, x): |
| 166 | + return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"])) |
| 167 | + |
| 168 | + |
| 169 | +fuse_successive_relu_rule = FuseSuccessiveRelu().rule() |
| 170 | +fuse_successive_clip_rule = FuseSuccessiveClip().rule() |
| 171 | +fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() |
| 172 | +fuse_successive_relu_clip_rule = FuseSuccessiveReluClip().rule() |
| 173 | + |
| 174 | + |
| 175 | +def fuse_relus_clips_rules() -> orp.RewriteRuleSet: |
| 176 | + """Returns a set of rewrite rules that fuse successive Relu/Clip nodes. |
| 177 | +
|
| 178 | + Returns: |
| 179 | + RewriteRuleSet |
| 180 | + """ |
| 181 | + |
| 182 | + # Order is important |
| 183 | + return orp.RewriteRuleSet( |
| 184 | + [ |
| 185 | + fuse_successive_clip_relu_rule, |
| 186 | + fuse_successive_relu_clip_rule, |
| 187 | + fuse_successive_relu_rule, |
| 188 | + fuse_successive_clip_rule, |
| 189 | + ] |
| 190 | + ) |
0 commit comments