Skip to content

Commit 44658d1

Browse files
author
wenyuchi.wyc
committed
Support fuse bn into ConvTranspose.
1 parent 74fdf9c commit 44658d1

File tree

1 file changed

+55
-4
lines changed

1 file changed

+55
-4
lines changed

onnxoptimizer/passes/fuse_bn_into_conv.h

+55-4
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,42 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
6969
}
7070
}
7171

72-
bool modify_conv(Node* conv, Node* bn, Graph& graph) {
72+
void scale_by_dim(Tensor& W, Tensor& s, const int axis) {
73+
ONNX_ASSERT(W.sizes().size() > 1 && s.sizes().size() == 1 && s.sizes()[0] == W.sizes()[axis]);
74+
ONNX_ASSERT(s.elem_type() == W.elem_type());
75+
const int64_t inner_size = W.size_from_dim(axis+1);
76+
const int64_t outer_size = axis > 0 ? std::accumulate(W.sizes().begin(), W.sizes().begin() + axis, 1, std::multiplies<int>()) : 1;
77+
const int64_t axis_size = W.sizes()[axis];
78+
79+
#define DO_SCALE(TENSOR_TYPE) \
80+
TENSOR_TYPE* ptr = W.data<TENSOR_TYPE>(); \
81+
const TENSOR_TYPE* s_ptr = s.data<TENSOR_TYPE>(); \
82+
int64_t counter = 0; \
83+
for (int64_t i = 0; i < outer_size; ++i) { \
84+
for (int64_t j = 0; j < axis_size; ++j) { \
85+
for (int64_t k = 0; k < inner_size; ++k) { \
86+
ptr[counter++] *= s_ptr[j]; \
87+
} \
88+
} \
89+
}
90+
91+
switch (s.elem_type()) {
92+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
93+
DO_SCALE(float)
94+
break;
95+
}
96+
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
97+
DO_SCALE(double)
98+
break;
99+
}
100+
default:
101+
TENSOR_ASSERTM(
102+
false, "Operation scale_by_dim not supported for data type %s", to_string(W.elem_type()).c_str());
103+
}
104+
#undef DO_SCALE
105+
}
106+
107+
bool modify_conv(Node* conv, Node* bn, Graph& graph, const bool is_conv) {
73108
const auto& bn_inputs = bn->inputs();
74109
const auto& conv_inputs = conv->inputs();
75110
auto end_iter = graph.initializers().end();
@@ -136,7 +171,6 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
136171
var.add(eps); \
137172
var.sqrt(); \
138173
s.divide(var); \
139-
W.scale_by_first_dim(s); \
140174
bc.subtract(m); \
141175
bc.multiply(s); \
142176
bc.add(bbn);
@@ -154,21 +188,38 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
154188
return false;
155189
}
156190
#undef DO_COMPUTATION
191+
if (is_conv) {
192+
scale_by_dim(W, s, 0);
193+
} else {
194+
scale_by_dim(W, s, 1);
195+
}
157196
replace_inputs(W, bc, conv, graph);
158197
return true;
159198
}
160199

161-
bool patternMatchPredicate(Node* node) override {
200+
inline bool matchConvBn(Node *node) {
162201
return node->kind() == kBatchNormalization &&
163202
node->inputs()[0]->node()->kind() == kConv;
164203
}
204+
205+
inline bool matchConvTransposeBn(Node *node) {
206+
return node->kind() == kBatchNormalization &&
207+
node->inputs()[0]->node()->kind() == kConvTranspose;
208+
}
209+
210+
bool patternMatchPredicate(Node *node) override {
211+
return matchConvBn(node) || matchConvTransposeBn(node);
212+
}
213+
165214
bool runTransform(Node* n, Graph& graph,
166215
NodeDestroyType& destroy_current) override {
216+
const bool is_conv = matchConvBn(n);
217+
167218
Node* bn = n;
168219
Node* conv = n->inputs()[0]->node();
169220
auto origInput = bn->inputs()[0];
170221
if (origInput->uses().size() > 1 || bn->outputs().size() > 1 ||
171-
!modify_conv(conv, bn, graph)) {
222+
!modify_conv(conv, bn, graph, is_conv)) {
172223
destroy_current = NodeDestroyType::DestroyZero;
173224
return false;
174225
}

0 commit comments

Comments
 (0)