Skip to content

Commit 746c02f

Browse files
authored
Eliminate nop cast (onnx#10)
* Add pass to eliminate no-op casts Signed-off-by: crisp-snakey <[email protected]> * Add basic test for no-op cast elimination Signed-off-by: crisp-snakey <[email protected]> * Formatting Signed-off-by: crisp-snakey <[email protected]> * Propagate sizes Signed-off-by: crisp-snakey <[email protected]> * Propagate output name Signed-off-by: crisp-snakey <[email protected]>
1 parent e2d3d56 commit 746c02f

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

onnxoptimizer/pass_registry.h

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "onnxoptimizer/passes/eliminate_deadend.h"
1212
#include "onnxoptimizer/passes/eliminate_identity.h"
13+
#include "onnxoptimizer/passes/eliminate_nop_cast.h"
1314
#include "onnxoptimizer/passes/eliminate_nop_dropout.h"
1415
#include "onnxoptimizer/passes/eliminate_nop_monotone_argmax.h"
1516
#include "onnxoptimizer/passes/eliminate_nop_pad.h"
@@ -44,6 +45,7 @@ struct GlobalPassRegistry {
4445
// Register the optimization passes to the optimizer.
4546
registerPass<NopEmptyPass>();
4647
registerPass<EliminateDeadEnd>();
48+
registerPass<EliminateNopCast>();
4749
registerPass<EliminateNopDropout>();
4850
registerPass<EliminateIdentity>();
4951
registerPass<EliminateNopMonotoneArgmax>();
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// ATTENTION: The code in this file is highly EXPERIMENTAL.
2+
// Adventurous users should note that the APIs will probably change.
3+
4+
#pragma once
5+
6+
#include "onnxoptimizer/pass.h"
7+
8+
namespace ONNX_NAMESPACE {
9+
namespace optimization {
10+
11+
struct EliminateNopCast final : public PredicateBasedPass {
12+
explicit EliminateNopCast()
13+
: PredicateBasedPass(
14+
PassType::Nop,
15+
PassEfficiency::Complete,
16+
PassOptimizationType::Compute) {}
17+
18+
std::string getPassName() const override {
19+
return "eliminate_nop_cast";
20+
}
21+
22+
bool patternMatchPredicate(Node* node) override {
23+
return (node->kind() == kCast && node->hasAttribute(kto) &&
24+
node->input()->elemType() == node->i(kto));
25+
}
26+
27+
bool runTransform(Node* node, Graph& graph, NodeDestroyType& destroy_current)
28+
override {
29+
if (node->output()->has_sizes()) {
30+
node->input()->setSizes(node->output()->sizes());
31+
}
32+
if (std::find(graph.outputs().rbegin(), graph.outputs().rend(),
33+
node->output()) != graph.outputs().rend()) {
34+
node->input()->setUniqueName(node->output()->uniqueName());
35+
}
36+
node->output()->replaceAllUsesWith(node->input());
37+
destroy_current = NodeDestroyType::DestroyOne;
38+
return true;
39+
}
40+
};
41+
42+
} // namespace optimization
43+
} // namespace ONNX_NAMESPACE

onnxoptimizer/test/optimizer_test.py

+12
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,18 @@ def test_eliminate_identity_multiple_uses(self): # type: () -> None
161161
assert node.op_type != "Identity"
162162
assert len(optimized_model.graph.node) == 2
163163

164+
def test_nop_cast(self):
165+
cast = helper.make_node("Cast", ["A"], ["B"], to=TensorProto.FLOAT)
166+
graph = helper.make_graph(
167+
[cast],
168+
"test",
169+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (2, 3))],
170+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (2, 3))])
171+
172+
optimized_model = self._optimized(graph, ["eliminate_nop_cast"])
173+
174+
assert len(optimized_model.graph.node) == 0
175+
164176
def test_nop_transpose_graph_output(self): # type: () -> None
165177
add = helper.make_node("Add", ["X", "Y"], ["A"])
166178
trans = helper.make_node("Transpose", ["A"], ["B"], perm=[0, 1])

0 commit comments

Comments
 (0)