88#include " THCTensor.h"
99#include " THCDeviceUtils.cuh"
1010#include " THCTensorInfo.cuh"
11+ #include " THCAsmUtils.cuh"
1112
1213// Enum that indicates whether tensor arguments are read/write or
1314// read-only
@@ -26,8 +27,101 @@ __device__ __forceinline__ IndexType getLinearBlockId() {
2627// level) with N elements per thread in the block, so we have to use min(numvals,
2728// max block size) to determine this count.
2829template <typename T, int N>
29- int reduceSmemSize (THCState *state, int numVals) {
30- return THCRoundUp (std::min (numVals, 1024 ), 32 ) * N * sizeof (T);
30+ int reduceSmemSize (THCState *state, long numVals) {
31+ // check if we can use a warp shuffle
32+ cudaDeviceProp *props = THCState_getCurrentDeviceProperties (state);
33+ if (props->major >= 3 ) {
34+ return props->warpSize * N * sizeof (T);
35+ } else {
36+ return THCRoundUp (std::min (numVals, (long ) props->maxThreadsPerBlock ), (long ) props->warpSize ) * N * sizeof (T);
37+ }
38+ }
39+
40+ template <typename T>
41+ struct THCWarpUtils {
42+ static __device__ __forceinline__ T shflxor (T val, unsigned int mask) {
43+ return __shfl_xor (val, mask);
44+ }
45+ };
46+
47+ template <>
48+ struct THCWarpUtils <unsigned char > {
49+ static __device__ __forceinline__ unsigned char shflxor (unsigned char val, unsigned int mask) {
50+ return (unsigned char ) __shfl_xor ((int ) val, mask);
51+ }
52+ };
53+
54+ template <>
55+ struct THCWarpUtils <char > {
56+ static __device__ __forceinline__ char shflxor (char val, unsigned int mask) {
57+ return (char ) __shfl_xor ((int ) val, mask);
58+ }
59+ };
60+
61+ template <>
62+ struct THCWarpUtils <short > {
63+ static __device__ __forceinline__ short shflxor (short val, unsigned int mask) {
64+ return (short ) __shfl_xor ((int ) val, mask);
65+ }
66+ };
67+
68+ template <>
69+ struct THCWarpUtils <double > {
70+ static __device__ __forceinline__ double shflxor (double val, unsigned int mask) {
71+ int2 a = *reinterpret_cast <int2 *>(&val);
72+ a.x = __shfl_xor (a.x , mask);
73+ a.y = __shfl_xor (a.y , mask);
74+ return *reinterpret_cast <double *>(&a);
75+ }
76+ };
77+
78+ template <>
79+ struct THCWarpUtils <long > {
80+ static __device__ __forceinline__ long shflxor (long val, unsigned int mask) {
81+ int2 a = *reinterpret_cast <int2 *>(&val);
82+ a.x = __shfl_xor (a.x , mask);
83+ a.y = __shfl_xor (a.y , mask);
84+ return *reinterpret_cast <long *>(&a);
85+ }
86+ };
87+
88+ template <typename T, typename ReduceOp, int N>
89+ __device__ void warpReduce (T threadVals[N], ReduceOp reduceOp) {
90+ #pragma unroll
91+ for (int mask = 1 ; mask < warpSize ; mask *= 2 ) {
92+ #pragma unroll
93+ for (int i = 0 ; i < N; ++i) {
94+ T neighbor = THCWarpUtils<T>::shflxor (threadVals[i], mask);
95+ threadVals[i] = reduceOp (threadVals[i], neighbor);
96+ }
97+ }
98+ }
99+
100+ template <typename T, typename ReduceOp, int N>
101+ __device__ void warpReduceBlock (T *smem, T threadVals[N], int numVals, ReduceOp reduceOp, T init) {
102+ assert (blockDim .x % warpSize == 0 );
103+ // First, warps cooperate to reduce values within the warp
104+ warpReduce<T, ReduceOp, N>(threadVals, reduceOp);
105+ int lane = getLaneId ();
106+ int warp = threadIdx .x / warpSize ;
107+
108+ if (lane == 0 ) {
109+
110+ #pragma unroll
111+ for (int i = 0 ; i < N; ++i) {
112+ smem[warp + (i * warpSize )] = threadVals[i];
113+ }
114+ }
115+ __syncthreads ();
116+
117+ #pragma unroll
118+ for (int i = 0 ; i < N; ++i) {
119+ threadVals[i] = (threadIdx .x < (blockDim .x / warpSize )) ? smem[lane + (i * warpSize )] : init;
120+ }
121+
122+ if (warp == 0 ) {
123+ warpReduce<T, ReduceOp, N>(threadVals, reduceOp);
124+ }
31125}
32126
33127// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
@@ -47,6 +141,9 @@ __device__ void reduceNValuesInBlock(T *smem,
47141 return ;
48142 }
49143
144+ #if __CUDA_ARCH__ >= 300
145+ warpReduceBlock<T, ReduceOp, N>(smem, threadVals, numVals, reduceOp, init);
146+ #else
50147 // We store each of the N values contiguously, so if N = 2, all values for
51148 // the first threadVal for each thread in the block are stored followed by
52149 // all of the values for the second threadVal for each thread in the block
@@ -102,6 +199,7 @@ __device__ void reduceNValuesInBlock(T *smem,
102199 }
103200 }
104201 }
202+ #endif
105203}
106204
107205// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
0 commit comments