@@ -419,12 +419,12 @@ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
419419 int BlockDim>
420420__device__ __forceinline__ void ReduceLastDim (const Tx* x, Ty* y,
421421 ReduceOp reducer,
422- TransformOp transformer,
422+ TransformOp transformer, Ty init,
423423 int reduce_num) {
424424 __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
425425 int idx_x = blockIdx.x * reduce_num;
426426 int idx_y = threadIdx.x ;
427- Ty reduce_var = reducer. initial () ;
427+ Ty reduce_var = init ;
428428 for (int idx_y = threadIdx.x ; idx_y < reduce_num; idx_y += BlockDim) {
429429 reduce_var =
430430 reducer (reduce_var, static_cast <Ty>(transformer (x[idx_x + idx_y])));
@@ -448,12 +448,12 @@ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
448448__device__ __forceinline__ void ReduceHigherDim (const Tx* x, Ty* y,
449449 ReduceOp reducer,
450450 TransformOp transformer,
451- int reduce_num , int left_num ,
452- int block_size) {
451+ Ty init , int reduce_num ,
452+ int left_num, int block_size) {
453453 int idx = blockIdx.x * blockDim.x + threadIdx.x ;
454454 int idy = blockIdx.y * block_size;
455455
456- Ty reduce_var = reducer. initial () ;
456+ Ty reduce_var = init ;
457457
458458 if (idx < left_num) {
459459 int loop = reduce_num - idy;
@@ -532,7 +532,7 @@ __device__ __forceinline__ void ReduceAny(
532532template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
533533 int BlockDim, int Rank, int ReduceRank, int ReduceType>
534534__device__ __forceinline__ void ReduceModule (
535- const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer,
535+ const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
536536 int reduce_num, int left_num, int blocking_size,
537537 paddle::framework::Array<int , Rank> x_strides,
538538 paddle::framework::Array<int , ReduceRank> reduce_dim,
@@ -542,12 +542,12 @@ __device__ __forceinline__ void ReduceModule(
542542 // reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
543543 if (ReduceType == ReduceType::kReduceLastDim ) {
544544 ReduceLastDim<Tx, Ty, ReduceOp, TransformOp, BlockDim>(
545- x, y, reducer, transformer, reduce_num);
545+ x, y, reducer, transformer, init, reduce_num);
546546
547547 // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
548548 } else if (ReduceType == ReduceType::kReduceHigherDim ) {
549549 ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
550- x, y, reducer, transformer, reduce_num, left_num, blocking_size);
550+ x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
551551
552552 // reduce_rank >= 2
553553 } else {
@@ -560,32 +560,32 @@ __device__ __forceinline__ void ReduceModule(
560560template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
561561 int BlockDim, int Rank, int ReduceRank, int ReduceType>
562562__global__ void ReduceKernelFunction (
563- const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer,
563+ const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
564564 int reduce_num, int left_num, int block_size,
565565 paddle::framework::Array<int , Rank> x_strides,
566566 paddle::framework::Array<int , ReduceRank> reduce_dim,
567567 paddle::framework::Array<int , ReduceRank> reduce_strides,
568568 paddle::framework::Array<int , Rank - ReduceRank> left_dim,
569569 paddle::framework::Array<int , Rank - ReduceRank> left_strides) {
570570 ReduceModule<Tx, Ty, ReduceOp, TransformOp, BlockDim, Rank, ReduceRank,
571- ReduceType>(x, y, reducer, transformer, reduce_num, left_num ,
572- block_size, x_strides, reduce_dim, reduce_strides ,
573- left_dim, left_strides);
571+ ReduceType>(x, y, reducer, transformer, init, reduce_num ,
572+ left_num, block_size, x_strides, reduce_dim,
573+ reduce_strides, left_dim, left_strides);
574574}
575575
576576template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
577577 typename TransformOp, int kRank , int kReduceRank >
578578static void LaunchKernel (const Tx* x_data, Ty* y_data, const ReduceOp& reducer,
579- const TransformOp& transformer, gpuStream_t stream ,
580- ReduceConfig<Ty> config) {
579+ const TransformOp& transformer, Ty init ,
580+ gpuStream_t stream, ReduceConfig<Ty> config) {
581581#define CUB_REDUCE_TYPE_CASE (type ) \
582582 case type: { \
583583 constexpr auto kReduceType = type; \
584584 ReduceKernelFunction< \
585585 Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank , kReduceRank , \
586586 kReduceType ><<<config.grid , config.block , 0 , stream>>>( \
587- x_data, config.output_data , reducer, transformer, config. reduce_num , \
588- config.left_num , config.blocking_size , \
587+ x_data, config.output_data , reducer, transformer, init, \
588+ config.reduce_num , config. left_num , config.blocking_size , \
589589 detail::VectorToArray<int , kRank >(config.x_strides ), \
590590 detail::VectorToArray<int , kReduceRank >(config.reduce_dim ), \
591591 detail::VectorToArray<int , kReduceRank >(config.reduce_strides ), \
@@ -607,7 +607,7 @@ static void LaunchKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer,
607607 Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, 128 , kRank , kReduceRank ,
608608 ReduceType::kReduceHigherDim ><<<grid, block, 0 , stream>>>(
609609 config.output_data , y_data, reducer,
610- detail::IdentityFunctor<Ty>(config.grid .y ), config.grid .y ,
610+ detail::IdentityFunctor<Ty>(config.grid .y ), init, config.grid .y ,
611611 config.left_num , config.grid .y ,
612612 detail::VectorToArray<int , kRank >(config.x_strides ),
613613 detail::VectorToArray<int , kReduceRank >(config.reduce_dim ),
@@ -621,7 +621,7 @@ template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
621621 typename TransformOp>
622622static void LaunchReduceKernel (const Tx* x_data, Ty* y_data,
623623 const ReduceOp& reducer,
624- const TransformOp& transformer,
624+ const TransformOp& transformer, Ty init,
625625 gpuStream_t stream, ReduceConfig<Ty> config) {
626626 int reduce_rank = config.reduce_strides .size ();
627627 int rank = config.x_strides .size ();
@@ -636,7 +636,7 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
636636 case i: { \
637637 constexpr auto kReduceRank = i; \
638638 LaunchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank , kReduceRank >( \
639- x_data, y_data, reducer, transformer, stream, config); \
639+ x_data, y_data, reducer, transformer, init, stream, config); \
640640 } break
641641
642642 detail::CheckReduceRank (reduce_rank, rank);
@@ -711,8 +711,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
711711 case block_dim: { \
712712 constexpr auto kBlockDim = block_dim; \
713713 LaunchReduceKernel<Tx, Ty, block_dim, ReduceOp<Tx, Ty>, TransformOp>( \
714- x_data, y_data, reducer, TransformOp (config.reduce_num ), stream, \
715- config); \
714+ x_data, y_data, reducer, TransformOp (config.reduce_num ), \
715+ reducer. initial (), stream, config); \
716716 } break
717717
718718 switch (detail::GetBlockDim (config.reduce_num )) {
0 commit comments