@@ -1125,6 +1125,31 @@ def test_fuse_add_bias_into_conv_squeeze_4d_bias_no_fuse(self):
1125
1125
assert optimized_model .graph .node [0 ].op_type == 'Conv'
1126
1126
assert optimized_model .graph .node [1 ].op_type == 'Add'
1127
1127
1128
+ # type: () -> None
1129
+ def test_fuse_add_bias_into_conv_with_non_constant_bias (self ):
1130
+ nodes = [helper .make_node ("Conv" , ["X" , "Y" ], ["Z" ]),
1131
+ helper .make_node ("Sin" , ["A" ], ["B" ]),
1132
+ helper .make_node ("Add" , ["Z" , "B" ], ["C" ])]
1133
+ graph = helper .make_graph (
1134
+ nodes ,
1135
+ "test" ,
1136
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (1 , 5 , 3 , 3 )),
1137
+ helper .make_tensor_value_info (
1138
+ "Y" , TensorProto .FLOAT , (16 , 5 , 3 , 3 )),
1139
+ helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (16 , 1 , 1 ))],
1140
+ [helper .make_tensor_value_info (
1141
+ "C" , TensorProto .FLOAT , (1 , 16 , 1 , 1 ))],
1142
+ value_info = [helper .make_tensor_value_info (
1143
+ "B" , TensorProto .FLOAT , (16 , 1 , 1 ))]
1144
+ )
1145
+ optimized_model = self ._optimized (graph , ["fuse_add_bias_into_conv" ])
1146
+
1147
+ assert len (list (optimized_model .graph .node )) == 3
1148
+ assert optimized_model .graph .node [0 ].op_type == 'Sin'
1149
+ assert optimized_model .graph .node [1 ].op_type == 'Squeeze'
1150
+ assert optimized_model .graph .node [2 ].op_type == 'Conv'
1151
+ assert optimized_model .graph .output [0 ].name == 'C'
1152
+
1128
1153
def test_fuse_matmul_add_bias_into_gemm (self ): # type: () -> None
1129
1154
matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
1130
1155
add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
0 commit comments