-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[webgpu] Flash attention for generation #23808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
1. Only copy the new kv data for static kv cache 2. Add flash decoding for sequence_length = 1
6f6d6d1
to
f0424fd
Compare
can you merge with main? |
Done. This PR is ready for review. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename valid_new_present_shape
to copy_kv_shape
to help understand. Thanks for your suggestion.
can you merge with main? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you merge with main?
Done
This PR adds the flash decoding support to optimization the generation speed when the total sequence length is large. Previously, when the total sequence length is big enough, the softmax and softmax * v shaders will become the bottleneck since it only uses limited gpu cores. In this changes, we add the flash decoding support to split the present key/value based on the total sequence length, then do reduce to get the final result. On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for phi4 static kv cache On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static kv cache Side effect of this PR: It adds two extra buffers to store 1) metadata (max and exp_sum in each split), 2) the splited qkv results with shape [B, N, split_k, H], which increase the memory size. TODO: Ideally, there should only be two shaders, which can also reduce the intermediate memory. The computeQKT can be merged into split shader and do the final softmax adjustment in the reduce shader. However, I meet some issues that when the total sequence length exceeds some value, the result will become garbage. Since I can't resolve it in a short time, leave it in as TODO to fix it in future.
This PR adds the flash decoding support to optimization the generation speed when the total sequence length is large. Previously, when the total sequence length is big enough, the softmax and softmax * v shaders will become the bottleneck since it only uses limited gpu cores. In this changes, we add the flash decoding support to split the present key/value based on the total sequence length, then do reduce to get the final result.
On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for phi4 static kv cache
On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static kv cache
Side effect of this PR:
It adds two extra buffers to store 1) metadata (max and exp_sum in each split), 2) the splited qkv results with shape [B, N, split_k, H], which increase the memory size.
TODO:
Ideally, there should only be two shaders, which can also reduce the intermediate memory. The computeQKT can be merged into split shader and do the final softmax adjustment in the reduce shader. However, I meet some issues that when the total sequence length exceeds some value, the result will become garbage. Since I can't resolve it in a short time, leave it in as TODO to fix it in future.