@@ -69,7 +69,42 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
69
69
}
70
70
}
71
71
72
- bool modify_conv (Node* conv, Node* bn, Graph& graph) {
72
+ void scale_by_dim (Tensor& W, Tensor& s, const int axis) {
73
+ ONNX_ASSERT (W.sizes ().size () > 1 && s.sizes ().size () == 1 && s.sizes ()[0 ] == W.sizes ()[axis]);
74
+ ONNX_ASSERT (s.elem_type () == W.elem_type ());
75
+ const int64_t inner_size = W.size_from_dim (axis+1 );
76
+ const int64_t outer_size = axis > 0 ? std::accumulate (W.sizes ().begin (), W.sizes ().begin () + axis, 1 , std::multiplies<int >()) : 1 ;
77
+ const int64_t axis_size = W.sizes ()[axis];
78
+
79
+ #define DO_SCALE (TENSOR_TYPE ) \
80
+ TENSOR_TYPE* ptr = W.data <TENSOR_TYPE>(); \
81
+ const TENSOR_TYPE* s_ptr = s.data <TENSOR_TYPE>(); \
82
+ int64_t counter = 0 ; \
83
+ for (int64_t i = 0 ; i < outer_size; ++i) { \
84
+ for (int64_t j = 0 ; j < axis_size; ++j) { \
85
+ for (int64_t k = 0 ; k < inner_size; ++k) { \
86
+ ptr[counter++] *= s_ptr[j]; \
87
+ } \
88
+ } \
89
+ }
90
+
91
+ switch (s.elem_type ()) {
92
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
93
+ DO_SCALE (float )
94
+ break ;
95
+ }
96
+ case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
97
+ DO_SCALE (double )
98
+ break ;
99
+ }
100
+ default :
101
+ TENSOR_ASSERTM (
102
+ false , " Operation scale_by_dim not supported for data type %s" , to_string (W.elem_type ()).c_str ());
103
+ }
104
+ #undef DO_SCALE
105
+ }
106
+
107
+ bool modify_conv (Node* conv, Node* bn, Graph& graph, const bool is_conv) {
73
108
const auto & bn_inputs = bn->inputs ();
74
109
const auto & conv_inputs = conv->inputs ();
75
110
auto end_iter = graph.initializers ().end ();
@@ -136,7 +171,6 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
136
171
var.add (eps); \
137
172
var.sqrt (); \
138
173
s.divide (var); \
139
- W.scale_by_first_dim (s); \
140
174
bc.subtract (m); \
141
175
bc.multiply (s); \
142
176
bc.add (bbn);
@@ -154,21 +188,38 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
154
188
return false ;
155
189
}
156
190
#undef DO_COMPUTATION
191
+ if (is_conv) {
192
+ scale_by_dim (W, s, 0 );
193
+ } else {
194
+ scale_by_dim (W, s, 1 );
195
+ }
157
196
replace_inputs (W, bc, conv, graph);
158
197
return true ;
159
198
}
160
199
161
- bool patternMatchPredicate (Node* node) override {
200
+ inline bool matchConvBn (Node * node) {
162
201
return node->kind () == kBatchNormalization &&
163
202
node->inputs ()[0 ]->node ()->kind () == kConv ;
164
203
}
204
+
205
+ inline bool matchConvTransposeBn (Node *node) {
206
+ return node->kind () == kBatchNormalization &&
207
+ node->inputs ()[0 ]->node ()->kind () == kConvTranspose ;
208
+ }
209
+
210
+ bool patternMatchPredicate (Node *node) override {
211
+ return matchConvBn (node) || matchConvTransposeBn (node);
212
+ }
213
+
165
214
bool runTransform (Node* n, Graph& graph,
166
215
NodeDestroyType& destroy_current) override {
216
+ const bool is_conv = matchConvBn (n);
217
+
167
218
Node* bn = n;
168
219
Node* conv = n->inputs ()[0 ]->node ();
169
220
auto origInput = bn->inputs ()[0 ];
170
221
if (origInput->uses ().size () > 1 || bn->outputs ().size () > 1 ||
171
- !modify_conv (conv, bn, graph)) {
222
+ !modify_conv (conv, bn, graph, is_conv )) {
172
223
destroy_current = NodeDestroyType::DestroyZero;
173
224
return false ;
174
225
}
0 commit comments