Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.

Commit e3658ae

Browse files
author
yaozhixin
authored
add affine_channel and clip (PaddlePaddle#826)
* add affine_channel and clip
1 parent 7e057f4 commit e3658ae

File tree

4 files changed

+790
-143
lines changed

4 files changed

+790
-143
lines changed

paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc

Lines changed: 144 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,32 @@ Node *conv2d_handler(Graph *graph, Node *node) {
3535
auto stride_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("strides"));
3636
auto stride = std::vector<int64_t>{stride_.begin(), stride_.end()};
3737
if (!op->Input("Bias").empty()) {
38-
return CreateConv(graph, node,
38+
return CreateConv(graph,
39+
node,
3940
{
4041
GetInputVarNode("Input", node),
4142
GetInputVarNode("Filter", node),
4243
GetInputVarNode("Bias", node),
4344
},
44-
node->outputs, dilations, group_, {}, pads, stride);
45+
node->outputs,
46+
dilations,
47+
group_,
48+
{},
49+
pads,
50+
stride);
4551
} else {
46-
return CreateConv(graph, node,
52+
return CreateConv(graph,
53+
node,
4754
{
4855
GetInputVarNode("Input", node),
4956
GetInputVarNode("Filter", node),
5057
},
51-
node->outputs, dilations, group_, {}, pads, stride);
58+
node->outputs,
59+
dilations,
60+
group_,
61+
{},
62+
pads,
63+
stride);
5264
}
5365
}
5466

@@ -83,7 +95,11 @@ Node *batch_norm_handler(Graph *graph, Node *node) {
8395
auto momentum = BOOST_GET_CONST(float, op->GetAttr("momentum"));
8496
auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
8597
// data_layout
86-
return CreateBaseOp(graph, node, "popart_batchnormalization", inputs, outputs,
98+
return CreateBaseOp(graph,
99+
node,
100+
"popart_batchnormalization",
101+
inputs,
102+
outputs,
87103
{
88104
{"momentum", momentum},
89105
{"epsilon", epsilon},
@@ -105,18 +121,18 @@ Node *pool2d_handler(Graph *graph, Node *node) {
105121
}
106122
// adaptive maxpool op is max_pool2d_with_index. Only process avgpool
107123
// here.
108-
return CreateBaseOp(graph, node, "popart_globalaveragepool", node->inputs,
109-
node->outputs);
124+
return CreateBaseOp(
125+
graph, node, "popart_globalaveragepool", node->inputs, node->outputs);
110126
}
111127
}
112128

113129
if (global_pooling) {
114130
if (pooling_type == "max") {
115-
return CreateBaseOp(graph, node, "popart_globalmaxpool", node->inputs,
116-
node->outputs);
131+
return CreateBaseOp(
132+
graph, node, "popart_globalmaxpool", node->inputs, node->outputs);
117133
} else if (pooling_type == "avg") {
118-
return CreateBaseOp(graph, node, "popart_globalaveragepool", node->inputs,
119-
node->outputs);
134+
return CreateBaseOp(
135+
graph, node, "popart_globalaveragepool", node->inputs, node->outputs);
120136
} else {
121137
PADDLE_THROW(platform::errors::InvalidArgument(
122138
"op pool2d with unkonwn pooling_type: %s", pooling_type));
@@ -147,7 +163,10 @@ Node *pool2d_handler(Graph *graph, Node *node) {
147163
int64_t num_outputs = 1;
148164
auto dilations = std::vector<int64_t>{};
149165
int64_t storage_order = 0;
150-
return CreateBaseOp(graph, node, "popart_maxpool", node->inputs,
166+
return CreateBaseOp(graph,
167+
node,
168+
"popart_maxpool",
169+
node->inputs,
151170
node->outputs,
152171
{
153172
{"num_outputs", num_outputs},
@@ -160,7 +179,10 @@ Node *pool2d_handler(Graph *graph, Node *node) {
160179
});
161180
} else if (pooling_type == "avg") {
162181
int64_t count_include_pad = 0;
163-
return CreateBaseOp(graph, node, "popart_averagepool", node->inputs,
182+
return CreateBaseOp(graph,
183+
node,
184+
"popart_averagepool",
185+
node->inputs,
164186
node->outputs,
165187
{
166188
{"kernel_shape", kernel_shape},
@@ -182,7 +204,10 @@ Node *max_pool2d_with_index_handler(Graph *graph, Node *node) {
182204
PADDLE_THROW(platform::errors::InvalidArgument(
183205
"Only support pool_size=1 with adaptive mode."));
184206
}
185-
return CreateBaseOp(graph, node, "popart_globalmaxpool", node->inputs,
207+
return CreateBaseOp(graph,
208+
node,
209+
"popart_globalmaxpool",
210+
node->inputs,
186211
{GetOutputVarNode("Out", node)});
187212
}
188213

@@ -199,8 +224,8 @@ Node *group_norm_handler(Graph *graph, Node *node) {
199224
std::vector<Node *> outputs_ = {GetOutputVarNode("Y", node),
200225
GetOutputVarNode("Mean", node),
201226
GetOutputVarNode("Variance", node)};
202-
return CreateBaseOp(graph, node, "popart_groupnormalization_v2", inputs_,
203-
outputs_, attrs_);
227+
return CreateBaseOp(
228+
graph, node, "popart_groupnormalization_v2", inputs_, outputs_, attrs_);
204229
}
205230

206231
Node *instance_norm_handler(Graph *graph, Node *node) {
@@ -212,8 +237,8 @@ Node *instance_norm_handler(Graph *graph, Node *node) {
212237
GetInputVarNode("Scale", node),
213238
GetInputVarNode("Bias", node)};
214239
std::vector<Node *> outputs_ = {GetOutputVarNode("Y", node)};
215-
return CreateBaseOp(graph, node, "popart_instancenormalization", inputs_,
216-
outputs_, attrs_);
240+
return CreateBaseOp(
241+
graph, node, "popart_instancenormalization", inputs_, outputs_, attrs_);
217242
}
218243

219244
Node *layer_norm_handler(Graph *graph, Node *node) {
@@ -227,13 +252,16 @@ Node *layer_norm_handler(Graph *graph, Node *node) {
227252
AttributeMap{{"epsilon", epsilon_}, {"num_groups", groups_}};
228253

229254
if (input_shape_.size() == 2) {
230-
return CreateBaseOp(
231-
graph, node, "popart_groupnormalization_v2",
232-
{GetInputVarNode("X", node), GetInputVarNode("Scale", node),
233-
GetInputVarNode("Bias", node)},
234-
{GetOutputVarNode("Y", node), GetOutputVarNode("Mean", node),
235-
GetOutputVarNode("Variance", node)},
236-
groupnorm_attrs_);
255+
return CreateBaseOp(graph,
256+
node,
257+
"popart_groupnormalization_v2",
258+
{GetInputVarNode("X", node),
259+
GetInputVarNode("Scale", node),
260+
GetInputVarNode("Bias", node)},
261+
{GetOutputVarNode("Y", node),
262+
GetOutputVarNode("Mean", node),
263+
GetOutputVarNode("Variance", node)},
264+
groupnorm_attrs_);
237265
}
238266

239267
std::vector<int64_t> norm_shape_{1, 1};
@@ -251,15 +279,23 @@ Node *layer_norm_handler(Graph *graph, Node *node) {
251279
{"dtype", ONNXDataType::INT64}};
252280
auto reshape1_const =
253281
CreateBaseOp(graph, node, "popart_constant", {}, {}, attrs1);
254-
auto new_node_reshape1 = CreateBaseOp(
255-
graph, node, "popart_reshape",
256-
{GetInputVarNode("X", node), reshape1_const->outputs[0]}, {}, {});
282+
auto new_node_reshape1 =
283+
CreateBaseOp(graph,
284+
node,
285+
"popart_reshape",
286+
{GetInputVarNode("X", node), reshape1_const->outputs[0]},
287+
{},
288+
{});
257289

258290
auto out_Y_ = MakeVarNode(graph, node);
259-
CreateBaseOp(graph, node, "popart_groupnormalization_v2",
260-
{new_node_reshape1->outputs[0], GetInputVarNode("Scale", node),
291+
CreateBaseOp(graph,
292+
node,
293+
"popart_groupnormalization_v2",
294+
{new_node_reshape1->outputs[0],
295+
GetInputVarNode("Scale", node),
261296
GetInputVarNode("Bias", node)},
262-
{out_Y_, GetOutputVarNode("Mean", node),
297+
{out_Y_,
298+
GetOutputVarNode("Mean", node),
263299
GetOutputVarNode("Variance", node)},
264300
groupnorm_attrs_);
265301

@@ -269,9 +305,12 @@ Node *layer_norm_handler(Graph *graph, Node *node) {
269305
{"dtype", ONNXDataType::INT64}};
270306
auto reshape2_const =
271307
CreateBaseOp(graph, node, "popart_constant", {}, {}, attrs2);
272-
auto new_node_reshape2 = CreateBaseOp(graph, node, "popart_reshape",
308+
auto new_node_reshape2 = CreateBaseOp(graph,
309+
node,
310+
"popart_reshape",
273311
{out_Y_, reshape2_const->outputs[0]},
274-
{GetOutputVarNode("Y", node)}, {});
312+
{GetOutputVarNode("Y", node)},
313+
{});
275314
return new_node_reshape2;
276315
}
277316

@@ -292,18 +331,27 @@ Node *dropout_handler(Graph *graph, Node *node) {
292331

293332
if (is_test_) {
294333
if (dropout_implementation_ == "upscale_in_train") {
295-
return CreateBaseOp(graph, node, "popart_identity",
334+
return CreateBaseOp(graph,
335+
node,
336+
"popart_identity",
296337
{GetInputVarNode("X", node)},
297-
{GetOutputVarNode("Out", node)}, {});
338+
{GetOutputVarNode("Out", node)},
339+
{});
298340
} else if (dropout_implementation_ == "downgrade_in_infer") {
299341
auto scale =
300-
CreateConst(graph, node, {}, {},
342+
CreateConst(graph,
343+
node,
344+
{},
345+
{},
301346
{{"value", std::vector<float>{1 - dropout_prob_}},
302347
{"dims", std::vector<int64_t>{1}},
303348
{"dtype", GetOutputVarDType(node)}});
304-
return CreateBaseOp(graph, node, "popart_mul",
349+
return CreateBaseOp(graph,
350+
node,
351+
"popart_mul",
305352
{GetInputVarNode("X", node), scale->outputs[0]},
306-
{GetOutputVarNode("Out", node)}, {});
353+
{GetOutputVarNode("Out", node)},
354+
{});
307355
} else {
308356
PADDLE_THROW(
309357
platform::errors::InvalidArgument("Invalid dropout_implementation"));
@@ -312,9 +360,12 @@ Node *dropout_handler(Graph *graph, Node *node) {
312360
if (dropout_implementation_ == "upscale_in_train") {
313361
auto attrs_ =
314362
AttributeMap{{"num_outputs", (int64_t)1}, {"ratio", dropout_prob_}};
315-
return CreateBaseOp(graph, node, "popart_dropout",
363+
return CreateBaseOp(graph,
364+
node,
365+
"popart_dropout",
316366
{GetInputVarNode("X", node)},
317-
{GetOutputVarNode("Out", node)}, attrs_);
367+
{GetOutputVarNode("Out", node)},
368+
attrs_);
318369
} else if (dropout_implementation_ == "downgrade_in_infer") {
319370
PADDLE_THROW(platform::errors::InvalidArgument(
320371
"Do not support downgrade_in_infer with training"));
@@ -388,28 +439,77 @@ Node *conv2d_transpose_handler(Graph *graph, Node *node) {
388439
{"pads", paddings},
389440
{"strides", strides}};
390441
if (!op->Input("Bias").empty()) {
391-
return CreateBaseOp(graph, node, "popart_convtranspose",
442+
return CreateBaseOp(graph,
443+
node,
444+
"popart_convtranspose",
392445
{
393446
GetInputVarNode("Input", node),
394447
GetInputVarNode("Filter", node),
395448
GetInputVarNode("Bias", node),
396449
},
397-
node->outputs, attrs);
450+
node->outputs,
451+
attrs);
398452
} else {
399-
return CreateBaseOp(graph, node, "popart_convtranspose",
453+
return CreateBaseOp(graph,
454+
node,
455+
"popart_convtranspose",
400456
{
401457
GetInputVarNode("Input", node),
402458
GetInputVarNode("Filter", node),
403459
},
404-
node->outputs, attrs);
460+
node->outputs,
461+
attrs);
462+
}
463+
}
464+
465+
Node *affine_channel_handler(Graph *graph, Node *node) {
466+
auto *op = node->Op();
467+
468+
auto data_layout = BOOST_GET_CONST(std::string, op->GetAttr("data_layout"));
469+
if (data_layout != "NCHW") {
470+
platform::errors::InvalidArgument("Only support NCHW as data_format.");
471+
}
472+
473+
auto *scale = GetInputVarNode("Scale", node);
474+
auto *bias = GetInputVarNode("Bias", node);
475+
auto scale_shape = scale->Var()->GetShape();
476+
auto bias_shape = bias->Var()->GetShape();
477+
if (scale_shape.size() <= 1 || bias_shape.size() <= 1) {
478+
auto attrs = AttributeMap{{"value", std::vector<int64_t>{1, -1, 1, 1}},
479+
{"dims", std::vector<int64_t>{4}},
480+
{"dtype", ONNXDataType::INT64}};
481+
auto new_shape_const = CreateConst(graph, node, {}, {}, attrs);
482+
483+
scale = CreateBaseOp(graph,
484+
node,
485+
"popart_reshape",
486+
{scale, new_shape_const->outputs[0]},
487+
{},
488+
{})
489+
->outputs[0];
490+
bias = CreateBaseOp(graph,
491+
node,
492+
"popart_reshape",
493+
{bias, new_shape_const->outputs[0]},
494+
{},
495+
{})
496+
->outputs[0];
405497
}
498+
auto *out = CreateBaseOp(
499+
graph, node, "popart_mul", {GetInputVarNode("X", node), scale}, {});
500+
return CreateBaseOp(graph,
501+
node,
502+
"popart_add",
503+
{out->outputs[0], bias},
504+
{GetOutputVarNode("Out", node)});
406505
}
407506

408507
} // namespace
409508
} // namespace ipu
410509
} // namespace platform
411510
} // namespace paddle
412511

512+
REGISTER_HANDLER(affine_channel, affine_channel_handler);
413513
REGISTER_HANDLER(pool2d, pool2d_handler);
414514
REGISTER_HANDLER(max_pool2d_with_index, max_pool2d_with_index_handler);
415515
REGISTER_HANDLER(batch_norm, batch_norm_handler);

0 commit comments

Comments
 (0)