@@ -13,46 +13,98 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#pragma once
16- #include < string>
17- #include < vector>
18-
19- #include " paddle/fluid/framework/eigen.h"
20- #include " paddle/fluid/framework/tensor.h"
21- #include " paddle/fluid/operators/amp/fp16_type_traits.h"
22- #include " paddle/fluid/platform/device_context.h"
16+ #include < cmath>
17+ #include < limits>
18+ #include " paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
2319#include " paddle/fluid/platform/hostdevice.h"
24- #include " paddle/fluid/platform/macros.h"
20+ #ifdef __HIPCC__
21+ #include < hip/hip_runtime.h>
22+ #endif
2523
2624namespace paddle {
2725namespace operators {
2826
29- template <typename T >
27+ template <typename Tx, typename Ty = Tx >
3028struct CustomMin {
31- __device__ __forceinline__ T operator ()(const T &a, const T &b) const {
29+ using Transformer = detail::IdentityFunctor<Tx>;
30+
31+ inline Ty initial () {
32+ return static_cast <Ty>(std::numeric_limits<Ty>::max ());
33+ }
34+
35+ __device__ __forceinline__ Ty operator ()(const Ty &a, const Ty &b) const {
3236 return (b < a) ? b : a;
3337 }
3438};
3539
36- template <typename T >
40+ template <typename Tx, typename Ty = Tx >
3741struct CustomMax {
38- __device__ __forceinline__ T operator ()(const T &a, const T &b) const {
42+ using Transformer = detail::IdentityFunctor<Tx>;
43+
44+ inline Ty initial () {
45+ return static_cast <Ty>(std::numeric_limits<Ty>::lowest ());
46+ }
47+
48+ __device__ __forceinline__ Ty operator ()(const Ty &a, const Ty &b) const {
3949 return (b > a) ? b : a;
4050 }
4151};
4252
43- template <typename T>
53+ // for cub::Reduce
54+ template <typename Tx, typename Ty = Tx>
4455struct CustomSum {
45- __device__ __forceinline__ T operator ()(const T &a, const T &b) const {
56+ using Transformer = detail::IdentityFunctor<Tx, Ty>;
57+
58+ inline Ty initial () { return static_cast <Ty>(0 .0f ); }
59+
60+ __device__ __forceinline__ Ty operator ()(const Ty &a, const Ty &b) const {
4661 return b + a;
4762 }
4863};
4964
50- template <typename T>
65+ template <typename Tx, typename Ty = Tx>
66+ struct CustomMean {
67+ using Transformer = detail::DivideFunctor<Tx>;
68+
69+ inline Ty initial () { return static_cast <Ty>(0 .0f ); }
70+
71+ __device__ __forceinline__ Ty operator ()(const Ty &a, const Ty &b) const {
72+ return b + a;
73+ }
74+ };
75+
76+ template <typename Tx, typename Ty = Tx>
5177struct CustomMul {
52- __device__ __forceinline__ T operator ()(const T &a, const T &b) const {
78+ using Transformer = detail::IdentityFunctor<Tx>;
79+
80+ inline Ty initial () { return static_cast <Ty>(1 .0f ); }
81+
82+ __device__ __forceinline__ Ty operator ()(const Ty &a, const Ty &b) const {
5383 return b * a;
5484 }
5585};
5686
87+ template <typename Tx, typename Ty = Tx>
88+ struct CustomLogicalOr {
89+ using Transformer = detail::IdentityFunctor<Tx>;
90+
91+ inline Ty initial () { return static_cast <Ty>(false ); }
92+
93+ __device__ __forceinline__ Ty operator ()(const Ty &a, const Ty &b) const {
94+ return b || a;
95+ }
96+ };
97+
98+ template <typename Tx, typename Ty = Tx>
99+ struct CustomLogicalAnd {
100+ using Transformer = detail::IdentityFunctor<Tx>;
101+
102+ inline Ty initial () { return static_cast <Ty>(true ); }
103+
104+ __device__ __forceinline__ Ty operator ()(const Ty &a, const Ty &b) const {
105+ return b && a;
106+ }
107+ };
108+
57109} // namespace operators
58110} // namespace paddle
0 commit comments