Skip to content

Commit ec8ff8e

Browse files
committed
upgrade onnx submodule, support qdq
Signed-off-by: daquexian <[email protected]>
1 parent cc8fd26 commit ec8ff8e

File tree

3 files changed

+70
-10
lines changed

3 files changed

+70
-10
lines changed

onnxoptimizer/passes/fuse_add_bias_into_conv.h

+33-9
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,21 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
3838
}
3939
static Node *makeSqueezeOrUnsqueeze(Graph &graph, std::vector<int64_t> &axes,
4040
Value *input, Node *target_node,
41-
BuiltinSymbol k) {
41+
BuiltinSymbol k, bool is_input_qdq) {
4242
assert(k == kSqueeze || k == kUnsqueeze);
4343
Node *squeeze = graph.create(k, 1);
44-
int opset_version = getOpsetVersion(graph);
44+
Node *dequant_node = nullptr;
45+
Node *quant_node = nullptr;
46+
if (is_input_qdq) {
47+
dequant_node = input->node();
48+
quant_node = dequant_node->input(0)->node();
49+
target_node = quant_node;
50+
input = target_node->input(0);
51+
dequant_node->output()->clearMetadata();
52+
quant_node->output()->clearMetadata();
53+
}
4554
squeeze->addInput(input);
55+
int opset_version = getOpsetVersion(graph);
4656
int version_threshold = 13;
4757
if (opset_version < version_threshold && opset_version != 0) {
4858
squeeze->is_(kaxes, std::move(axes));
@@ -54,7 +64,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
5464
Value *tv = graph.addInitializerAndInput(t);
5565
squeeze->addInput(tv);
5666
}
67+
if (is_input_qdq) {
68+
quant_node->replaceInput(0, squeeze->output());
69+
}
5770
squeeze->insertBefore(target_node);
71+
if (is_input_qdq) {
72+
return dequant_node;
73+
}
5874
return squeeze;
5975
}
6076
bool runTransform(Node *n, Graph &graph,
@@ -115,13 +131,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
115131
if (bias_shape.size() > 1) {
116132
std::vector<int64_t> axes(bias_shape.size() - 1);
117133
std::iota(axes.begin(), axes.end(), 0);
118-
Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input,
119-
orig_conv->node(), kSqueeze);
134+
Node *squeeze = makeSqueezeOrUnsqueeze(
135+
graph, axes, conv_3rd_input, orig_conv->node(), kSqueeze, false);
120136
conv_3rd_input = squeeze->output();
121137
} else if (bias_shape.size() == 0) {
122138
std::vector<int64_t> axes = {0};
123-
Node *unsqueeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input,
124-
orig_conv->node(), kUnsqueeze);
139+
Node *unsqueeze = makeSqueezeOrUnsqueeze(
140+
graph, axes, conv_3rd_input, orig_conv->node(), kUnsqueeze, false);
125141
conv_3rd_input = unsqueeze->output();
126142
}
127143
if (M > 1) {
@@ -149,17 +165,25 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
149165
bias_shape[1 + bias_shape.size() - static_cast<unsigned>(rank)]
150166
.dim == M) {
151167
ONNX_ASSERT(bias_shape.size() > 1);
168+
const bool is_input_qdq =
169+
orig_bias->node()->kind() == Symbol("DequantizeLinear") &&
170+
orig_bias->node()->input(0)->node()->kind() ==
171+
Symbol("QuantizeLinear");
152172
if (orig_bias->node()->kind() != kParam &&
153173
orig_conv->node()->isBefore(orig_bias->node())) {
174+
if (is_input_qdq) {
175+
orig_bias->node()->input(0)->node()->moveBefore(orig_conv->node());
176+
}
154177
orig_bias->node()->moveBefore(orig_conv->node());
155178
}
156179
std::vector<int64_t> axes(bias_shape.size());
157180
std::iota(axes.begin(), axes.end(), static_cast<int64_t>(0));
158181
axes.erase(axes.begin() +
159182
(1 + bias_shape.size() - static_cast<unsigned>(rank)));
160-
Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, orig_bias,
161-
orig_conv->node(), kSqueeze);
162-
orig_conv->node()->addInput(squeeze->output());
183+
184+
Node *new_bias = makeSqueezeOrUnsqueeze(
185+
graph, axes, orig_bias, orig_conv->node(), kSqueeze, is_input_qdq);
186+
orig_conv->node()->addInput(new_bias->output());
163187
} else {
164188
return false;
165189
}

onnxoptimizer/test/optimizer_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def _optimized(self, graph_or_model, opts, fixed_point=False, compare_result=Tru
166166
graph_or_model, producer_name='onnx-test', opset_imports=opset_imports, **kwargs)
167167
checker.check_model(orig_model)
168168
optimized_model = onnxoptimizer.optimize(orig_model, opts, fixed_point)
169+
print(str(optimized_model))
169170
checker.check_model(optimized_model)
170171
if compare_result and len(optimized_model.graph.node) > 0:
171172
if has_ort:
@@ -1150,6 +1151,41 @@ def test_fuse_add_bias_into_conv_with_non_constant_bias(self):
11501151
assert optimized_model.graph.node[2].op_type == 'Conv'
11511152
assert optimized_model.graph.output[0].name == 'C'
11521153

1154+
# type: () -> None
1155+
def test_fuse_add_bias_into_conv_with_quanted_bias(self):
1156+
nodes = [helper.make_node("Conv", ["X", "Y"], ["Z"]),
1157+
helper.make_node("QuantizeLinear", ["A", "scale", "zero_point"], ["B"], axis=0),
1158+
helper.make_node("DequantizeLinear", ["B", "scale", "zero_point"], ["C"], axis=0),
1159+
helper.make_node("Add", ["Z", "C"], ["D"])]
1160+
graph = helper.make_graph(
1161+
nodes,
1162+
"test",
1163+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 5, 3, 3)),
1164+
helper.make_tensor_value_info(
1165+
"Y", TensorProto.FLOAT, (16, 5, 3, 3)),
1166+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1))],
1167+
[helper.make_tensor_value_info(
1168+
"D", TensorProto.FLOAT, (1, 16, 1, 1))],
1169+
[helper.make_tensor("scale", TensorProto.FLOAT,
1170+
dims=(16,),
1171+
vals=np.random.rand(16).astype(np.float32).tobytes(),
1172+
raw=True),
1173+
helper.make_tensor("zero_point", TensorProto.INT8,
1174+
dims=(16,),
1175+
vals=np.zeros([16]).astype(np.int8).tobytes(),
1176+
raw=True)],
1177+
value_info=[helper.make_tensor_value_info(
1178+
"C", TensorProto.FLOAT, (16, 1, 1))]
1179+
)
1180+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"], opset_imports=[helper.make_opsetid("", 13)])
1181+
1182+
assert len(list(optimized_model.graph.node)) == 4
1183+
assert optimized_model.graph.node[0].op_type == 'Squeeze'
1184+
assert optimized_model.graph.node[1].op_type == 'QuantizeLinear'
1185+
assert optimized_model.graph.node[2].op_type == 'DequantizeLinear'
1186+
assert optimized_model.graph.node[3].op_type == 'Conv'
1187+
assert optimized_model.graph.output[0].name == 'D'
1188+
11531189
def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None
11541190
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
11551191
add = helper.make_node("Add", ["Z", "B"], ["A"])

third_party/onnx

Submodule onnx updated 302 files

0 commit comments

Comments
 (0)