Skip to content

Commit 8e3c881

Browse files
committed
upgrade onnx submodule, support qdq
Signed-off-by: daquexian <[email protected]>
1 parent cc8fd26 commit 8e3c881

File tree

3 files changed

+70
-10
lines changed

3 files changed

+70
-10
lines changed

onnxoptimizer/passes/fuse_add_bias_into_conv.h

+34-9
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,22 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
3838
}
3939
static Node *makeSqueezeOrUnsqueeze(Graph &graph, std::vector<int64_t> &axes,
4040
Value *input, Node *target_node,
41-
BuiltinSymbol k) {
41+
BuiltinSymbol k, bool is_input_qdq) {
4242
assert(k == kSqueeze || k == kUnsqueeze);
4343
Node *squeeze = graph.create(k, 1);
44-
int opset_version = getOpsetVersion(graph);
44+
Node *dequant_node = nullptr;
45+
Node *quant_node = nullptr;
46+
// insert squeeze op before qdq
47+
if (is_input_qdq) {
48+
dequant_node = input->node();
49+
quant_node = dequant_node->input(0)->node();
50+
target_node = quant_node;
51+
input = target_node->input(0);
52+
dequant_node->output()->clearMetadata();
53+
quant_node->output()->clearMetadata();
54+
}
4555
squeeze->addInput(input);
56+
int opset_version = getOpsetVersion(graph);
4657
int version_threshold = 13;
4758
if (opset_version < version_threshold && opset_version != 0) {
4859
squeeze->is_(kaxes, std::move(axes));
@@ -54,7 +65,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
5465
Value *tv = graph.addInitializerAndInput(t);
5566
squeeze->addInput(tv);
5667
}
68+
if (is_input_qdq) {
69+
quant_node->replaceInput(0, squeeze->output());
70+
}
5771
squeeze->insertBefore(target_node);
72+
if (is_input_qdq) {
73+
return dequant_node;
74+
}
5875
return squeeze;
5976
}
6077
bool runTransform(Node *n, Graph &graph,
@@ -115,13 +132,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
115132
if (bias_shape.size() > 1) {
116133
std::vector<int64_t> axes(bias_shape.size() - 1);
117134
std::iota(axes.begin(), axes.end(), 0);
118-
Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input,
119-
orig_conv->node(), kSqueeze);
135+
Node *squeeze = makeSqueezeOrUnsqueeze(
136+
graph, axes, conv_3rd_input, orig_conv->node(), kSqueeze, false);
120137
conv_3rd_input = squeeze->output();
121138
} else if (bias_shape.size() == 0) {
122139
std::vector<int64_t> axes = {0};
123-
Node *unsqueeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input,
124-
orig_conv->node(), kUnsqueeze);
140+
Node *unsqueeze = makeSqueezeOrUnsqueeze(
141+
graph, axes, conv_3rd_input, orig_conv->node(), kUnsqueeze, false);
125142
conv_3rd_input = unsqueeze->output();
126143
}
127144
if (M > 1) {
@@ -149,17 +166,25 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
149166
bias_shape[1 + bias_shape.size() - static_cast<unsigned>(rank)]
150167
.dim == M) {
151168
ONNX_ASSERT(bias_shape.size() > 1);
169+
const bool is_input_qdq =
170+
orig_bias->node()->kind() == Symbol("DequantizeLinear") &&
171+
orig_bias->node()->input(0)->node()->kind() ==
172+
Symbol("QuantizeLinear");
152173
if (orig_bias->node()->kind() != kParam &&
153174
orig_conv->node()->isBefore(orig_bias->node())) {
175+
if (is_input_qdq) {
176+
orig_bias->node()->input(0)->node()->moveBefore(orig_conv->node());
177+
}
154178
orig_bias->node()->moveBefore(orig_conv->node());
155179
}
156180
std::vector<int64_t> axes(bias_shape.size());
157181
std::iota(axes.begin(), axes.end(), static_cast<int64_t>(0));
158182
axes.erase(axes.begin() +
159183
(1 + bias_shape.size() - static_cast<unsigned>(rank)));
160-
Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, orig_bias,
161-
orig_conv->node(), kSqueeze);
162-
orig_conv->node()->addInput(squeeze->output());
184+
185+
Node *new_bias = makeSqueezeOrUnsqueeze(
186+
graph, axes, orig_bias, orig_conv->node(), kSqueeze, is_input_qdq);
187+
orig_conv->node()->addInput(new_bias->output());
163188
} else {
164189
return false;
165190
}

onnxoptimizer/test/optimizer_test.py

+35
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,41 @@ def test_fuse_add_bias_into_conv_with_non_constant_bias(self):
11501150
assert optimized_model.graph.node[2].op_type == 'Conv'
11511151
assert optimized_model.graph.output[0].name == 'C'
11521152

1153+
# type: () -> None
1154+
def test_fuse_add_bias_into_conv_with_quanted_bias(self):
1155+
nodes = [helper.make_node("Conv", ["X", "Y"], ["Z"]),
1156+
helper.make_node("QuantizeLinear", ["A", "scale", "zero_point"], ["B"], axis=0),
1157+
helper.make_node("DequantizeLinear", ["B", "scale", "zero_point"], ["C"], axis=0),
1158+
helper.make_node("Add", ["Z", "C"], ["D"])]
1159+
graph = helper.make_graph(
1160+
nodes,
1161+
"test",
1162+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 5, 3, 3)),
1163+
helper.make_tensor_value_info(
1164+
"Y", TensorProto.FLOAT, (16, 5, 3, 3)),
1165+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1))],
1166+
[helper.make_tensor_value_info(
1167+
"D", TensorProto.FLOAT, (1, 16, 1, 1))],
1168+
[helper.make_tensor("scale", TensorProto.FLOAT,
1169+
dims=(16,),
1170+
vals=np.random.rand(16).astype(np.float32).tobytes(),
1171+
raw=True),
1172+
helper.make_tensor("zero_point", TensorProto.INT8,
1173+
dims=(16,),
1174+
vals=np.zeros([16]).astype(np.int8).tobytes(),
1175+
raw=True)],
1176+
value_info=[helper.make_tensor_value_info(
1177+
"C", TensorProto.FLOAT, (16, 1, 1))]
1178+
)
1179+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"], opset_imports=[helper.make_opsetid("", 13)])
1180+
1181+
assert len(list(optimized_model.graph.node)) == 4
1182+
assert optimized_model.graph.node[0].op_type == 'Squeeze'
1183+
assert optimized_model.graph.node[1].op_type == 'QuantizeLinear'
1184+
assert optimized_model.graph.node[2].op_type == 'DequantizeLinear'
1185+
assert optimized_model.graph.node[3].op_type == 'Conv'
1186+
assert optimized_model.graph.output[0].name == 'D'
1187+
11531188
def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None
11541189
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
11551190
add = helper.make_node("Add", ["Z", "B"], ["A"])

third_party/onnx

Submodule onnx updated 302 files

0 commit comments

Comments
 (0)