Skip to content

Faster (still slow) fallback matrix multiplication #590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

christiangnrd
Copy link
Member

@christiangnrd christiangnrd commented Apr 13, 2025

Taken from the KernelAbstractions.jl performant matmul example.

I had to make a few changes, such as using unsafe_indices, since the algorithm itself does the bounds checking, and I was getting wrong results until I added that.

I also made it so I and J are only fetched once. Not sure if the old way is outdated or to prevent a bug I didn't encounter. Edit: Guess i found out why that was there. Why is it only necessary for some backends and why is the other way working for nightly?

Finally, I made tile size 16 instead of 32 since it cannot be set dynamically, and Metal does not always have 1024 (32*32) threads per threadgroup available.

@maleadt
Copy link
Member

maleadt commented Apr 14, 2025

I had to make a few changes, such as using unsafe_indices, since the algorithm itself does the bounds checking, and I was getting wrong results until I added that.

Oof, that's bad, and unexpected. cc @vchuravy

Guess i found out why that was there.

Care to elaborate?

Any performance numbers?

@christiangnrd
Copy link
Member Author

christiangnrd commented Apr 14, 2025

I had to make a few changes, such as using unsafe_indices, since the algorithm itself does the bounds checking, and I was getting wrong results until I added that.

Oof, that's bad, and unexpected. cc @vchuravy

To reproduce, you can apply this patch:

Patch
diff --git a/src/host/linalg.jl b/src/host/linalg.jl
index b59598f..2e51d9f 100644
--- a/src/host/linalg.jl
+++ b/src/host/linalg.jl
@@ -326,7 +326,7 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
 end

 # XXX: figure out how to do dynamically
-MAX_TILE_DIM = 16
+MAX_TILE_DIM = 2 # THIS CHANGE MADE TO SIMPLIFY MWE OUTPUT

 ## matrix multiplication
 # legacy method
@@ -346,7 +346,7 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
         return fill!(C, zero(R))
     end

-    @kernel unsafe_indices=true function coalesced_matmul_kernel!(
+    @kernel function coalesced_matmul_kernel!(
             output, @Const(input1), @Const(input2), N, Q, M,
             ::Val{BANK} = Val(1),
         ) where {BANK}
@@ -408,7 +408,7 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
         end
     end

-    coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C)))
+    coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=size(C))
     C
 end
 function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}

And then when run using Metal/CUDA, it gives:

julia> using Metal, GPUArrays; a = Metal.ones(3,3); b = Metal.ones(3,3); c = Metal.zeros(3,3); GPUArrays.generic_matmatmul!(c,a,b, true, false)
Precompiling Metal...
  3 dependencies successfully precompiled in 13 seconds. 66 already precompiled.
3×3 MtlMatrix{Float32, Metal.PrivateStorage}:
 3.0  3.0   2.0
 3.0  3.0   2.0
 2.0  2.0  16.6155

# With CUDA
julia> using CUDA, GPUArrays; a = CUDA.ones(3,3); b = CUDA.ones(3,3); c = CUDA.zeros(3,3); GPUArrays.generic_matmatmul!(c,a,b, true, false)
Precompiling CUDA...
  3 dependencies successfully precompiled in 38 seconds. 96 already precompiled.
3×3 CuArray{Float32, 2, CUDA.DeviceMemory}:
 3.0  3.0  2.0
 3.0  3.0  2.0
 2.0  2.0  2.0

It seems to not be broken with JLArrays, and with CUDA it seems less broken in that it results in close-to-integer-valued results in the other 3 quadrants. This happens on KA 0.9.34, I haven't tested with master branch.

Guess i found out why that was there.

Care to elaborate?

Yes sorry about that. CI was showing that on some platforms, I and J were no longer in scope/defined after an @synchronize call. Without looking into it for this case specifically, I assume it has to do with the code transformations the @kernel macro does.

Any performance numbers?

It seems to be at least as fast and up to 4-5x faster than the naive algorithm.

Linux Ryzen 3700X with RTX 3060 (note the different y-axes in the bottom row):
bench_all_1_3060

M2 Max 30 core:
bench_all_1

@christiangnrd
Copy link
Member Author

christiangnrd commented Apr 19, 2025

Based on JuliaGPU/KernelAbstractions.jl#590 passing tests, maybe this should wait until GPUArrays supports KA v0.10

@christiangnrd christiangnrd force-pushed the fastmatmul branch 2 times, most recently from 3892d1b to 1499c12 Compare April 25, 2025 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants