@@ -38,11 +38,13 @@ namespace operators {
3838namespace detail {
3939
4040// Post processing function for sum, max, min, prod, any
41- template <typename T >
41+ template <typename Tx, typename Ty = Tx >
4242struct IdentityFunctor {
4343 HOSTDEVICE explicit inline IdentityFunctor () {}
4444
45- HOSTDEVICE inline T operator ()(const T& x) const { return x; }
45+ HOSTDEVICE inline Ty operator ()(const Tx& x) const {
46+ return static_cast <Ty>(x);
47+ }
4648};
4749
4850// Post processing function for mean
@@ -81,7 +83,7 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
8183#ifdef __HIPCC__
8284constexpr int kMaxBlock = 256 ;
8385#else
84- constexpr int kMaxBlock = 512 ;
86+ constexpr int kMaxBlock = 128 ;
8587#endif
8688
8789// get blockDim for reduceLastDim and reduceAny
@@ -544,8 +546,7 @@ __global__ void ReduceKernelFunction(
544546
545547template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
546548 typename TransformOp, int kRank , int kReduceRank >
547- static void LaunchKernel (const Tx* x_data, Ty* y_data,
548- const platform::Place& place, const ReduceOp& reducer,
549+ static void LaunchKernel (const Tx* x_data, Ty* y_data, const ReduceOp& reducer,
549550 const TransformOp& transformer, const Ty& init,
550551 gpuStream_t stream, ReduceConfig<Ty> config) {
551552#define CUB_REDUCE_TYPE_CASE (type ) \
@@ -589,7 +590,6 @@ static void LaunchKernel(const Tx* x_data, Ty* y_data,
589590template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
590591 typename TransformOp>
591592static void LaunchReduceKernel (const Tx* x_data, Ty* y_data,
592- const platform::Place& place,
593593 const ReduceOp& reducer,
594594 const TransformOp& transformer, const Ty& init,
595595 gpuStream_t stream, ReduceConfig<Ty> config) {
@@ -606,26 +606,9 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
606606 case i: { \
607607 constexpr auto kReduceRank = i; \
608608 LaunchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank , kReduceRank >( \
609- x_data, y_data, place, reducer, transformer, init, stream, config); \
609+ x_data, y_data, reducer, transformer, init, stream, config); \
610610 } break
611611
612- // launch CUB::Reduce
613- if (config.reduce_type == static_cast <int >(ReduceType::kReduceAll )) {
614- cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x (
615- x_data, transformer);
616- size_t temp_storage_bytes = 0 ;
617- cub::DeviceReduce::Reduce (nullptr , temp_storage_bytes, trans_x, y_data,
618- config.reduce_num , reducer, init, stream);
619- framework::Tensor tmp;
620- auto * temp_storage = tmp.mutable_data <uint8_t >(
621- framework::make_ddim ({static_cast <int64_t >(temp_storage_bytes)}),
622- place);
623- cub::DeviceReduce::Reduce (temp_storage, temp_storage_bytes, trans_x, y_data,
624- config.reduce_num , reducer, init, stream);
625-
626- return ;
627- }
628-
629612 detail::CheckReduceRank (reduce_rank, rank);
630613 switch (rank) {
631614 CUB_RANK_CASE (2 , CUB_REDUCE_RANK_CASE (1 ););
@@ -649,10 +632,12 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
649632#undef CUB_RANK_CASE
650633}
651634
652- template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
635+ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
636+ typename CubTransformOp = TransformOp>
653637void TensorReduceFunc (const framework::Tensor& x, framework::Tensor* y,
654638 std::vector<int > origin_reduce_dims, const Ty& init,
655639 const ReduceOp& reducer, const TransformOp& transformer,
640+ const CubTransformOp& cub_transformer,
656641 gpuStream_t stream) {
657642 auto x_dim = framework::vectorize<int >(x.dims ());
658643 auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
@@ -673,13 +658,28 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
673658 y->Resize (out_dims);
674659 return ;
675660 }
661+ // launch CUB::Reduce
662+ if (config.reduce_type == static_cast <int >(ReduceType::kReduceAll )) {
663+ cub::TransformInputIterator<Ty, CubTransformOp, const Tx*> trans_x (
664+ x_data, cub_transformer);
665+ size_t temp_storage_bytes = 0 ;
666+ cub::DeviceReduce::Reduce (nullptr , temp_storage_bytes, trans_x, y_data,
667+ config.reduce_num , reducer, init, stream);
668+ framework::Tensor tmp;
669+ auto * temp_storage = tmp.mutable_data <uint8_t >(
670+ framework::make_ddim ({static_cast <int64_t >(temp_storage_bytes)}),
671+ x.place ());
672+ cub::DeviceReduce::Reduce (temp_storage, temp_storage_bytes, trans_x, y_data,
673+ config.reduce_num , reducer, init, stream);
674+
675+ return ;
676+ }
676677
677- #define CUB_BLOCK_DIM_CASE (block_dim ) \
678- case block_dim: { \
679- constexpr auto kBlockDim = block_dim; \
680- LaunchReduceKernel<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
681- x_data, y_data, x.place (), reducer, transformer, init, stream, \
682- config); \
678+ #define CUB_BLOCK_DIM_CASE (block_dim ) \
679+ case block_dim: { \
680+ constexpr auto kBlockDim = block_dim; \
681+ LaunchReduceKernel<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
682+ x_data, y_data, reducer, transformer, init, stream, config); \
683683 } break
684684
685685 switch (detail::GetBlockDim (config.reduce_num )) {
@@ -696,5 +696,36 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
696696#undef CUB_BLOCK_DIM_CASE
697697}
698698
699+ template <typename Tx, typename ReduceOp,
700+ template <typename , typename > class TransformOp >
701+ struct TensorReduceFunctorImpl {
702+ const framework::Tensor& x;
703+ framework::Tensor* y;
704+ std::vector<int > origin_reduce_dims;
705+ const double & init;
706+ const ReduceOp& reducer;
707+ gpuStream_t stream;
708+ TensorReduceFunctorImpl (const framework::Tensor& x, framework::Tensor* y,
709+ std::vector<int > origin_reduce_dims,
710+ const double & init, const ReduceOp& reducer,
711+ gpuStream_t stream)
712+ : x(x),
713+ y (y),
714+ origin_reduce_dims(origin_reduce_dims),
715+ init(init),
716+ reducer(reducer),
717+ stream(stream) {}
718+
719+ template <typename Ty>
720+
721+ void apply () const {
722+ const Ty& init_cast = static_cast <Ty>(init);
723+ TensorReduceFunc<Tx, Ty, ReduceOp, TransformOp<Ty, Ty>,
724+ TransformOp<Tx, Ty>>(x, y, origin_reduce_dims, init_cast,
725+ reducer, TransformOp<Ty, Ty>(),
726+ TransformOp<Tx, Ty>(), stream);
727+ }
728+ };
729+
699730} // namespace operators
700731} // namespace paddle
0 commit comments