-
Notifications
You must be signed in to change notification settings - Fork 84
Use linear-indexing broadcast kernel when possible #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This looks to be working well, so tagging people who ran into issues before: @ToucheSir and @chengchingwen. Note that this will still cause additional compilation, i.e. every time the size of any container involved in a broadcast changes, but I'm curious about which workloads would trigger that (once in the steady-state application regime, of course). |
fcc80ce
to
90fb573
Compare
I had a look back through the CI failure on the Flux side. Apparently this call was the one that failed:
But that's strange, because surely broadcasting + was already tested by GPUArrays + CUDA.jl? Anyhow, I doubt this will cause any problems for FluxML as long as elementwise broadcasting of binary ops still work across the board. |
Looking back at the CUDA.jl CI logs, there seemed to be some issue with printing too, is why I added a |
I tried testing Transformers.jl, but that seems not possible right now (see chengchingwen/Transformers.jl#153 and linked PRs in NeuralAttentionlib.jl). |
One alternative would be that we expose 1d/2d/3d indices and only generate 4 broadcast kernels. I'll experiment with that, as it would lead to far fewer kernels being compiled (but the fact that the bounds aren't fully statically known may come at a cost again). Given #451 the above would also mean that KA.jl would need to support 1d/2d/3d indices, so cc @vchuravy. |
... or, I should probably just confine this optimization to Metal.jl... |
This, unfortunately, happens a lot when doing sequence generation inference with transformer models. It might also happen during training but can be avoided with padding. |
OK, good to know. I have an alternative in JuliaGPU/Metal.jl#304, relying on hadware indices instead. That will only accelerate 2d and 3d broadcasts though, so it's a trade-off. |
I think we might be able to port the algorithms used in libdivide to implement a new |
I only noticed significant impact of the |
Attempt to re-land #454, this time using a slightly nicer implementation.
It hasn't fundamentally changed though, so should run into the same issues. Let's do this carefully.
The motivation is also unchanged: on certain platforms, like Metal.jl, the integer divisions required to go from a linear hardware index to a cartesian one for indexing the input/output containers is extremely expensive. By using static iteration bounds, the compiler can replace the
idiv
with a series of bitshifts. This improves the performance of broadcast by 3-4x on those platforms.cc @maxwindiff