@@ -38,11 +38,21 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
38
38
}
39
39
static Node *makeSqueezeOrUnsqueeze (Graph &graph, std::vector<int64_t > &axes,
40
40
Value *input, Node *target_node,
41
- BuiltinSymbol k) {
41
+ BuiltinSymbol k, bool is_input_qdq ) {
42
42
assert (k == kSqueeze || k == kUnsqueeze );
43
43
Node *squeeze = graph.create (k, 1 );
44
- int opset_version = getOpsetVersion (graph);
44
+ Node *dequant_node = nullptr ;
45
+ Node *quant_node = nullptr ;
46
+ if (is_input_qdq) {
47
+ dequant_node = input->node ();
48
+ quant_node = dequant_node->input (0 )->node ();
49
+ target_node = quant_node;
50
+ input = target_node->input (0 );
51
+ dequant_node->output ()->clearMetadata ();
52
+ quant_node->output ()->clearMetadata ();
53
+ }
45
54
squeeze->addInput (input);
55
+ int opset_version = getOpsetVersion (graph);
46
56
int version_threshold = 13 ;
47
57
if (opset_version < version_threshold && opset_version != 0 ) {
48
58
squeeze->is_ (kaxes, std::move (axes));
@@ -54,7 +64,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
54
64
Value *tv = graph.addInitializerAndInput (t);
55
65
squeeze->addInput (tv);
56
66
}
67
+ if (is_input_qdq) {
68
+ quant_node->replaceInput (0 , squeeze->output ());
69
+ }
57
70
squeeze->insertBefore (target_node);
71
+ if (is_input_qdq) {
72
+ return dequant_node;
73
+ }
58
74
return squeeze;
59
75
}
60
76
bool runTransform (Node *n, Graph &graph,
@@ -115,13 +131,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
115
131
if (bias_shape.size () > 1 ) {
116
132
std::vector<int64_t > axes (bias_shape.size () - 1 );
117
133
std::iota (axes.begin (), axes.end (), 0 );
118
- Node *squeeze = makeSqueezeOrUnsqueeze (graph, axes, conv_3rd_input,
119
- orig_conv->node (), kSqueeze );
134
+ Node *squeeze = makeSqueezeOrUnsqueeze (
135
+ graph, axes, conv_3rd_input, orig_conv->node (), kSqueeze , false );
120
136
conv_3rd_input = squeeze->output ();
121
137
} else if (bias_shape.size () == 0 ) {
122
138
std::vector<int64_t > axes = {0 };
123
- Node *unsqueeze = makeSqueezeOrUnsqueeze (graph, axes, conv_3rd_input,
124
- orig_conv->node (), kUnsqueeze );
139
+ Node *unsqueeze = makeSqueezeOrUnsqueeze (
140
+ graph, axes, conv_3rd_input, orig_conv->node (), kUnsqueeze , false );
125
141
conv_3rd_input = unsqueeze->output ();
126
142
}
127
143
if (M > 1 ) {
@@ -149,17 +165,25 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
149
165
bias_shape[1 + bias_shape.size () - static_cast <unsigned >(rank)]
150
166
.dim == M) {
151
167
ONNX_ASSERT (bias_shape.size () > 1 );
168
+ const bool is_input_qdq =
169
+ orig_bias->node ()->kind () == Symbol (" DequantizeLinear" ) &&
170
+ orig_bias->node ()->input (0 )->node ()->kind () ==
171
+ Symbol (" QuantizeLinear" );
152
172
if (orig_bias->node ()->kind () != kParam &&
153
173
orig_conv->node ()->isBefore (orig_bias->node ())) {
174
+ if (is_input_qdq) {
175
+ orig_bias->node ()->input (0 )->node ()->moveBefore (orig_conv->node ());
176
+ }
154
177
orig_bias->node ()->moveBefore (orig_conv->node ());
155
178
}
156
179
std::vector<int64_t > axes (bias_shape.size ());
157
180
std::iota (axes.begin (), axes.end (), static_cast <int64_t >(0 ));
158
181
axes.erase (axes.begin () +
159
182
(1 + bias_shape.size () - static_cast <unsigned >(rank)));
160
- Node *squeeze = makeSqueezeOrUnsqueeze (graph, axes, orig_bias,
161
- orig_conv->node (), kSqueeze );
162
- orig_conv->node ()->addInput (squeeze->output ());
183
+
184
+ Node *new_bias = makeSqueezeOrUnsqueeze (
185
+ graph, axes, orig_bias, orig_conv->node (), kSqueeze , is_input_qdq);
186
+ orig_conv->node ()->addInput (new_bias->output ());
163
187
} else {
164
188
return false ;
165
189
}
0 commit comments