Skip to content

Commit cf3ddd3

Browse files
authored
Pass compat of conv_transpose_bias_mkldnn_fuse_pass (#33708)
1 parent 1828426 commit cf3ddd3

File tree

3 files changed

+117
-1
lines changed

3 files changed

+117
-1
lines changed

paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,102 @@ namespace paddle {
2525
namespace framework {
2626
namespace ir {
2727

28+
ConvBiasFusePass::ConvBiasFusePass() {
29+
AddOpCompat(OpCompat("conv2d"))
30+
.AddInput("Input")
31+
.IsTensor()
32+
.End()
33+
.AddInput("Filter")
34+
.IsTensor()
35+
.End()
36+
.AddInput("Bias")
37+
.IsTensor()
38+
.IsOptional()
39+
.End()
40+
.AddOutput("Output")
41+
.IsTensor()
42+
.End()
43+
.AddAttr("strides")
44+
.End()
45+
.AddAttr("paddings")
46+
.End()
47+
.AddAttr("padding_algorithm")
48+
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
49+
.End()
50+
.AddAttr("groups")
51+
.IsNumGE(1)
52+
.End()
53+
.AddAttr("dilations")
54+
.End()
55+
.AddAttr("data_format")
56+
.IsStringIn({"NCHW", "NHWC"})
57+
.End();
58+
59+
AddOpCompat(OpCompat("elementwise_add"))
60+
.AddInput("X")
61+
.IsTensor()
62+
.End()
63+
.AddInput("Y")
64+
.IsTensor()
65+
.End()
66+
.AddOutput("Out")
67+
.IsTensor()
68+
.End()
69+
.AddAttr("axis")
70+
.IsNumEQ(-1)
71+
.End();
72+
}
73+
74+
Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
75+
AddOpCompat(OpCompat("conv2d_transpose"))
76+
.AddInput("Input")
77+
.IsTensor()
78+
.End()
79+
.AddInput("Filter")
80+
.IsTensor()
81+
.End()
82+
.AddInput("Bias")
83+
.IsTensor()
84+
.End()
85+
.AddOutput("Output")
86+
.IsTensor()
87+
.End()
88+
.AddAttr("output_padding")
89+
.End()
90+
.AddAttr("output_size")
91+
.IsNumGE(1)
92+
.End()
93+
.AddAttr("groups")
94+
.IsNumGE(1)
95+
.End()
96+
.AddAttr("dilations")
97+
.End()
98+
.AddAttr("strides")
99+
.End()
100+
.AddAttr("paddings")
101+
.End()
102+
.AddAttr("padding_algorithm")
103+
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
104+
.End()
105+
.AddAttr("data_format")
106+
.IsStringIn({"NCHW", "NHWC"})
107+
.End();
108+
109+
AddOpCompat(OpCompat("elementwise_add"))
110+
.AddInput("X")
111+
.IsTensor()
112+
.End()
113+
.AddInput("Y")
114+
.IsTensor()
115+
.End()
116+
.AddOutput("Out")
117+
.IsTensor()
118+
.End()
119+
.AddAttr("axis")
120+
.IsNumEQ(-1)
121+
.End();
122+
}
123+
28124
template <typename BinaryOperation>
29125
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
30126
BinaryOperation f) {
@@ -80,6 +176,12 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
80176
subgraph.count(conv_input), 0,
81177
platform::errors::NotFound("Detector did not find conv input."));
82178

179+
// check compat
180+
if (!IsCompat(subgraph, g)) {
181+
VLOG(3) << "Pass in op compat failed.";
182+
return;
183+
}
184+
83185
// check if fuse can be done and if MKL-DNN should be used
84186
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
85187
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {

paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Graph;
2929

3030
class ConvBiasFusePass : public FusePassBase {
3131
public:
32+
ConvBiasFusePass();
3233
virtual ~ConvBiasFusePass() {}
3334
virtual std::string type() const { return "conv2d"; }
3435

@@ -41,6 +42,7 @@ class ConvBiasFusePass : public FusePassBase {
4142
*/
4243
class Conv2DTransposeBiasFusePass : public ConvBiasFusePass {
4344
public:
45+
Conv2DTransposeBiasFusePass();
4446
std::string type() const override { return "conv2d_transpose"; }
4547
};
4648

paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
3131
auto* op = prog->MutableBlock(0)->AppendOp();
3232
op->SetType(type);
3333
if (type == "conv2d") {
34+
const std::vector<int> strides({1, 1});
35+
const std::vector<int> paddings({0, 0});
36+
const std::vector<int> dilations({1, 1});
3437
op->SetAttr("use_mkldnn", true);
3538
op->SetAttr("name", name);
39+
op->SetAttr("strides", strides);
40+
op->SetAttr("groups", 1);
41+
op->SetAttr("paddings", paddings);
42+
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
43+
op->SetAttr("dilations", dilations);
44+
op->SetAttr("data_format", std::string("NCHW"));
45+
46+
op->SetOutput("Output", outputs);
3647
op->SetInput("Input", {inputs[0]});
3748
op->SetInput("Filter", {inputs[1]});
3849
if (inputs.size() > 2)
@@ -41,10 +52,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
4152
op->SetInput("Bias", {});
4253
} else if (type == "elementwise_add") {
4354
op->SetAttr("use_mkldnn", true);
55+
op->SetAttr("axis", -1);
4456
op->SetInput("X", {inputs[0]});
4557
op->SetInput("Y", {inputs[1]});
58+
op->SetOutput("Out", outputs);
4659
}
47-
op->SetOutput("Out", outputs);
4860
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
4961
static_cast<int>(OpRole::kForward));
5062
}

0 commit comments

Comments
 (0)