Skip to content

Commit edf3c71

Browse files
committed
small fix in nthreadlocal; add doc
1 parent 6e7e3ec commit edf3c71

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

lib/THC/THCReduceApplyUtils.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,11 @@ __device__ T reduceBlock(T* smem,
118118

119119
// Block-wide reduction where each thread locally reduces N
120120
// values before letting a single warp take over - assumes
121-
// threadVals is in registers, not shared memory
121+
// threadVals is in registers, not shared memory. Note that
122+
// numVals in this case is the number of values in the overall
123+
// reduction, i.e. if there are 512 threads with N=2, and say
124+
// there are 768 elements in the input block, then numVals is 768,
125+
// not, say, 384 (i.e. 768 / N=2)
122126
template <typename T, typename ReduceOp, int N>
123127
__device__ T reduceBlockWithNThreadLocalReductions(T *smem,
124128
T threadVals[N],
@@ -135,7 +139,7 @@ __device__ T reduceBlockWithNThreadLocalReductions(T *smem,
135139
local = reduceOp(local, next);
136140
}
137141

138-
return reduceBlock<T, ReduceOp>(smem, blockDim.x < numVals ? blockDim.x : numVals, local, reduceOp, init);
142+
return reduceBlock<T, ReduceOp>(smem, THCCeilDiv(numVals, N), local, reduceOp, init);
139143
}
140144

141145
// Make sure the given tensor doesn't have too many dimensions

0 commit comments

Comments
 (0)