Skip to content

Commit cc8fd26

Browse files
committed
support fusing non constant bias into conv
Signed-off-by: daquexian <[email protected]>
1 parent 9fb5721 commit cc8fd26

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

onnxoptimizer/passes/fuse_add_bias_into_conv.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,15 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
6464
destroy_current = NodeDestroyType::DestroyZero;
6565
auto orig_conv = n->inputs()[0];
6666
auto orig_bias = n->inputs()[1];
67-
// check if bias is Const or in graph's initializers
68-
if (orig_bias->node()->kind() != kConstant &&
69-
orig_bias->node()->kind() != kParam) {
70-
return false;
71-
}
7267
// check if conv is only used by Add
7368
if (orig_conv->uses().size() > 1) {
7469
return false;
7570
}
7671
auto conv_shape = orig_conv->sizes();
72+
// We need the size of bias
73+
if (!orig_bias->has_sizes()) {
74+
return false;
75+
}
7776
auto bias_shape = orig_bias->sizes();
7877
auto weight_shape = orig_conv->node()->inputs()[1]->sizes();
7978
int64_t M = -1;

onnxoptimizer/test/optimizer_test.py

+25
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,31 @@ def test_fuse_add_bias_into_conv_squeeze_4d_bias_no_fuse(self):
11251125
assert optimized_model.graph.node[0].op_type == 'Conv'
11261126
assert optimized_model.graph.node[1].op_type == 'Add'
11271127

1128+
# type: () -> None
1129+
def test_fuse_add_bias_into_conv_with_non_constant_bias(self):
1130+
nodes = [helper.make_node("Conv", ["X", "Y"], ["Z"]),
1131+
helper.make_node("Sin", ["A"], ["B"]),
1132+
helper.make_node("Add", ["Z", "B"], ["C"])]
1133+
graph = helper.make_graph(
1134+
nodes,
1135+
"test",
1136+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 5, 3, 3)),
1137+
helper.make_tensor_value_info(
1138+
"Y", TensorProto.FLOAT, (16, 5, 3, 3)),
1139+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1))],
1140+
[helper.make_tensor_value_info(
1141+
"C", TensorProto.FLOAT, (1, 16, 1, 1))],
1142+
value_info=[helper.make_tensor_value_info(
1143+
"B", TensorProto.FLOAT, (16, 1, 1))]
1144+
)
1145+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])
1146+
1147+
assert len(list(optimized_model.graph.node)) == 3
1148+
assert optimized_model.graph.node[0].op_type == 'Sin'
1149+
assert optimized_model.graph.node[1].op_type == 'Squeeze'
1150+
assert optimized_model.graph.node[2].op_type == 'Conv'
1151+
assert optimized_model.graph.output[0].name == 'C'
1152+
11281153
def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None
11291154
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
11301155
add = helper.make_node("Add", ["Z", "B"], ["A"])

0 commit comments

Comments
 (0)