Skip to content

Use linear-indexing broadcast kernel when possible #520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 8, 2024

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Mar 5, 2024

Attempt to re-land #454, this time using a slightly nicer implementation.

It hasn't fundamentally changed though, so should run into the same issues. Let's do this carefully.

The motivation is also unchanged: on certain platforms, like Metal.jl, the integer divisions required to go from a linear hardware index to a cartesian one for indexing the input/output containers is extremely expensive. By using static iteration bounds, the compiler can replace the idiv with a series of bitshifts. This improves the performance of broadcast by 3-4x on those platforms.

cc @maxwindiff

@maleadt
Copy link
Member Author

maleadt commented Mar 5, 2024

This looks to be working well, so tagging people who ran into issues before: @ToucheSir and @chengchingwen. Note that this will still cause additional compilation, i.e. every time the size of any container involved in a broadcast changes, but I'm curious about which workloads would trigger that (once in the steady-state application regime, of course).

@maleadt maleadt force-pushed the tb/static_cartesian_indices branch from fcc80ce to 90fb573 Compare March 5, 2024 09:02
@ToucheSir
Copy link

I had a look back through the CI failure on the Flux side. Apparently this call was the one that failed:

   [15] broadcast(::typeof(+), ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
      @ Base.Broadcast ./broadcast.jl:821

But that's strange, because surely broadcasting + was already tested by GPUArrays + CUDA.jl? Anyhow, I doubt this will cause any problems for FluxML as long as elementwise broadcasting of binary ops still work across the board.

@maleadt
Copy link
Member Author

maleadt commented Mar 5, 2024

Looking back at the CUDA.jl CI logs, there seemed to be some issue with printing too, is why I added a show method here. I'm not sure whether that was the cause of an issue, or whether it was just masking an actual error in CI...

@maleadt
Copy link
Member Author

maleadt commented Mar 6, 2024

I tried testing Transformers.jl, but that seems not possible right now (see chengchingwen/Transformers.jl#153 and linked PRs in NeuralAttentionlib.jl).

@maleadt
Copy link
Member Author

maleadt commented Mar 6, 2024

One alternative would be that we expose 1d/2d/3d indices and only generate 4 broadcast kernels. I'll experiment with that, as it would lead to far fewer kernels being compiled (but the fact that the bounds aren't fully statically known may come at a cost again).

Given #451 the above would also mean that KA.jl would need to support 1d/2d/3d indices, so cc @vchuravy.

@maleadt
Copy link
Member Author

maleadt commented Mar 6, 2024

... or, I should probably just confine this optimization to Metal.jl...

@chengchingwen
Copy link
Contributor

i.e. every time the size of any container involved in a broadcast changes, but I'm curious about which workloads would trigger that (once in the steady-state application regime, of course).

This, unfortunately, happens a lot when doing sequence generation inference with transformer models. It might also happen during training but can be avoided with padding.

@maleadt
Copy link
Member Author

maleadt commented Mar 6, 2024

OK, good to know. I have an alternative in JuliaGPU/Metal.jl#304, relying on hadware indices instead. That will only accelerate 2d and 3d broadcasts though, so it's a trade-off.

@maleadt maleadt changed the title Introduce StaticCartesianIndices to eliminate expensive integer divisions. Use linear-indexing broadcast kernel when possible Mar 6, 2024
@chengchingwen
Copy link
Contributor

I think we might be able to port the algorithms used in libdivide to implement a new CartesianIndices without integer division. It's similar to the method used in StaticCartesian.jl but without requiring the divisor in compile-time.

@maleadt maleadt merged commit e4d40ea into master Mar 8, 2024
@maleadt maleadt deleted the tb/static_cartesian_indices branch March 8, 2024 10:56
@maleadt
Copy link
Member Author

maleadt commented Mar 8, 2024

I only noticed significant impact of the idiv on Metal.jl, so I've opted to move the specialization to Metal.jl (forcing static bounds when a specific broadcast shape is used more than 10 times).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants