Skip to content

Commit 4990147

Browse files
committed
implement warp shuffle based reduction; enable for arch >= 3.0
1 parent edf3c71 commit 4990147

File tree

2 files changed

+118
-2
lines changed

2 files changed

+118
-2
lines changed

lib/THC/THCReduceApplyUtils.cuh

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "THCTensor.h"
99
#include "THCDeviceUtils.cuh"
1010
#include "THCTensorInfo.cuh"
11+
#include "THCAsmUtils.cuh"
1112

1213
// Enum that indicates whether tensor arguments are read/write or
1314
// read-only
@@ -26,8 +27,101 @@ __device__ __forceinline__ IndexType getLinearBlockId() {
2627
// level) with N elements per thread in the block, so we have to use min(numvals,
2728
// max block size) to determine this count.
2829
template <typename T, int N>
29-
int reduceSmemSize(THCState *state, int numVals) {
30-
return THCRoundUp(std::min(numVals, 1024), 32) * N * sizeof(T);
30+
int reduceSmemSize(THCState *state, long numVals) {
31+
// check if we can use a warp shuffle
32+
cudaDeviceProp *props = THCState_getCurrentDeviceProperties(state);
33+
if (props->major >= 3) {
34+
return props->warpSize * N * sizeof(T);
35+
} else {
36+
return THCRoundUp(std::min(numVals, (long) props->maxThreadsPerBlock), (long) props->warpSize) * N * sizeof(T);
37+
}
38+
}
39+
40+
template <typename T>
41+
struct THCWarpUtils {
42+
static __device__ __forceinline__ T shflxor(T val, unsigned int mask) {
43+
return __shfl_xor(val, mask);
44+
}
45+
};
46+
47+
template <>
48+
struct THCWarpUtils<unsigned char> {
49+
static __device__ __forceinline__ unsigned char shflxor(unsigned char val, unsigned int mask) {
50+
return (unsigned char) __shfl_xor((int) val, mask);
51+
}
52+
};
53+
54+
template <>
55+
struct THCWarpUtils<char> {
56+
static __device__ __forceinline__ char shflxor(char val, unsigned int mask) {
57+
return (char) __shfl_xor((int) val, mask);
58+
}
59+
};
60+
61+
template <>
62+
struct THCWarpUtils<short> {
63+
static __device__ __forceinline__ short shflxor(short val, unsigned int mask) {
64+
return (short) __shfl_xor((int) val, mask);
65+
}
66+
};
67+
68+
template <>
69+
struct THCWarpUtils<double> {
70+
static __device__ __forceinline__ double shflxor(double val, unsigned int mask) {
71+
int2 a = *reinterpret_cast<int2*>(&val);
72+
a.x = __shfl_xor(a.x, mask);
73+
a.y = __shfl_xor(a.y, mask);
74+
return *reinterpret_cast<double*>(&a);
75+
}
76+
};
77+
78+
template <>
79+
struct THCWarpUtils<long> {
80+
static __device__ __forceinline__ long shflxor(long val, unsigned int mask) {
81+
int2 a = *reinterpret_cast<int2*>(&val);
82+
a.x = __shfl_xor(a.x, mask);
83+
a.y = __shfl_xor(a.y, mask);
84+
return *reinterpret_cast<long*>(&a);
85+
}
86+
};
87+
88+
template <typename T, typename ReduceOp, int N>
89+
__device__ void warpReduce(T threadVals[N], ReduceOp reduceOp) {
90+
#pragma unroll
91+
for (int mask = 1; mask < warpSize; mask *= 2) {
92+
#pragma unroll
93+
for (int i = 0; i < N; ++i) {
94+
T neighbor = THCWarpUtils<T>::shflxor(threadVals[i], mask);
95+
threadVals[i] = reduceOp(threadVals[i], neighbor);
96+
}
97+
}
98+
}
99+
100+
template <typename T, typename ReduceOp, int N>
101+
__device__ void warpReduceBlock(T *smem, T threadVals[N], int numVals, ReduceOp reduceOp, T init) {
102+
assert(blockDim.x % warpSize == 0);
103+
// First, warps cooperate to reduce values within the warp
104+
warpReduce<T, ReduceOp, N>(threadVals, reduceOp);
105+
int lane = getLaneId();
106+
int warp = threadIdx.x / warpSize;
107+
108+
if (lane == 0) {
109+
110+
#pragma unroll
111+
for (int i = 0; i < N; ++i) {
112+
smem[warp + (i * warpSize)] = threadVals[i];
113+
}
114+
}
115+
__syncthreads();
116+
117+
#pragma unroll
118+
for (int i = 0; i < N; ++i) {
119+
threadVals[i] = (threadIdx.x < (blockDim.x / warpSize)) ? smem[lane + (i * warpSize)] : init;
120+
}
121+
122+
if (warp == 0) {
123+
warpReduce<T, ReduceOp, N>(threadVals, reduceOp);
124+
}
31125
}
32126

33127
// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
@@ -47,6 +141,9 @@ __device__ void reduceNValuesInBlock(T *smem,
47141
return;
48142
}
49143

144+
#if __CUDA_ARCH__ >= 300
145+
warpReduceBlock<T, ReduceOp, N>(smem, threadVals, numVals, reduceOp, init);
146+
#else
50147
// We store each of the N values contiguously, so if N = 2, all values for
51148
// the first threadVal for each thread in the block are stored followed by
52149
// all of the values for the second threadVal for each thread in the block
@@ -102,6 +199,7 @@ __device__ void reduceNValuesInBlock(T *smem,
102199
}
103200
}
104201
}
202+
#endif
105203
}
106204

107205
// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will

lib/THC/THCTensorMode.cuh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,31 @@ struct ModeUnsignedBoolPair {
5656
bool flag;
5757
};
5858

59+
template <>
60+
struct THCWarpUtils<ModeUnsignedBoolPair> {
61+
static __device__ __forceinline__ ModeUnsignedBoolPair shflxor(ModeUnsignedBoolPair val, unsigned int mask) {
62+
val.val = __shfl_xor(val.val, mask);
63+
val.flag = (bool) __shfl_xor((int) val.flag, mask);
64+
return val;
65+
}
66+
};
67+
5968
// In the kernel below, we have a common pattern of reducing (unsigned int, unsigned int)
6069
// pairs of data
6170
struct ModeUnsignedPair {
6271
unsigned int val;
6372
unsigned int index;
6473
};
6574

75+
template <>
76+
struct THCWarpUtils<ModeUnsignedPair> {
77+
static __device__ __forceinline__ ModeUnsignedPair shflxor(ModeUnsignedPair val, unsigned int mask) {
78+
val.val = __shfl_xor(val.val, mask);
79+
val.index = __shfl_xor(val.index, mask);
80+
return val;
81+
}
82+
};
83+
6684
template <typename T>
6785
struct MaxReduceOp {
6886
__host__ __device__ inline T operator()(const T& a, const T& b) {

0 commit comments

Comments
 (0)