Skip to content

Commit 0880593

Browse files
committed
where-not optimization
Signed-off-by: Ananya <[email protected]>
1 parent 0e49375 commit 0880593

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

onnxoptimizer/pass_registry.h

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "onnxoptimizer/passes/fuse_consecutive_unsqueezes.h"
6161
#include "onnxoptimizer/passes/eliminate_nop_with_unit.h"
6262
#include "onnxoptimizer/passes/rewrite_input_dtype.h"
63+
#include "onnxoptimizer/passes/rewrite_where.h"
6364

6465
namespace ONNX_NAMESPACE {
6566
namespace optimization {
@@ -118,6 +119,7 @@ struct GlobalPassRegistry {
118119
registerPass<EliminateDuplicateInitializer>();
119120
registerPass<AdjustSliceAndMatmul>();
120121
registerPass<RewriteInputDtype>();
122+
registerPass<RewriteWhere>();
121123
}
122124

123125
~GlobalPassRegistry() {

onnxoptimizer/passes/rewrite_where.h

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
// ATTENTION: The code in this file is highly EXPERIMENTAL.
6+
// Adventurous users should note that the APIs will probably change.
7+
8+
#pragma once
9+
10+
#include "onnxoptimizer/pass.h"
11+
#include "onnxoptimizer/passes/pass_util.h"
12+
13+
namespace ONNX_NAMESPACE {
14+
namespace optimization {
15+
16+
// where(not(b), x, y) -> where(b, y, x)
17+
// https://github.com/microsoft/onnxruntime/blob/v1.15.1/onnxruntime/core/optimizer/not_where_fusion.h
18+
struct RewriteWhere final : public PredicateBasedPass {
19+
explicit RewriteWhere()
20+
: PredicateBasedPass(PassType::Nop, PassEfficiency::Partial,
21+
PassOptimizationType::Compute) {}
22+
23+
std::string getPassName() const override {
24+
return "rewrite_where";
25+
}
26+
27+
bool patternMatchPredicate(Node* node) override {
28+
bool isWhere = CheckKind(node, Symbol("Where"));
29+
if (isWhere) {
30+
return CheckKind(node->inputs()[0]->node(), Symbol("Not"));
31+
}
32+
return false;
33+
}
34+
bool runTransform(Node* node, Graph& graph,
35+
NodeDestroyType& destroy_current) override {
36+
destroy_current = NodeDestroyType::DestroyZero;
37+
Node* previous_node = node->input(0)->node();
38+
if (previous_node->output()->uses().size() == 1) {
39+
const bool replacing_success =
40+
tryReplacingAllUsesWith(node->input(0), previous_node->input(0));
41+
if (!replacing_success) {
42+
return false;
43+
}
44+
auto x = node->inputs()[1];
45+
auto y = node->inputs()[2];
46+
node->replaceInput(1, y);
47+
node->replaceInput(2, x);
48+
previous_node->destroy();
49+
return true;
50+
}
51+
return false;
52+
}
53+
};
54+
55+
} // namespace optimization
56+
} // namespace ONNX_NAMESPACE

onnxoptimizer/test/optimizer_test.py

+26
Original file line numberDiff line numberDiff line change
@@ -4597,6 +4597,32 @@ def test_eliminate_consecutive_idempotent_op(self):
45974597
assert optimized_model.graph.node[0].op_type == "Constant"
45984598
assert optimized_model.graph.node[1].op_type == "Reshape"
45994599

4600+
def test_rewrite_where(self):
4601+
model = parser.parse_model("""
4602+
<
4603+
ir_version: 7,
4604+
opset_import:["": 11]
4605+
>
4606+
agraph (bool[4] A, float[4] X, float[4] Y) => (float[4] F, float[4] H)
4607+
{
4608+
B = Not(A)
4609+
Z = Where(B, X, Y)
4610+
F = Sign(Z)
4611+
M = And(A,A)
4612+
G = Where(M, X, Y)
4613+
H = Sign(G)
4614+
}
4615+
""")
4616+
4617+
optimized_model = self._optimized(
4618+
model,["rewrite_where"], True)
4619+
4620+
assert len(optimized_model.graph.node) == 5
4621+
assert set([i.op_type for i in optimized_model.graph.node]) == {'Where', 'And', 'Sign'}
4622+
assert optimized_model.graph.node[0].input == ['A', 'Y', 'X']
4623+
assert optimized_model.graph.node[3].input == ['M', 'X', 'Y']
4624+
4625+
46004626

46014627
if __name__ == "__main__":
46024628
unittest.main()

0 commit comments

Comments
 (0)