-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA does too many un-fused transposes #16914
Comments
Yes, you could use logging in *thunks files, but I'm not sure it will help you, as at that point as you've pointed out fusion decisions have been done.
You could start by dumping HLO after every pass with |
Would you know how I can do that? In the first instance, I'm looking to be able to clearly see that the same memory is being read multiple times by different kernel calls. Is the some flag for the runtime that will have it (eg) log stream executions along with argument shapes?
Yes, I've been looking over this, but it's of lot of manual effort matching up the passes due to renaming et al (oh, and my limited familiarity! ). Is there any documentation (beyond the source code) for things like the I guess I'm trying to see the memory read/write graph at the CUDA kernel boundaries. Is there some existing way to see this? If not, I guess I'll try to put together a script to generate a dot graph. |
You could look at various *thunk files, and inside the
Yes, that would be buffer assignment. You'll only see offsets there (and you can match them with instruction names in after-optimizations HLO) as the actual memory is allocated at runtime. |
fwiw, you can see how to parse the buffer assignment format here: https://github.com/openxla/xla/blob/main/xla/tools/driver.cc#L311 |
(This is running on a Nvidia 4090 GPU, with jax '0.4.31')
I had got that is something like the example below. Here, the depth-wise convolution wants the input to be transposed from [batch, sequence, feature] into [batch, feature, sequence] so that it can apply the convolution along sequence.
The output from the convolution is used 3 times, and XLA generates at least 3 seperate (fused) transposes, each of which does a full read and write of memory. This is very slow and causes sadness.
Unfortunately, this example code doesn't reproduce the problem: The problem seems to be quite sensitive to the surrounding code, and trying to trim it down make most of the issue go away. A screen-grab from the profiled code somewhat shows the issue:
After the convolution is a
loop-transpose_fusion
, and then after the two cutlass gemm kernels, there are twoinput_transpose_fusion
kernels, and then following thekernel__1
is anotherinput_transpose_fusion
.Each of these fusion is doing a full read/write of memory.
My main question is: How can I effectively debug this?
eg. Is there a way to log all the GPU kernel calls along with the argument shapes?
Is there some way to see why the transposes didn't fuse into a single kernel with 1 input and 3 outputs?
module_0071.jit_apply.sm_8.9_gpu_after_optimizations.txt
The text was updated successfully, but these errors were encountered: