@@ -38,11 +38,22 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
38
38
}
39
39
static Node *makeSqueezeOrUnsqueeze (Graph &graph, std::vector<int64_t > &axes,
40
40
Value *input, Node *target_node,
41
- BuiltinSymbol k) {
41
+ BuiltinSymbol k, bool is_input_qdq ) {
42
42
assert (k == kSqueeze || k == kUnsqueeze );
43
43
Node *squeeze = graph.create (k, 1 );
44
- int opset_version = getOpsetVersion (graph);
44
+ Node *dequant_node = nullptr ;
45
+ Node *quant_node = nullptr ;
46
+ // insert squeeze op before qdq
47
+ if (is_input_qdq) {
48
+ dequant_node = input->node ();
49
+ quant_node = dequant_node->input (0 )->node ();
50
+ target_node = quant_node;
51
+ input = target_node->input (0 );
52
+ dequant_node->output ()->clearMetadata ();
53
+ quant_node->output ()->clearMetadata ();
54
+ }
45
55
squeeze->addInput (input);
56
+ int opset_version = getOpsetVersion (graph);
46
57
int version_threshold = 13 ;
47
58
if (opset_version < version_threshold && opset_version != 0 ) {
48
59
squeeze->is_ (kaxes, std::move (axes));
@@ -54,7 +65,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
54
65
Value *tv = graph.addInitializerAndInput (t);
55
66
squeeze->addInput (tv);
56
67
}
68
+ if (is_input_qdq) {
69
+ quant_node->replaceInput (0 , squeeze->output ());
70
+ }
57
71
squeeze->insertBefore (target_node);
72
+ if (is_input_qdq) {
73
+ return dequant_node;
74
+ }
58
75
return squeeze;
59
76
}
60
77
bool runTransform (Node *n, Graph &graph,
@@ -115,13 +132,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
115
132
if (bias_shape.size () > 1 ) {
116
133
std::vector<int64_t > axes (bias_shape.size () - 1 );
117
134
std::iota (axes.begin (), axes.end (), 0 );
118
- Node *squeeze = makeSqueezeOrUnsqueeze (graph, axes, conv_3rd_input,
119
- orig_conv->node (), kSqueeze );
135
+ Node *squeeze = makeSqueezeOrUnsqueeze (
136
+ graph, axes, conv_3rd_input, orig_conv->node (), kSqueeze , false );
120
137
conv_3rd_input = squeeze->output ();
121
138
} else if (bias_shape.size () == 0 ) {
122
139
std::vector<int64_t > axes = {0 };
123
- Node *unsqueeze = makeSqueezeOrUnsqueeze (graph, axes, conv_3rd_input,
124
- orig_conv->node (), kUnsqueeze );
140
+ Node *unsqueeze = makeSqueezeOrUnsqueeze (
141
+ graph, axes, conv_3rd_input, orig_conv->node (), kUnsqueeze , false );
125
142
conv_3rd_input = unsqueeze->output ();
126
143
}
127
144
if (M > 1 ) {
@@ -149,17 +166,25 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
149
166
bias_shape[1 + bias_shape.size () - static_cast <unsigned >(rank)]
150
167
.dim == M) {
151
168
ONNX_ASSERT (bias_shape.size () > 1 );
169
+ const bool is_input_qdq =
170
+ orig_bias->node ()->kind () == Symbol (" DequantizeLinear" ) &&
171
+ orig_bias->node ()->input (0 )->node ()->kind () ==
172
+ Symbol (" QuantizeLinear" );
152
173
if (orig_bias->node ()->kind () != kParam &&
153
174
orig_conv->node ()->isBefore (orig_bias->node ())) {
175
+ if (is_input_qdq) {
176
+ orig_bias->node ()->input (0 )->node ()->moveBefore (orig_conv->node ());
177
+ }
154
178
orig_bias->node ()->moveBefore (orig_conv->node ());
155
179
}
156
180
std::vector<int64_t > axes (bias_shape.size ());
157
181
std::iota (axes.begin (), axes.end (), static_cast <int64_t >(0 ));
158
182
axes.erase (axes.begin () +
159
183
(1 + bias_shape.size () - static_cast <unsigned >(rank)));
160
- Node *squeeze = makeSqueezeOrUnsqueeze (graph, axes, orig_bias,
161
- orig_conv->node (), kSqueeze );
162
- orig_conv->node ()->addInput (squeeze->output ());
184
+
185
+ Node *new_bias = makeSqueezeOrUnsqueeze (
186
+ graph, axes, orig_bias, orig_conv->node (), kSqueeze , is_input_qdq);
187
+ orig_conv->node ()->addInput (new_bias->output ());
163
188
} else {
164
189
return false ;
165
190
}
0 commit comments