Skip to content

Commit 8086dfe

Browse files
committed
Add sign to idempotent ops list
1 parent 0e49375 commit 8086dfe

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

onnxoptimizer/passes/eliminate_consecutive_idempotent_ops.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct EliminateConsecutiveIdempotentOps final : public PredicateBasedPass {
2424

2525
bool patternMatchPredicate(Node* node) override {
2626
static const std::unordered_set<std::string> idempotent_ops = {
27-
"Ceil", "Floor", "Round", "Relu", "Reshape"};
27+
"Ceil", "Floor", "Round", "Relu", "Reshape", "Sign"};
2828
for (const auto& op : idempotent_ops) {
2929
// TODO: support uses().size() > 1 for ops except Reshape
3030
if (CheckKind(node, Symbol(op), 0, Symbol(op)) &&

onnxoptimizer/test/optimizer_test.py

+19
Original file line numberDiff line numberDiff line change
@@ -4597,6 +4597,25 @@ def test_eliminate_consecutive_idempotent_op(self):
45974597
assert optimized_model.graph.node[0].op_type == "Constant"
45984598
assert optimized_model.graph.node[1].op_type == "Reshape"
45994599

4600+
def test_eliminate_consecutive_idempotent_sign_op(self):
4601+
model = parser.parse_model("""
4602+
<
4603+
ir_version: 7,
4604+
opset_import:["": 11]
4605+
>
4606+
agraph (float[1, 2, 3] X) => (float[1, 2, 3] Z)
4607+
{
4608+
T1 = Sign(X)
4609+
T2 = Sign(T1)
4610+
T3 = Sign(T2)
4611+
Z = Sign(T3)
4612+
}
4613+
""")
4614+
4615+
optimized_model = self._optimized(
4616+
model, ['eliminate_consecutive_idempotent_ops', 'eliminate_deadend'], True)
4617+
assert len(optimized_model.graph.node) == 1
4618+
assert optimized_model.graph.node[0].op_type == "Sign"
46004619

46014620
if __name__ == "__main__":
46024621
unittest.main()

0 commit comments

Comments
 (0)