Skip to content

Faster Reductions using Coalesced Reads #12

@THargreaves

Description

@THargreaves

I wanted to share a quick demo of a way to make reductions over the columns of matrices substantially faster by ensuring reads from global memory are coalesced. This is very much a prototype where I've made some assumptions about the input matrix dimensions so I didn't have to worry about masking certain threads, but the idea would hold for general dimensions and be just as fast.

I plan to come back and formalise this approach when I have time, but I wanted to throw it out early to get feedback and start some discussion (as this sort of approach applies more generally to all kernel designs).

The core idea is that for simple reductions, the fundamental bottleneck is memory bandwidth, not compute.

On my RTX 4090 for example, the memory bandwidth is 1008 GB/s and the compute bandwidth is 82.6 TFLOPS. Focusing on the square D x D matrix case, a reduction must read and write D(D+1) * 4 bytes of Float32. Assuming op requires F FLOPs, the total amount of compute is roughly F * D(D - 1) assuming perfect parallelism. Assuming D is fairly large, the implied theoretical performance of these two bounds only cross when F ~ 328. Most reductions are going to be far from this point, and so the kernels are entirely memory bound.

For this reason, it is critical that all writes/reads involving global memory are coalesced so we can obtain the maximum throughput.

Going forwards I will focus on the case where _reduce_nd_by_thread is used (one thread per output element).

When reducing rows, coalescing occurs naturally and so near theoretically-optimal performance is achieved. For column reductions, reads are scattered leading to a 3–4x slow down (I'm actually surprised it isn't more!).

It's possible to get around this by using a bit of shared memory to temporarily store tiles of the input matrix that are read in a coalesced fashion. Below I provide a kernel column_reduction_kernel! which implements this in a naive way. By having 32 x 32 tiles, each associated with one warp, synchronisation is not required.

Admittedly, this naive approach puts a fair amount of pressure on shared memory. This is why the block size is 128 threads (else it wouldn't fit!). This can be relieved substantially by replacing the 32 x 32 tile with a 32 x K tile and having 32 / K threads work on each column. This is obviously not optimal in terms of compute-parallelism but this doesn't matter since the compute cost is negligible for all but the most expensive ops.

Even with this restriction of block size and shared memory pressure, the kernel achieves the same optimal speed as row reduction.

MRE and output below.

using AcceleratedKernels
using KernelAbstractions
using BenchmarkTools
using CUDA

using AcceleratedKernels: i16
using KernelAbstractions: synchronize

T = Float32
A_wide = CUDA.rand(T, 2^8, 2^18);
A_tall = CUDA.rand(T, 2^18, 2^8);
op = Base.add_sum
init = zero(T)
dst_wide = similar(A_wide, 1, size(A_wide, 2));
dst_tall = similar(A_tall, size(A_tall, 1), 1);

io = IOContext(stdout)
backend = get_backend(A_wide)
block_size = 1024

@kernel inbounds = true cpu = false function existing_kernel!(@Const(src), dst, op, init, dims)

    # One thread per output element, when there are more outer elements than in the reduced dim
    # e.g. reduce(+, rand(3, 1000), dims=1) => only 3 elements in the reduced dim
    src_sizes = size(src)
    src_strides = strides(src)
    dst_sizes = size(dst)
    dst_strides = strides(dst)

    output_size = length(dst)
    reduce_size = src_sizes[dims]

    ndims = length(src_sizes)

    N = @groupsize()[1]

    # NOTE: for many index calculations in this library, computation using zero-indexing leads to
    # fewer operations (also code is transpiled to CUDA / ROCm / oneAPI / Metal code which do zero
    # indexing). Internal calculations will be done using zero indexing except when actually
    # accessing memory. As with C, the lower bound is inclusive, the upper bound exclusive.

    # Group (block) and local (thread) indices
    iblock = @index(Group, Linear) - 0x1
    ithread = @index(Local, Linear) - 0x1

    tid = ithread + iblock * N

    # Each thread handles one output element
    tid = ithread + iblock * N
    if tid < output_size

        # # Sometimes slightly faster method using additional memory with
        # # output_idx = @private typeof(iblock) (ndims,)
        # tmp = tid
        # KernelAbstractions.Extras.@unroll for i in ndims:-1:1
        #     output_idx[i] = tmp ÷ dst_strides[i]
        #     tmp = tmp % dst_strides[i]
        # end
        # # Compute the base index in src (excluding the reduced axis)
        # input_base_idx = 0
        # KernelAbstractions.Extras.@unroll for i in 1:ndims
        #     i == dims && continue
        #     input_base_idx += output_idx[i] * src_strides[i]
        # end

        # Compute the base index in src (excluding the reduced axis)
        input_base_idx = typeof(ithread)(0)
        tmp = tid
        KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16
            if i != dims
                input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i]
            end
            tmp = tmp % dst_strides[i]
        end

        # Go over each element in the reduced dimension; this implementation assumes that there
        # are so many outer elements (each processed by an independent thread) that we afford to
        # loop sequentially over the reduced dimension (e.g. reduce(+, rand(3, 1000), dims=1))
        res = init
        for i in 0x0:reduce_size-0x1
            src_idx = input_base_idx + i * src_strides[dims]
            res = op(res, src[src_idx+0x1])
        end
        dst[tid+0x1] = res
    end
end

println("\n#### Existing wide reduction ####\n")
blocks = (size(dst_wide, 2) + block_size - 1) ÷ block_size
kernel! = existing_kernel!(backend, block_size)
# Validate
kernel!(A_wide, dst_wide, op, init, 1, ndrange=(blocks * block_size,))
synchronize(backend)
println("Valid: ", dst_wide  sum(A_wide, dims=1))
res = @benchmark begin
    kernel!(A_wide, dst_wide, op, init, 1, ndrange=(1024 * block_size))
    synchronize(backend)
end
show(io, "text/plain", res)

println("\n#### Existing tall reduction ####\n")
blocks = (size(dst_tall, 1) + block_size - 1) ÷ block_size
kernel!(A_tall, dst_tall, op, init, 2, ndrange=(blocks * block_size,))
synchronize(backend)
println("Valid: ", dst_tall  sum(A_tall, dims=2))
res = @benchmark begin
    kernel!(A_tall, dst_tall, op, init, 2, ndrange=(1024 * block_size))
    synchronize(backend)
end
show(io, "text/plain", res)

@kernel inbounds = true cpu = false function column_reduction_kernel!(@Const(src), dst, op, init)
    # Fixed parameters — chosen to meet shared memory constraints
    # TODO: generalise by having multiple threads compute each column
    BLOCK_SIZE = 128
    TILE_DIM = 32
    NUM_WARPS = 4

    # One 32x32 tile of shared memory per warp
    tiles = @localmem eltype(src) (TILE_DIM, TILE_DIM, NUM_WARPS)

    warp_id = (@index(Local, Linear) - 0x1) ÷ TILE_DIM     # Which warp (0-3)
    lane_id = (@index(Local, Linear) - 0x1) % TILE_DIM     # Position in warp (0-31)
    global_col = @index(Group, Linear) - 0x1               # Which block of columns

    src_height = size(src, 1)
    src_width = size(src, 2)

    # Each thread is responsible for one column in the output
    global_col_idx = global_col * BLOCK_SIZE + (@index(Local, Linear) - 0x1)

    if global_col_idx < src_width
        result = init

        # Process input in TILE_DIM x TILE_DIM tiles
        for tile_start = 0:TILE_DIM:src_height-1
            # Load tile into shared memory with coalesced reads (one column at a time)
            KernelAbstractions.Extras.@unroll for i = 0:TILE_DIM-1
                row_idx = tile_start + lane_id
                col_idx = global_col * BLOCK_SIZE + warp_id * TILE_DIM + i
                if row_idx < src_height
                    tiles[lane_id+1, i+1, warp_id+1] = src[row_idx+1, col_idx+1]
                end
            end
            # No sync needed since no communication between warps

            # Reduce along rows for this thread's column using shared memory
            KernelAbstractions.Extras.@unroll for i = 0:TILE_DIM-1
                result = op(result, tiles[i+1, lane_id+1, warp_id+1])
            end
        end

        # Write result to global memory (naturally coalesced)
        dst[1, global_col_idx+1] = result
    end
end

println("\n#### Coalesced column reduction ####\n")
col_block_size = 128
blocks = (size(A_wide, 2) + col_block_size - 1) ÷ col_block_size
kernel! = column_reduction_kernel!(backend, col_block_size)
# Validate
kernel!(A_wide, dst_wide, op, init, ndrange=(blocks * col_block_size,))
synchronize(backend)
println("Valid: ", dst_wide  sum(A_wide, dims=1))
res = @benchmark begin
    kernel!(A_wide, dst_wide, op, init, ndrange=(blocks * col_block_size))
    synchronize(backend)
end
show(io, "text/plain", res)
#### Existing wide reduction ####

Valid: true
BenchmarkTools.Trial: 4623 samples with 1 evaluation.
 Range (min … max):  1.069 ms … 1.435 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     1.079 ms             ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.080 ms ± 7.359 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

              ▁▂▄▅▆▇██▇▅▆▅▂▂▁                                
  ▂▂▂▂▃▃▄▄▅▆▆█████████████████▇▆▇▆▅▅▄▅▄▃▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂ ▄
  1.07 ms        Histogram: frequency by time        1.1 ms <

 Memory estimate: 1.62 KiB, allocs estimate: 58.
#### Existing tall reduction ####

Valid: true
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  335.978 μs … 419.454 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     342.958 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   342.409 μs ±   2.824 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                ▂▁           ▁ ▁   ▁▁▄▄▆▆▅▇█▆▅▃▂                 
  ▁▁▁▂▂▃▄▄▅▄▅▆▇████▇▆▆▅▅▅▅▅▇██████▇█████████████▇▅▃▄▂▂▂▂▂▂▂▁▁▁▁ ▄
  336 μs           Histogram: frequency by time          348 μs <

 Memory estimate: 1.62 KiB, allocs estimate: 58.
#### Coalesced column reduction ####

Valid: true
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  306.203 μs … 369.769 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     308.687 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   309.001 μs ±   1.745 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

        ▁▂▅▇██▇▇▅▅▂▁  ▁▂▁▂▃▂▁             ▁ ▂ ▁                  
  ▁▁▂▄▅▆██████████████████████▇▇▆▅▄▅▄▅▇▆████████▇▆▅▄▃▃▂▂▂▁▂▁▁▁▁ ▅
  306 μs           Histogram: frequency by time          313 μs <

 Memory estimate: 1.52 KiB, allocs estimate: 57.

For multidimensional arrays (say 5 dims for the examples), the same ideas hold:

  • Reducing over "trailing" dimensions, e.g. dims = (3, 4, 5) is naturally coalesced
  • Reducing over "leading" dimensions, e.g. dims = (1, 2) can use this same approach
  • Reducing over "interior" dimensions, e.g. dims = (3, 4) can be handled by an adaptation of the current approach.

The only thing I'm not sure how to handle is dims = (2, 4). I think this might require two kernel calls, one reducing each dimension.

Eager to hear any thoughts on this topic!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions