@@ -35,20 +35,32 @@ Node *conv2d_handler(Graph *graph, Node *node) {
3535 auto stride_ = BOOST_GET_CONST (std::vector<int >, op->GetAttr (" strides" ));
3636 auto stride = std::vector<int64_t >{stride_.begin (), stride_.end ()};
3737 if (!op->Input (" Bias" ).empty ()) {
38- return CreateConv (graph, node,
38+ return CreateConv (graph,
39+ node,
3940 {
4041 GetInputVarNode (" Input" , node),
4142 GetInputVarNode (" Filter" , node),
4243 GetInputVarNode (" Bias" , node),
4344 },
44- node->outputs , dilations, group_, {}, pads, stride);
45+ node->outputs ,
46+ dilations,
47+ group_,
48+ {},
49+ pads,
50+ stride);
4551 } else {
46- return CreateConv (graph, node,
52+ return CreateConv (graph,
53+ node,
4754 {
4855 GetInputVarNode (" Input" , node),
4956 GetInputVarNode (" Filter" , node),
5057 },
51- node->outputs , dilations, group_, {}, pads, stride);
58+ node->outputs ,
59+ dilations,
60+ group_,
61+ {},
62+ pads,
63+ stride);
5264 }
5365}
5466
@@ -83,7 +95,11 @@ Node *batch_norm_handler(Graph *graph, Node *node) {
8395 auto momentum = BOOST_GET_CONST (float , op->GetAttr (" momentum" ));
8496 auto epsilon = BOOST_GET_CONST (float , op->GetAttr (" epsilon" ));
8597 // data_layout
86- return CreateBaseOp (graph, node, " popart_batchnormalization" , inputs, outputs,
98+ return CreateBaseOp (graph,
99+ node,
100+ " popart_batchnormalization" ,
101+ inputs,
102+ outputs,
87103 {
88104 {" momentum" , momentum},
89105 {" epsilon" , epsilon},
@@ -105,18 +121,18 @@ Node *pool2d_handler(Graph *graph, Node *node) {
105121 }
106122 // adaptive maxpool op is max_pool2d_with_index. Only process avgpool
107123 // here.
108- return CreateBaseOp (graph, node, " popart_globalaveragepool " , node-> inputs ,
109- node->outputs );
124+ return CreateBaseOp (
125+ graph, node, " popart_globalaveragepool " , node-> inputs , node->outputs );
110126 }
111127 }
112128
113129 if (global_pooling) {
114130 if (pooling_type == " max" ) {
115- return CreateBaseOp (graph, node, " popart_globalmaxpool " , node-> inputs ,
116- node->outputs );
131+ return CreateBaseOp (
132+ graph, node, " popart_globalmaxpool " , node-> inputs , node->outputs );
117133 } else if (pooling_type == " avg" ) {
118- return CreateBaseOp (graph, node, " popart_globalaveragepool " , node-> inputs ,
119- node->outputs );
134+ return CreateBaseOp (
135+ graph, node, " popart_globalaveragepool " , node-> inputs , node->outputs );
120136 } else {
121137 PADDLE_THROW (platform::errors::InvalidArgument (
122138 " op pool2d with unkonwn pooling_type: %s" , pooling_type));
@@ -147,7 +163,10 @@ Node *pool2d_handler(Graph *graph, Node *node) {
147163 int64_t num_outputs = 1 ;
148164 auto dilations = std::vector<int64_t >{};
149165 int64_t storage_order = 0 ;
150- return CreateBaseOp (graph, node, " popart_maxpool" , node->inputs ,
166+ return CreateBaseOp (graph,
167+ node,
168+ " popart_maxpool" ,
169+ node->inputs ,
151170 node->outputs ,
152171 {
153172 {" num_outputs" , num_outputs},
@@ -160,7 +179,10 @@ Node *pool2d_handler(Graph *graph, Node *node) {
160179 });
161180 } else if (pooling_type == " avg" ) {
162181 int64_t count_include_pad = 0 ;
163- return CreateBaseOp (graph, node, " popart_averagepool" , node->inputs ,
182+ return CreateBaseOp (graph,
183+ node,
184+ " popart_averagepool" ,
185+ node->inputs ,
164186 node->outputs ,
165187 {
166188 {" kernel_shape" , kernel_shape},
@@ -182,7 +204,10 @@ Node *max_pool2d_with_index_handler(Graph *graph, Node *node) {
182204 PADDLE_THROW (platform::errors::InvalidArgument (
183205 " Only support pool_size=1 with adaptive mode." ));
184206 }
185- return CreateBaseOp (graph, node, " popart_globalmaxpool" , node->inputs ,
207+ return CreateBaseOp (graph,
208+ node,
209+ " popart_globalmaxpool" ,
210+ node->inputs ,
186211 {GetOutputVarNode (" Out" , node)});
187212}
188213
@@ -199,8 +224,8 @@ Node *group_norm_handler(Graph *graph, Node *node) {
199224 std::vector<Node *> outputs_ = {GetOutputVarNode (" Y" , node),
200225 GetOutputVarNode (" Mean" , node),
201226 GetOutputVarNode (" Variance" , node)};
202- return CreateBaseOp (graph, node, " popart_groupnormalization_v2 " , inputs_,
203- outputs_, attrs_);
227+ return CreateBaseOp (
228+ graph, node, " popart_groupnormalization_v2 " , inputs_, outputs_, attrs_);
204229}
205230
206231Node *instance_norm_handler (Graph *graph, Node *node) {
@@ -212,8 +237,8 @@ Node *instance_norm_handler(Graph *graph, Node *node) {
212237 GetInputVarNode (" Scale" , node),
213238 GetInputVarNode (" Bias" , node)};
214239 std::vector<Node *> outputs_ = {GetOutputVarNode (" Y" , node)};
215- return CreateBaseOp (graph, node, " popart_instancenormalization " , inputs_,
216- outputs_, attrs_);
240+ return CreateBaseOp (
241+ graph, node, " popart_instancenormalization " , inputs_, outputs_, attrs_);
217242}
218243
219244Node *layer_norm_handler (Graph *graph, Node *node) {
@@ -227,13 +252,16 @@ Node *layer_norm_handler(Graph *graph, Node *node) {
227252 AttributeMap{{" epsilon" , epsilon_}, {" num_groups" , groups_}};
228253
229254 if (input_shape_.size () == 2 ) {
230- return CreateBaseOp (
231- graph, node, " popart_groupnormalization_v2" ,
232- {GetInputVarNode (" X" , node), GetInputVarNode (" Scale" , node),
233- GetInputVarNode (" Bias" , node)},
234- {GetOutputVarNode (" Y" , node), GetOutputVarNode (" Mean" , node),
235- GetOutputVarNode (" Variance" , node)},
236- groupnorm_attrs_);
255+ return CreateBaseOp (graph,
256+ node,
257+ " popart_groupnormalization_v2" ,
258+ {GetInputVarNode (" X" , node),
259+ GetInputVarNode (" Scale" , node),
260+ GetInputVarNode (" Bias" , node)},
261+ {GetOutputVarNode (" Y" , node),
262+ GetOutputVarNode (" Mean" , node),
263+ GetOutputVarNode (" Variance" , node)},
264+ groupnorm_attrs_);
237265 }
238266
239267 std::vector<int64_t > norm_shape_{1 , 1 };
@@ -251,15 +279,23 @@ Node *layer_norm_handler(Graph *graph, Node *node) {
251279 {" dtype" , ONNXDataType::INT64}};
252280 auto reshape1_const =
253281 CreateBaseOp (graph, node, " popart_constant" , {}, {}, attrs1);
254- auto new_node_reshape1 = CreateBaseOp (
255- graph, node, " popart_reshape" ,
256- {GetInputVarNode (" X" , node), reshape1_const->outputs [0 ]}, {}, {});
282+ auto new_node_reshape1 =
283+ CreateBaseOp (graph,
284+ node,
285+ " popart_reshape" ,
286+ {GetInputVarNode (" X" , node), reshape1_const->outputs [0 ]},
287+ {},
288+ {});
257289
258290 auto out_Y_ = MakeVarNode (graph, node);
259- CreateBaseOp (graph, node, " popart_groupnormalization_v2" ,
260- {new_node_reshape1->outputs [0 ], GetInputVarNode (" Scale" , node),
291+ CreateBaseOp (graph,
292+ node,
293+ " popart_groupnormalization_v2" ,
294+ {new_node_reshape1->outputs [0 ],
295+ GetInputVarNode (" Scale" , node),
261296 GetInputVarNode (" Bias" , node)},
262- {out_Y_, GetOutputVarNode (" Mean" , node),
297+ {out_Y_,
298+ GetOutputVarNode (" Mean" , node),
263299 GetOutputVarNode (" Variance" , node)},
264300 groupnorm_attrs_);
265301
@@ -269,9 +305,12 @@ Node *layer_norm_handler(Graph *graph, Node *node) {
269305 {" dtype" , ONNXDataType::INT64}};
270306 auto reshape2_const =
271307 CreateBaseOp (graph, node, " popart_constant" , {}, {}, attrs2);
272- auto new_node_reshape2 = CreateBaseOp (graph, node, " popart_reshape" ,
308+ auto new_node_reshape2 = CreateBaseOp (graph,
309+ node,
310+ " popart_reshape" ,
273311 {out_Y_, reshape2_const->outputs [0 ]},
274- {GetOutputVarNode (" Y" , node)}, {});
312+ {GetOutputVarNode (" Y" , node)},
313+ {});
275314 return new_node_reshape2;
276315}
277316
@@ -292,18 +331,27 @@ Node *dropout_handler(Graph *graph, Node *node) {
292331
293332 if (is_test_) {
294333 if (dropout_implementation_ == " upscale_in_train" ) {
295- return CreateBaseOp (graph, node, " popart_identity" ,
334+ return CreateBaseOp (graph,
335+ node,
336+ " popart_identity" ,
296337 {GetInputVarNode (" X" , node)},
297- {GetOutputVarNode (" Out" , node)}, {});
338+ {GetOutputVarNode (" Out" , node)},
339+ {});
298340 } else if (dropout_implementation_ == " downgrade_in_infer" ) {
299341 auto scale =
300- CreateConst (graph, node, {}, {},
342+ CreateConst (graph,
343+ node,
344+ {},
345+ {},
301346 {{" value" , std::vector<float >{1 - dropout_prob_}},
302347 {" dims" , std::vector<int64_t >{1 }},
303348 {" dtype" , GetOutputVarDType (node)}});
304- return CreateBaseOp (graph, node, " popart_mul" ,
349+ return CreateBaseOp (graph,
350+ node,
351+ " popart_mul" ,
305352 {GetInputVarNode (" X" , node), scale->outputs [0 ]},
306- {GetOutputVarNode (" Out" , node)}, {});
353+ {GetOutputVarNode (" Out" , node)},
354+ {});
307355 } else {
308356 PADDLE_THROW (
309357 platform::errors::InvalidArgument (" Invalid dropout_implementation" ));
@@ -312,9 +360,12 @@ Node *dropout_handler(Graph *graph, Node *node) {
312360 if (dropout_implementation_ == " upscale_in_train" ) {
313361 auto attrs_ =
314362 AttributeMap{{" num_outputs" , (int64_t )1 }, {" ratio" , dropout_prob_}};
315- return CreateBaseOp (graph, node, " popart_dropout" ,
363+ return CreateBaseOp (graph,
364+ node,
365+ " popart_dropout" ,
316366 {GetInputVarNode (" X" , node)},
317- {GetOutputVarNode (" Out" , node)}, attrs_);
367+ {GetOutputVarNode (" Out" , node)},
368+ attrs_);
318369 } else if (dropout_implementation_ == " downgrade_in_infer" ) {
319370 PADDLE_THROW (platform::errors::InvalidArgument (
320371 " Do not support downgrade_in_infer with training" ));
@@ -388,28 +439,77 @@ Node *conv2d_transpose_handler(Graph *graph, Node *node) {
388439 {" pads" , paddings},
389440 {" strides" , strides}};
390441 if (!op->Input (" Bias" ).empty ()) {
391- return CreateBaseOp (graph, node, " popart_convtranspose" ,
442+ return CreateBaseOp (graph,
443+ node,
444+ " popart_convtranspose" ,
392445 {
393446 GetInputVarNode (" Input" , node),
394447 GetInputVarNode (" Filter" , node),
395448 GetInputVarNode (" Bias" , node),
396449 },
397- node->outputs , attrs);
450+ node->outputs ,
451+ attrs);
398452 } else {
399- return CreateBaseOp (graph, node, " popart_convtranspose" ,
453+ return CreateBaseOp (graph,
454+ node,
455+ " popart_convtranspose" ,
400456 {
401457 GetInputVarNode (" Input" , node),
402458 GetInputVarNode (" Filter" , node),
403459 },
404- node->outputs , attrs);
460+ node->outputs ,
461+ attrs);
462+ }
463+ }
464+
465+ Node *affine_channel_handler (Graph *graph, Node *node) {
466+ auto *op = node->Op ();
467+
468+ auto data_layout = BOOST_GET_CONST (std::string, op->GetAttr (" data_layout" ));
469+ if (data_layout != " NCHW" ) {
470+ platform::errors::InvalidArgument (" Only support NCHW as data_format." );
471+ }
472+
473+ auto *scale = GetInputVarNode (" Scale" , node);
474+ auto *bias = GetInputVarNode (" Bias" , node);
475+ auto scale_shape = scale->Var ()->GetShape ();
476+ auto bias_shape = bias->Var ()->GetShape ();
477+ if (scale_shape.size () <= 1 || bias_shape.size () <= 1 ) {
478+ auto attrs = AttributeMap{{" value" , std::vector<int64_t >{1 , -1 , 1 , 1 }},
479+ {" dims" , std::vector<int64_t >{4 }},
480+ {" dtype" , ONNXDataType::INT64}};
481+ auto new_shape_const = CreateConst (graph, node, {}, {}, attrs);
482+
483+ scale = CreateBaseOp (graph,
484+ node,
485+ " popart_reshape" ,
486+ {scale, new_shape_const->outputs [0 ]},
487+ {},
488+ {})
489+ ->outputs [0 ];
490+ bias = CreateBaseOp (graph,
491+ node,
492+ " popart_reshape" ,
493+ {bias, new_shape_const->outputs [0 ]},
494+ {},
495+ {})
496+ ->outputs [0 ];
405497 }
498+ auto *out = CreateBaseOp (
499+ graph, node, " popart_mul" , {GetInputVarNode (" X" , node), scale}, {});
500+ return CreateBaseOp (graph,
501+ node,
502+ " popart_add" ,
503+ {out->outputs [0 ], bias},
504+ {GetOutputVarNode (" Out" , node)});
406505}
407506
408507} // namespace
409508} // namespace ipu
410509} // namespace platform
411510} // namespace paddle
412511
512+ REGISTER_HANDLER (affine_channel, affine_channel_handler);
413513REGISTER_HANDLER (pool2d, pool2d_handler);
414514REGISTER_HANDLER (max_pool2d_with_index, max_pool2d_with_index_handler);
415515REGISTER_HANDLER (batch_norm, batch_norm_handler);
0 commit comments