Skip to content

Commit 532b10c

Browse files
author
wenyuchi.wyc
committed
Support fuse add into ConvTranspose.
Signed-off-by: wenyuchi.wyc <[email protected]>
1 parent 807cff7 commit 532b10c

File tree

2 files changed

+218
-4
lines changed

2 files changed

+218
-4
lines changed

onnxoptimizer/passes/fuse_add_bias_into_conv.h

+16-4
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,21 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
3333
std::string getPassName() const override {
3434
return "fuse_add_bias_into_conv";
3535
}
36+
37+
inline bool matchConvAdd(Node *node) {
38+
return node->kind() == kAdd && node->inputs()[0]->node()->kind() == kConv &&
39+
node->inputs()[0]->node()->inputs().size() == 2;
40+
}
41+
42+
inline bool matchConvTransposeAdd(Node *node) {
43+
return node->kind() == kAdd && node->inputs()[0]->node()->kind() == kConvTranspose &&
44+
node->inputs()[0]->node()->inputs().size() == 2;
45+
}
46+
3647
bool patternMatchPredicate(Node *node) override {
37-
return CheckKind(node, kAdd, 0, kConv) &&
38-
GetInputsOfPreNode(node, 0).size() == 2;
48+
return matchConvAdd(node) || matchConvTransposeAdd(node);
3949
}
50+
4051
static Node *makeSqueezeOrUnsqueeze(Graph &graph, std::vector<int64_t> &axes,
4152
Value *input, Node *target_node,
4253
BuiltinSymbol k) {
@@ -62,6 +73,7 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
6273
NodeDestroyType &destroy_current) override {
6374
// due to current broadcasting's constraint, Conv has to be the first
6475
// operand
76+
const bool is_conv = matchConvAdd(n);
6577
destroy_current = NodeDestroyType::DestroyZero;
6678
auto orig_conv = n->inputs()[0];
6779
auto orig_bias = n->inputs()[1];
@@ -86,8 +98,8 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
8698
}
8799
// try to get feature M and rank from weight_shape
88100
if (weight_shape.size() > 0 && weight_shape[0].is_int) {
89-
ONNX_ASSERT(M == -1 || M == weight_shape[0].dim);
90-
M = weight_shape[0].dim;
101+
ONNX_ASSERT(M == -1 || M == weight_shape[0].dim || M == weight_shape[1].dim);
102+
M = is_conv ? weight_shape[0].dim : weight_shape[1].dim;
91103
ONNX_ASSERT(rank == -1 ||
92104
rank == static_cast<int64_t>(weight_shape.size()));
93105
rank = weight_shape.size();

onnxoptimizer/test/optimizer_test.py

+202
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,208 @@ def test_fuse_add_bias_into_conv_squeeze_4d_bias_no_fuse(self):
14241424
assert optimized_model.graph.node[0].op_type == "Conv"
14251425
assert optimized_model.graph.node[1].op_type == "Add"
14261426

1427+
def test_fuse_add_bias_into_conv_transpose_with_scalar_bias(self): # type: () -> None
1428+
nodes = [
1429+
helper.make_node("ConvTranspose", ["X", "Y"], ["Z"], strides=(2, 2)),
1430+
helper.make_node("Add", ["Z", "A"], ["B"]),
1431+
]
1432+
graph = helper.make_graph(
1433+
nodes,
1434+
"test",
1435+
[
1436+
helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 160, 160)),
1437+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 16, 2, 2)),
1438+
helper.make_tensor_value_info("A", TensorProto.FLOAT, ()),
1439+
],
1440+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16, 320, 320))],
1441+
)
1442+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])
1443+
1444+
# Unsqueeze, Conv
1445+
assert len(optimized_model.graph.node) == 4
1446+
assert optimized_model.graph.node[0].op_type == "Unsqueeze"
1447+
assert optimized_model.graph.node[1].op_type == "Constant"
1448+
assert optimized_model.graph.node[2].op_type == "Tile"
1449+
assert optimized_model.graph.node[3].op_type == "ConvTranspose"
1450+
1451+
def test_fuse_add_bias_into_conv_transpose_use_weight_shape(self): # type: () -> None
1452+
nodes = [
1453+
helper.make_node("ConvTranspose", ["X", "Y"], ["Z"], strides=(2, 2)),
1454+
helper.make_node("Add", ["Z", "A"], ["B"]),
1455+
]
1456+
# FIXME(daquexian): It looks like subgraph cannot get value info from parent subgraph
1457+
# nodes.extend(self._make_fake_loop_op(
1458+
# [helper.make_node("Conv", ["_X", "Y"], ["_Z"]),
1459+
# helper.make_node("Add", ["_Z", "A"], ["_B2"])],
1460+
# [(TensorProto.FLOAT, (1, 5, 3, 3), "X")],
1461+
# [(TensorProto.FLOAT, (1, 16, 1, 1), "B2")]))
1462+
graph = helper.make_graph(
1463+
nodes,
1464+
"test",
1465+
[
1466+
helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 160, 160)),
1467+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 16, 2, 2)),
1468+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1)),
1469+
],
1470+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16, 320, 320))],
1471+
)
1472+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])
1473+
1474+
# # Squeeze, Conv, Constant (trip count), Constant (condition), Loop
1475+
# assert len(list(optimized_model.graph.node)) == 5
1476+
assert len(list(optimized_model.graph.node)) == 2
1477+
assert optimized_model.graph.node[0].op_type == "Squeeze"
1478+
assert optimized_model.graph.node[1].op_type == "ConvTranspose"
1479+
assert optimized_model.graph.output[0].name == "B"
1480+
# # Squeeze, Conv
1481+
# assert len(optimized_model.graph.node[4].attribute[0].g.node) == 2
1482+
# assert optimized_model.graph.node[4].attribute[0].g.node[0].op_type == 'Squeeze'
1483+
# assert optimized_model.graph.node[4].attribute[0].g.node[1].op_type == 'Conv'
1484+
# # Output 1 since 0 is 'cond'
1485+
# assert optimized_model.graph.node[4].attribute[0].g.output[1].name == 'B2'
1486+
1487+
# type: () -> None
1488+
def test_fuse_add_bias_into_conv_transpose_use_weight_shape_with_tile(self):
1489+
conv = helper.make_node("ConvTranspose", ["X", "Y"], ["Z"], strides=(2, 2))
1490+
add = helper.make_node("Add", ["Z", "A"], ["B"])
1491+
graph = helper.make_graph(
1492+
[conv, add],
1493+
"test",
1494+
[
1495+
helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 160, 160)),
1496+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 16, 2, 2)),
1497+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (1,)),
1498+
],
1499+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16, 320, 320))],
1500+
)
1501+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])
1502+
1503+
assert len(list(optimized_model.graph.node)) == 3
1504+
assert len(optimized_model.graph.value_info) == 1
1505+
assert (
1506+
optimized_model.graph.value_info[0].type.tensor_type.elem_type
1507+
== TensorProto.INT64
1508+
)
1509+
assert len(optimized_model.graph.value_info[0].type.tensor_type.shape.dim) == 1
1510+
assert optimized_model.graph.node[0].op_type == "Constant"
1511+
assert optimized_model.graph.node[1].op_type == "Tile"
1512+
assert optimized_model.graph.node[2].op_type == "ConvTranspose"
1513+
assert optimized_model.graph.output[0].name == "B"
1514+
1515+
def test_fuse_add_bias_into_conv_transpose_use_conv_shape(self): # type: () -> None
1516+
sub = helper.make_node("Sub", ["M", "N"], ["Y"])
1517+
conv = helper.make_node("ConvTranspose", ["X", "Y"], ["Z"], strides=(2, 2))
1518+
add = helper.make_node("Add", ["Z", "A"], ["B"])
1519+
graph = helper.make_graph(
1520+
[sub, conv, add],
1521+
"test",
1522+
[
1523+
helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 160, 160)),
1524+
helper.make_tensor_value_info("M", TensorProto.FLOAT, (3, 16, 2, 2)),
1525+
helper.make_tensor_value_info("N", TensorProto.FLOAT, (3, 16, 2, 2)),
1526+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (1, 16, 1, 1)),
1527+
],
1528+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16, 320, 320))],
1529+
value_info=[
1530+
helper.make_tensor_value_info("Z", TensorProto.FLOAT, (1, 16, 320, 320))
1531+
],
1532+
)
1533+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])
1534+
1535+
assert len(optimized_model.graph.node) == 3
1536+
assert optimized_model.graph.node[0].op_type == "Sub"
1537+
assert optimized_model.graph.node[1].op_type == "Squeeze"
1538+
assert optimized_model.graph.node[2].op_type == "ConvTranspose"
1539+
assert optimized_model.graph.output[0].name == "B"
1540+
assert (
1541+
optimized_model.graph.output[0].type.tensor_type.elem_type
1542+
== TensorProto.FLOAT
1543+
)
1544+
assert len(optimized_model.graph.output[0].type.tensor_type.shape.dim) == 4
1545+
1546+
# type: () -> None
1547+
def test_fuse_add_bias_into_conv_transpose_use_move_constant(self):
1548+
conv = helper.make_node("ConvTranspose", ["X", "Y"], ["Z"], strides=(2, 2))
1549+
constant = helper.make_node(
1550+
"Constant",
1551+
[],
1552+
["A"],
1553+
value=helper.make_tensor(
1554+
name="bias",
1555+
data_type=TensorProto.FLOAT,
1556+
dims=(16, 1, 1),
1557+
vals=np.random.randn(16).astype(np.float32).tolist(),
1558+
),
1559+
)
1560+
add = helper.make_node("Add", ["Z", "A"], ["B"])
1561+
graph = helper.make_graph(
1562+
[conv, constant, add],
1563+
"test",
1564+
[
1565+
helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 160, 160)),
1566+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 16, 2, 2)),
1567+
],
1568+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16, 320, 320))],
1569+
value_info=[
1570+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1)),
1571+
],
1572+
)
1573+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])
1574+
1575+
assert len(optimized_model.graph.node) == 3
1576+
assert optimized_model.graph.node[0].op_type == "Constant"
1577+
assert optimized_model.graph.node[1].op_type == "Squeeze"
1578+
assert optimized_model.graph.node[2].op_type == "ConvTranspose"
1579+
assert optimized_model.graph.output[0].name == "B"
1580+
assert (
1581+
optimized_model.graph.output[0].type.tensor_type.elem_type
1582+
== TensorProto.FLOAT
1583+
)
1584+
assert len(optimized_model.graph.output[0].type.tensor_type.shape.dim) == 4
1585+
1586+
# type: () -> None
1587+
def test_fuse_add_bias_into_conv_transpose_squeeze_1d_bias_no_fuse(self):
1588+
conv = helper.make_node("ConvTranspose", ["X", "Y"], ["Z"], strides=(2, 2))
1589+
add = helper.make_node("Add", ["Z", "A"], ["B"])
1590+
graph = helper.make_graph(
1591+
[conv, add],
1592+
"test",
1593+
[
1594+
helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 160, 160)),
1595+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 16, 2, 2)),
1596+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (320,)),
1597+
],
1598+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16, 320, 320))],
1599+
value_info=[
1600+
helper.make_tensor_value_info("Z", TensorProto.FLOAT, (1, 16, 320, 320)),
1601+
],
1602+
)
1603+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])
1604+
1605+
assert len(list(optimized_model.graph.node)) == 2
1606+
assert optimized_model.graph.node[0].op_type == "ConvTranspose"
1607+
assert optimized_model.graph.node[1].op_type == "Add"
1608+
1609+
# type: () -> None
1610+
def test_fuse_add_bias_into_conv_transpose_squeeze_4d_bias_no_fuse(self):
1611+
conv = helper.make_node("ConvTranspose", ["X", "Y"], ["Z"], strides=(2, 2))
1612+
add = helper.make_node("Add", ["Z", "A"], ["B"])
1613+
graph = helper.make_graph(
1614+
[conv, add],
1615+
"test",
1616+
[
1617+
helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 160, 160)),
1618+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 16, 2, 2)),
1619+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (1, 16, 320, 320)),
1620+
],
1621+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16, 320, 320))],
1622+
)
1623+
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])
1624+
1625+
assert len(list(optimized_model.graph.node)) == 2
1626+
assert optimized_model.graph.node[0].op_type == "ConvTranspose"
1627+
assert optimized_model.graph.node[1].op_type == "Add"
1628+
14271629
def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None
14281630
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
14291631
add = helper.make_node("Add", ["Z", "B"], ["A"])

0 commit comments

Comments
 (0)