@@ -118,7 +118,11 @@ __device__ T reduceBlock(T* smem,
118118
119119// Block-wide reduction where each thread locally reduces N
120120// values before letting a single warp take over - assumes
121- // threadVals is in registers, not shared memory
121+ // threadVals is in registers, not shared memory. Note that
122+ // numVals in this case is the number of values in the overall
123+ // reduction, i.e. if there are 512 threads with N=2, and say
124+ // there are 768 elements in the input block, then numVals is 768,
125+ // not, say, 384 (i.e. 768 / N=2)
122126template <typename T, typename ReduceOp, int N>
123127__device__ T reduceBlockWithNThreadLocalReductions (T *smem,
124128 T threadVals[N],
@@ -135,7 +139,7 @@ __device__ T reduceBlockWithNThreadLocalReductions(T *smem,
135139 local = reduceOp (local, next);
136140 }
137141
138- return reduceBlock<T, ReduceOp>(smem, blockDim . x < numVals ? blockDim . x : numVals , local, reduceOp, init);
142+ return reduceBlock<T, ReduceOp>(smem, THCCeilDiv ( numVals, N) , local, reduceOp, init);
139143}
140144
141145// Make sure the given tensor doesn't have too many dimensions
0 commit comments