Skip to content

Commit 8700894

Browse files
committed
update TensorReduceFunc
1 parent 790173a commit 8700894

File tree

4 files changed

+70
-46
lines changed

4 files changed

+70
-46
lines changed

paddle/fluid/operators/reduce_ops/reduce_max_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ReduceMaxKernel : public framework::OpKernel<T> {
4545
auto stream = context.cuda_device_context().stream();
4646
TensorReduceFunc<T, T, CustomMax<T>, detail::IdentityFunctor<T>>(
4747
*input, output, reduce_dims, DataBound<T>::min(), CustomMax<T>(),
48-
detail::IdentityFunctor<T>(), stream);
48+
detail::IdentityFunctor<T>(), detail::IdentityFunctor<T>(), stream);
4949
}
5050
};
5151

paddle/fluid/operators/reduce_ops/reduce_min_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ReduceMinKernel : public framework::OpKernel<T> {
4545
auto stream = context.cuda_device_context().stream();
4646
TensorReduceFunc<T, T, CustomMin<T>, detail::IdentityFunctor<T>>(
4747
*input, output, reduce_dims, DataBound<T>::max(), CustomMin<T>(),
48-
detail::IdentityFunctor<T>(), stream);
48+
detail::IdentityFunctor<T>(), detail::IdentityFunctor<T>(), stream);
4949
}
5050
};
5151

paddle/fluid/operators/reduce_ops/reduce_op.cuh

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ namespace operators {
3838
namespace detail {
3939

4040
// Post processing function for sum, max, min, prod, any
41-
template <typename T>
41+
template <typename Tx, typename Ty = Tx>
4242
struct IdentityFunctor {
4343
HOSTDEVICE explicit inline IdentityFunctor() {}
4444

45-
HOSTDEVICE inline T operator()(const T& x) const { return x; }
45+
HOSTDEVICE inline Ty operator()(const Tx& x) const {
46+
return static_cast<Ty>(x);
47+
}
4648
};
4749

4850
// Post processing function for mean
@@ -81,7 +83,7 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
8183
#ifdef __HIPCC__
8284
constexpr int kMaxBlock = 256;
8385
#else
84-
constexpr int kMaxBlock = 512;
86+
constexpr int kMaxBlock = 128;
8587
#endif
8688

8789
// get blockDim for reduceLastDim and reduceAny
@@ -544,8 +546,7 @@ __global__ void ReduceKernelFunction(
544546

545547
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
546548
typename TransformOp, int kRank, int kReduceRank>
547-
static void LaunchKernel(const Tx* x_data, Ty* y_data,
548-
const platform::Place& place, const ReduceOp& reducer,
549+
static void LaunchKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer,
549550
const TransformOp& transformer, const Ty& init,
550551
gpuStream_t stream, ReduceConfig<Ty> config) {
551552
#define CUB_REDUCE_TYPE_CASE(type) \
@@ -589,7 +590,6 @@ static void LaunchKernel(const Tx* x_data, Ty* y_data,
589590
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
590591
typename TransformOp>
591592
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
592-
const platform::Place& place,
593593
const ReduceOp& reducer,
594594
const TransformOp& transformer, const Ty& init,
595595
gpuStream_t stream, ReduceConfig<Ty> config) {
@@ -606,26 +606,9 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
606606
case i: { \
607607
constexpr auto kReduceRank = i; \
608608
LaunchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank, kReduceRank>( \
609-
x_data, y_data, place, reducer, transformer, init, stream, config); \
609+
x_data, y_data, reducer, transformer, init, stream, config); \
610610
} break
611611

612-
// launch CUB::Reduce
613-
if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
614-
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
615-
x_data, transformer);
616-
size_t temp_storage_bytes = 0;
617-
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
618-
config.reduce_num, reducer, init, stream);
619-
framework::Tensor tmp;
620-
auto* temp_storage = tmp.mutable_data<uint8_t>(
621-
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
622-
place);
623-
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
624-
config.reduce_num, reducer, init, stream);
625-
626-
return;
627-
}
628-
629612
detail::CheckReduceRank(reduce_rank, rank);
630613
switch (rank) {
631614
CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););
@@ -649,10 +632,12 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
649632
#undef CUB_RANK_CASE
650633
}
651634

652-
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
635+
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
636+
typename CubTransformOp = TransformOp>
653637
void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
654638
std::vector<int> origin_reduce_dims, const Ty& init,
655639
const ReduceOp& reducer, const TransformOp& transformer,
640+
const CubTransformOp& cub_transformer,
656641
gpuStream_t stream) {
657642
auto x_dim = framework::vectorize<int>(x.dims());
658643
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
@@ -673,13 +658,28 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
673658
y->Resize(out_dims);
674659
return;
675660
}
661+
// launch CUB::Reduce
662+
if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
663+
cub::TransformInputIterator<Ty, CubTransformOp, const Tx*> trans_x(
664+
x_data, cub_transformer);
665+
size_t temp_storage_bytes = 0;
666+
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
667+
config.reduce_num, reducer, init, stream);
668+
framework::Tensor tmp;
669+
auto* temp_storage = tmp.mutable_data<uint8_t>(
670+
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
671+
x.place());
672+
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
673+
config.reduce_num, reducer, init, stream);
674+
675+
return;
676+
}
676677

677-
#define CUB_BLOCK_DIM_CASE(block_dim) \
678-
case block_dim: { \
679-
constexpr auto kBlockDim = block_dim; \
680-
LaunchReduceKernel<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
681-
x_data, y_data, x.place(), reducer, transformer, init, stream, \
682-
config); \
678+
#define CUB_BLOCK_DIM_CASE(block_dim) \
679+
case block_dim: { \
680+
constexpr auto kBlockDim = block_dim; \
681+
LaunchReduceKernel<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
682+
x_data, y_data, reducer, transformer, init, stream, config); \
683683
} break
684684

685685
switch (detail::GetBlockDim(config.reduce_num)) {
@@ -696,5 +696,36 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
696696
#undef CUB_BLOCK_DIM_CASE
697697
}
698698

699+
template <typename Tx, typename ReduceOp,
700+
template <typename, typename> class TransformOp>
701+
struct TensorReduceFunctorImpl {
702+
const framework::Tensor& x;
703+
framework::Tensor* y;
704+
std::vector<int> origin_reduce_dims;
705+
const double& init;
706+
const ReduceOp& reducer;
707+
gpuStream_t stream;
708+
TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
709+
std::vector<int> origin_reduce_dims,
710+
const double& init, const ReduceOp& reducer,
711+
gpuStream_t stream)
712+
: x(x),
713+
y(y),
714+
origin_reduce_dims(origin_reduce_dims),
715+
init(init),
716+
reducer(reducer),
717+
stream(stream) {}
718+
719+
template <typename Ty>
720+
721+
void apply() const {
722+
const Ty& init_cast = static_cast<Ty>(init);
723+
TensorReduceFunc<Tx, Ty, ReduceOp, TransformOp<Ty, Ty>,
724+
TransformOp<Tx, Ty>>(x, y, origin_reduce_dims, init_cast,
725+
reducer, TransformOp<Ty, Ty>(),
726+
TransformOp<Tx, Ty>(), stream);
727+
}
728+
};
729+
699730
} // namespace operators
700731
} // namespace paddle

paddle/fluid/operators/reduce_ops/reduce_prod_op.cu

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,15 @@ class ReduceProdKernel : public framework::OpKernel<T> {
5151

5252
auto stream = context.cuda_device_context().stream();
5353
if (out_dtype >= 0) {
54-
#define VisitDataTypeSmall_t(cpp_type, proto_type) \
55-
do { \
56-
if (static_cast<framework::proto::VarType::Type>(out_dtype) == \
57-
proto_type) { \
58-
TensorReduceFunc<T, cpp_type, CustomMul<cpp_type>, \
59-
detail::IdentityFunctor<cpp_type>>( \
60-
*input, output, reduce_dims, static_cast<cpp_type>(1.0f), \
61-
CustomMul<cpp_type>(), detail::IdentityFunctor<cpp_type>(), stream); \
62-
} \
63-
} while (0)
64-
_ForEachDataTypeSmall_(VisitDataTypeSmall_t);
65-
#undef VisitDataTypeSmall_t
54+
framework::VisitDataTypeSmall(
55+
static_cast<framework::proto::VarType::Type>(out_dtype),
56+
TensorReduceFunctorImpl<T, cub::Sum, detail::IdentityFunctor>(
57+
*input, output, reduce_dims, static_cast<double>(1.0f),
58+
cub::Sum(), stream));
6659
} else {
6760
TensorReduceFunc<T, T, CustomMul<T>, detail::IdentityFunctor<T>>(
6861
*input, output, reduce_dims, static_cast<T>(1.0f), CustomMul<T>(),
69-
detail::IdentityFunctor<T>(), stream);
62+
detail::IdentityFunctor<T>(), detail::IdentityFunctor<T>(), stream);
7063
}
7164
}
7265
};

0 commit comments

Comments
 (0)