@@ -224,6 +224,102 @@ void MLUUnaryOp(const framework::ExecutionContext& ctx) {
224
224
out_desc.get (), GetBasePtr (out));
225
225
}
226
226
227
+ // ------------------ MLUElementwiseGradOp -----------------
228
+ enum MINMAX_GRAD_FUNCTOR {
229
+ MAXIMUM_GRAD,
230
+ MINIMUM_GRAD,
231
+ };
232
+ template <MINMAX_GRAD_FUNCTOR Functor, typename Tin, typename Tout = Tin>
233
+ void MLUMinMaxGradHelper (const framework::ExecutionContext& ctx) {
234
+ auto * x = ctx.Input <Tensor>(" X" );
235
+ auto * y = ctx.Input <Tensor>(" Y" );
236
+ auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
237
+ auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
238
+ auto * dy = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
239
+ int axis = ctx.Attr <int >(" axis" );
240
+
241
+ const auto & x_dims = x->dims ();
242
+ const auto & y_dims = y->dims ();
243
+ axis =
244
+ (axis < 0 ? (std::abs (x_dims.size () - y_dims.size ()) + axis + 1 ) : axis);
245
+ int max_dim = std::max (x_dims.size (), y_dims.size ());
246
+ std::vector<int > x_dims_array (max_dim);
247
+ std::vector<int > y_dims_array (max_dim);
248
+ std::vector<int > out_dims_array (max_dim);
249
+ GetBroadcastDimsArrays (x_dims, y_dims, x_dims_array.data (),
250
+ y_dims_array.data (), out_dims_array.data (), max_dim,
251
+ axis);
252
+
253
+ // mask = Logic(x, y) only support min & max
254
+ cnnlLogicOp_t logic =
255
+ Functor == MAXIMUM_GRAD ? CNNL_LOGIC_OP_GE : CNNL_LOGIC_OP_LE;
256
+ Tensor mask (x->dtype ());
257
+ mask.Resize (phi::make_ddim (out_dims_array));
258
+ mask.mutable_data <Tin>(ctx.GetPlace ());
259
+
260
+ cnnlDataType_t data_type = ToCnnlDataType<Tin>();
261
+ MLUCnnlTensorDesc x_desc (max_dim, x_dims_array.data (), data_type);
262
+ MLUCnnlTensorDesc y_desc (max_dim, y_dims_array.data (), data_type);
263
+ MLUCnnlTensorDesc mask_desc (max_dim, out_dims_array.data (), data_type);
264
+ MLUCnnl::Logic (ctx, logic, x_desc.get (), GetBasePtr (x), y_desc.get (),
265
+ GetBasePtr (y), mask_desc.get (), GetBasePtr (&mask));
266
+
267
+ // dx = Mul(dz, mask)
268
+ Tensor dx_temp (x->dtype ());
269
+ dx_temp.Resize (dout->dims ());
270
+ dx_temp.mutable_data <Tout>(ctx.GetPlace ());
271
+ MLUCnnlTensorDesc dout_desc (*dout);
272
+ MLUCnnlOpTensorDesc mul_op_desc (CNNL_OP_TENSOR_MUL, data_type,
273
+ CNNL_NOT_PROPAGATE_NAN);
274
+ MLUCnnl::OpTensor (ctx, mul_op_desc.get (), dout_desc.get (), GetBasePtr (dout),
275
+ dout_desc.get (), GetBasePtr (&mask), dout_desc.get (),
276
+ GetBasePtr (&dx_temp), data_type);
277
+
278
+ // dy = Sub(dz, dx)
279
+ Tensor dy_temp (y->dtype ());
280
+ dy_temp.Resize (dout->dims ());
281
+ dy_temp.mutable_data <Tout>(ctx.GetPlace ());
282
+ MLUCnnlOpTensorDesc sub_op_desc (CNNL_OP_TENSOR_SUB, data_type,
283
+ CNNL_NOT_PROPAGATE_NAN);
284
+ MLUCnnl::OpTensor (ctx, sub_op_desc.get (), dout_desc.get (), GetBasePtr (dout),
285
+ dout_desc.get (), GetBasePtr (&dx_temp), dout_desc.get (),
286
+ GetBasePtr (&dy_temp), data_type);
287
+
288
+ if (dx) {
289
+ if (dx->dims () != dout->dims ()) {
290
+ dx->mutable_data <Tout>(ctx.GetPlace ());
291
+ std::vector<int > reduce_axes;
292
+ GetReduceAxes (axis, dx_temp.dims (), dx->dims (), &reduce_axes);
293
+ MLUCnnlReduceDesc reduction_desc (
294
+ reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
295
+ CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
296
+ MLUCnnlTensorDesc dx_desc (*dx);
297
+ MLUCnnl::Reduce (ctx, true /* need_workspace*/ , reduction_desc.get (),
298
+ nullptr , dout_desc.get (), GetBasePtr (&dx_temp), 0 ,
299
+ nullptr , nullptr , dx_desc.get (), GetBasePtr (dx));
300
+ } else {
301
+ dx->ShareDataWith (dx_temp);
302
+ }
303
+ }
304
+
305
+ if (dy) {
306
+ if (dy->dims () != dout->dims ()) {
307
+ dy->mutable_data <Tout>(ctx.GetPlace ());
308
+ std::vector<int > reduce_axes;
309
+ GetReduceAxes (axis, dy_temp.dims (), dy->dims (), &reduce_axes);
310
+ MLUCnnlReduceDesc reduction_desc (
311
+ reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
312
+ CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
313
+ MLUCnnlTensorDesc dy_desc (*dy);
314
+ MLUCnnl::Reduce (ctx, true /* need_workspace*/ , reduction_desc.get (),
315
+ nullptr , dout_desc.get (), GetBasePtr (&dy_temp), 0 ,
316
+ nullptr , nullptr , dy_desc.get (), GetBasePtr (dy));
317
+ } else {
318
+ dy->ShareDataWith (dy_temp);
319
+ }
320
+ }
321
+ }
322
+
227
323
} // namespace operators
228
324
} // namespace paddle
229
325
#endif
0 commit comments