Skip to content

Commit f586110

Browse files
authored
[MLU] add mlu kernel for elementwise_max_grad (PaddlePaddle#43608)
* [MLU] add mlu kernel for elementwise_max_grad * [MLU] modify mlu kernel elementwise_min_grad impl
1 parent 2353db3 commit f586110

File tree

4 files changed

+381
-230
lines changed

4 files changed

+381
-230
lines changed

Diff for: paddle/fluid/operators/elementwise/elementwise_max_op_mlu.cc

+12
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,23 @@ class ElementwiseMaxMLUKernel : public framework::OpKernel<T> {
2727
}
2828
};
2929

30+
template <typename T>
31+
class ElementwiseMaxGradMLUKernel : public framework::OpKernel<T> {
32+
public:
33+
void Compute(const framework::ExecutionContext& ctx) const override {
34+
MLUMinMaxGradHelper<MAXIMUM_GRAD, T>(ctx);
35+
}
36+
};
37+
3038
} // namespace operators
3139
} // namespace paddle
3240

3341
namespace ops = paddle::operators;
3442
REGISTER_OP_MLU_KERNEL(elementwise_max, ops::ElementwiseMaxMLUKernel<int>,
3543
ops::ElementwiseMaxMLUKernel<float>,
3644
ops::ElementwiseMaxMLUKernel<paddle::platform::float16>);
45+
REGISTER_OP_MLU_KERNEL(
46+
elementwise_max_grad, ops::ElementwiseMaxGradMLUKernel<int>,
47+
ops::ElementwiseMaxGradMLUKernel<float>,
48+
ops::ElementwiseMaxGradMLUKernel<paddle::platform::float16>);
3749
#endif

Diff for: paddle/fluid/operators/elementwise/elementwise_min_op_mlu.cc

+1-86
Original file line numberDiff line numberDiff line change
@@ -34,92 +34,7 @@ template <typename T>
3434
class ElementwiseMinGradMLUKernel : public framework::OpKernel<T> {
3535
public:
3636
void Compute(const framework::ExecutionContext& ctx) const override {
37-
auto* x = ctx.Input<Tensor>("X");
38-
auto* y = ctx.Input<Tensor>("Y");
39-
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
40-
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
41-
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
42-
int axis = ctx.Attr<int>("axis");
43-
44-
const auto& x_dims = x->dims();
45-
const auto& y_dims = y->dims();
46-
axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
47-
: axis);
48-
int max_dim = std::max(x_dims.size(), y_dims.size());
49-
std::vector<int> x_dims_array(max_dim);
50-
std::vector<int> y_dims_array(max_dim);
51-
std::vector<int> out_dims_array(max_dim);
52-
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
53-
y_dims_array.data(), out_dims_array.data(), max_dim,
54-
axis);
55-
56-
// mask = LessEqual(x, y)
57-
Tensor mask(x->dtype());
58-
mask.Resize(phi::make_ddim(out_dims_array));
59-
mask.mutable_data<T>(ctx.GetPlace());
60-
61-
cnnlDataType_t data_type = ToCnnlDataType<T>();
62-
MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), data_type);
63-
MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), data_type);
64-
MLUCnnlTensorDesc mask_desc(max_dim, out_dims_array.data(), data_type);
65-
MLUCnnl::Logic(ctx, CNNL_LOGIC_OP_LE, x_desc.get(), GetBasePtr(x),
66-
y_desc.get(), GetBasePtr(y), mask_desc.get(),
67-
GetBasePtr(&mask));
68-
69-
// dx = Mul(dz, mask)
70-
Tensor dx_temp(x->dtype());
71-
dx_temp.Resize(dout->dims());
72-
dx_temp.mutable_data<T>(ctx.GetPlace());
73-
MLUCnnlTensorDesc dout_desc(*dout);
74-
MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, data_type,
75-
CNNL_NOT_PROPAGATE_NAN);
76-
MLUCnnl::OpTensor(ctx, mul_op_desc.get(), dout_desc.get(), GetBasePtr(dout),
77-
dout_desc.get(), GetBasePtr(&mask), dout_desc.get(),
78-
GetBasePtr(&dx_temp), data_type);
79-
80-
// dy = Sub(dz, dx)
81-
Tensor dy_temp(y->dtype());
82-
dy_temp.Resize(dout->dims());
83-
dy_temp.mutable_data<T>(ctx.GetPlace());
84-
MLUCnnlOpTensorDesc sub_op_desc(CNNL_OP_TENSOR_SUB, data_type,
85-
CNNL_NOT_PROPAGATE_NAN);
86-
MLUCnnl::OpTensor(ctx, sub_op_desc.get(), dout_desc.get(), GetBasePtr(dout),
87-
dout_desc.get(), GetBasePtr(&dx_temp), dout_desc.get(),
88-
GetBasePtr(&dy_temp), data_type);
89-
90-
if (dx) {
91-
if (dx->dims() != dout->dims()) {
92-
dx->mutable_data<T>(ctx.GetPlace());
93-
std::vector<int> reduce_axes;
94-
GetReduceAxes(axis, dx_temp.dims(), dx->dims(), &reduce_axes);
95-
MLUCnnlReduceDesc reduction_desc(
96-
reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
97-
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
98-
MLUCnnlTensorDesc dx_desc(*dx);
99-
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
100-
nullptr, dout_desc.get(), GetBasePtr(&dx_temp), 0,
101-
nullptr, nullptr, dx_desc.get(), GetBasePtr(dx));
102-
} else {
103-
dx->ShareDataWith(dx_temp);
104-
}
105-
}
106-
107-
if (dy) {
108-
if (dy->dims() != dout->dims()) {
109-
dy->mutable_data<T>(ctx.GetPlace());
110-
std::vector<int> reduce_axes;
111-
GetReduceAxes(axis, dy_temp.dims(), dy->dims(), &reduce_axes);
112-
MLUCnnlReduceDesc reduction_desc(
113-
reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
114-
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
115-
MLUCnnlTensorDesc dy_desc(*dy);
116-
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
117-
nullptr, dout_desc.get(), GetBasePtr(&dy_temp), 0,
118-
nullptr, nullptr, dy_desc.get(), GetBasePtr(dy));
119-
} else {
120-
dy->ShareDataWith(dy_temp);
121-
}
122-
}
37+
MLUMinMaxGradHelper<MINIMUM_GRAD, T>(ctx);
12338
}
12439
};
12540

Diff for: paddle/fluid/operators/elementwise/elementwise_mlu.h

+96
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,102 @@ void MLUUnaryOp(const framework::ExecutionContext& ctx) {
224224
out_desc.get(), GetBasePtr(out));
225225
}
226226

227+
// ------------------ MLUElementwiseGradOp -----------------
228+
enum MINMAX_GRAD_FUNCTOR {
229+
MAXIMUM_GRAD,
230+
MINIMUM_GRAD,
231+
};
232+
template <MINMAX_GRAD_FUNCTOR Functor, typename Tin, typename Tout = Tin>
233+
void MLUMinMaxGradHelper(const framework::ExecutionContext& ctx) {
234+
auto* x = ctx.Input<Tensor>("X");
235+
auto* y = ctx.Input<Tensor>("Y");
236+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
237+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
238+
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
239+
int axis = ctx.Attr<int>("axis");
240+
241+
const auto& x_dims = x->dims();
242+
const auto& y_dims = y->dims();
243+
axis =
244+
(axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) : axis);
245+
int max_dim = std::max(x_dims.size(), y_dims.size());
246+
std::vector<int> x_dims_array(max_dim);
247+
std::vector<int> y_dims_array(max_dim);
248+
std::vector<int> out_dims_array(max_dim);
249+
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
250+
y_dims_array.data(), out_dims_array.data(), max_dim,
251+
axis);
252+
253+
// mask = Logic(x, y) only support min & max
254+
cnnlLogicOp_t logic =
255+
Functor == MAXIMUM_GRAD ? CNNL_LOGIC_OP_GE : CNNL_LOGIC_OP_LE;
256+
Tensor mask(x->dtype());
257+
mask.Resize(phi::make_ddim(out_dims_array));
258+
mask.mutable_data<Tin>(ctx.GetPlace());
259+
260+
cnnlDataType_t data_type = ToCnnlDataType<Tin>();
261+
MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), data_type);
262+
MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), data_type);
263+
MLUCnnlTensorDesc mask_desc(max_dim, out_dims_array.data(), data_type);
264+
MLUCnnl::Logic(ctx, logic, x_desc.get(), GetBasePtr(x), y_desc.get(),
265+
GetBasePtr(y), mask_desc.get(), GetBasePtr(&mask));
266+
267+
// dx = Mul(dz, mask)
268+
Tensor dx_temp(x->dtype());
269+
dx_temp.Resize(dout->dims());
270+
dx_temp.mutable_data<Tout>(ctx.GetPlace());
271+
MLUCnnlTensorDesc dout_desc(*dout);
272+
MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, data_type,
273+
CNNL_NOT_PROPAGATE_NAN);
274+
MLUCnnl::OpTensor(ctx, mul_op_desc.get(), dout_desc.get(), GetBasePtr(dout),
275+
dout_desc.get(), GetBasePtr(&mask), dout_desc.get(),
276+
GetBasePtr(&dx_temp), data_type);
277+
278+
// dy = Sub(dz, dx)
279+
Tensor dy_temp(y->dtype());
280+
dy_temp.Resize(dout->dims());
281+
dy_temp.mutable_data<Tout>(ctx.GetPlace());
282+
MLUCnnlOpTensorDesc sub_op_desc(CNNL_OP_TENSOR_SUB, data_type,
283+
CNNL_NOT_PROPAGATE_NAN);
284+
MLUCnnl::OpTensor(ctx, sub_op_desc.get(), dout_desc.get(), GetBasePtr(dout),
285+
dout_desc.get(), GetBasePtr(&dx_temp), dout_desc.get(),
286+
GetBasePtr(&dy_temp), data_type);
287+
288+
if (dx) {
289+
if (dx->dims() != dout->dims()) {
290+
dx->mutable_data<Tout>(ctx.GetPlace());
291+
std::vector<int> reduce_axes;
292+
GetReduceAxes(axis, dx_temp.dims(), dx->dims(), &reduce_axes);
293+
MLUCnnlReduceDesc reduction_desc(
294+
reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
295+
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
296+
MLUCnnlTensorDesc dx_desc(*dx);
297+
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
298+
nullptr, dout_desc.get(), GetBasePtr(&dx_temp), 0,
299+
nullptr, nullptr, dx_desc.get(), GetBasePtr(dx));
300+
} else {
301+
dx->ShareDataWith(dx_temp);
302+
}
303+
}
304+
305+
if (dy) {
306+
if (dy->dims() != dout->dims()) {
307+
dy->mutable_data<Tout>(ctx.GetPlace());
308+
std::vector<int> reduce_axes;
309+
GetReduceAxes(axis, dy_temp.dims(), dy->dims(), &reduce_axes);
310+
MLUCnnlReduceDesc reduction_desc(
311+
reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
312+
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
313+
MLUCnnlTensorDesc dy_desc(*dy);
314+
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
315+
nullptr, dout_desc.get(), GetBasePtr(&dy_temp), 0,
316+
nullptr, nullptr, dy_desc.get(), GetBasePtr(dy));
317+
} else {
318+
dy->ShareDataWith(dy_temp);
319+
}
320+
}
321+
}
322+
227323
} // namespace operators
228324
} // namespace paddle
229325
#endif

0 commit comments

Comments
 (0)