@@ -61,45 +61,5 @@ struct BaseActivationFunctor {
61
61
62
62
USE_PHI_FUNCTOR (Mish)
63
63
64
- template <typename T>
65
- struct SoftReluFunctor : public BaseActivationFunctor <T> {
66
- float threshold;
67
- typename BaseActivationFunctor<T>::AttrPair GetAttrs () {
68
- return {{" threshold" , &threshold}};
69
- }
70
-
71
- template <typename Device, typename X, typename Out>
72
- void operator ()(Device d, X x, Out out) const {
73
- auto tmp = static_cast <T>(threshold);
74
- auto temp = x.cwiseMax (-tmp).cwiseMin (tmp);
75
- out.device (d) = (static_cast <T>(1 ) + temp.exp ()).log ();
76
- }
77
- };
78
-
79
- template <typename T>
80
- struct SoftReluGradFunctor : public BaseActivationFunctor <T> {
81
- float threshold;
82
- typename BaseActivationFunctor<T>::AttrPair GetAttrs () {
83
- return {{" threshold" , &threshold}};
84
- }
85
- template <typename Device,
86
- typename X,
87
- typename Out,
88
- typename dOut,
89
- typename dX>
90
- void operator ()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
91
- auto tmp = static_cast <T>(threshold);
92
- auto temp = ((out > -tmp) * (out < tmp)).template cast <T>();
93
- dx.device (d) = dout * (static_cast <T>(1 ) - (-out).exp ()) * temp;
94
- }
95
-
96
- static constexpr ActBwdOpFwdDeps FwdDeps () {
97
- return ActBwdOpFwdDeps::kDepOut ;
98
- }
99
- };
100
-
101
64
} // namespace operators
102
65
} // namespace paddle
103
-
104
- #define FOR_EACH_ACTIVATION_OP (__macro ) \
105
- __macro (soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor);
0 commit comments