Skip to content

[webgpu] Flash attention for generation #23808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 8, 2025
Merged

[webgpu] Flash attention for generation #23808

merged 18 commits into from
Apr 8, 2025

Conversation

qjia7
Copy link
Contributor

@qjia7 qjia7 commented Feb 25, 2025

This PR adds the flash decoding support to optimization the generation speed when the total sequence length is large. Previously, when the total sequence length is big enough, the softmax and softmax * v shaders will become the bottleneck since it only uses limited gpu cores. In this changes, we add the flash decoding support to split the present key/value based on the total sequence length, then do reduce to get the final result.

On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for phi4 static kv cache
On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static kv cache

Side effect of this PR:
It adds two extra buffers to store 1) metadata (max and exp_sum in each split), 2) the splited qkv results with shape [B, N, split_k, H], which increase the memory size.

TODO:
Ideally, there should only be two shaders, which can also reduce the intermediate memory. The computeQKT can be merged into split shader and do the final softmax adjustment in the reduce shader. However, I meet some issues that when the total sequence length exceeds some value, the result will become garbage. Since I can't resolve it in a short time, leave it in as TODO to fix it in future.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Feb 26, 2025
1. Only copy the new kv data for static kv cache
2. Add flash decoding for sequence_length = 1
@qjia7 qjia7 force-pushed the attention_generate_fa branch from 6f6d6d1 to f0424fd Compare March 10, 2025 14:10
@qjia7 qjia7 changed the title [WIP] Flash attention for generation [webgpu] Flash attention for generation Mar 11, 2025
@qjia7 qjia7 requested review from sushraja-msft and guschmue March 11, 2025 13:22
@qjia7 qjia7 marked this pull request as ready for review March 11, 2025 13:22
@qjia7 qjia7 marked this pull request as draft March 19, 2025 10:12
@qjia7 qjia7 marked this pull request as ready for review March 19, 2025 13:26
@guschmue
Copy link
Contributor

can you merge with main?

@qjia7
Copy link
Contributor Author

qjia7 commented Mar 21, 2025

can you merge with main?

Done.

This PR is ready for review. Thanks.

Copy link
Contributor Author

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename valid_new_present_shape to copy_kv_shape to help understand. Thanks for your suggestion.

@qjia7 qjia7 requested a review from sushraja-msft March 28, 2025 02:04
sushraja-msft
sushraja-msft previously approved these changes Apr 4, 2025
guschmue
guschmue previously approved these changes Apr 4, 2025
@guschmue
Copy link
Contributor

guschmue commented Apr 4, 2025

can you merge with main?

@qjia7 qjia7 dismissed stale reviews from guschmue and sushraja-msft via 191cf41 April 7, 2025 06:19
Copy link
Contributor Author

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you merge with main?

Done

@qjia7 qjia7 requested review from sushraja-msft and guschmue April 7, 2025 06:29
@guschmue guschmue removed the request for review from sushraja-msft April 8, 2025 15:23
@guschmue guschmue merged commit 18f91e5 into main Apr 8, 2025
87 of 89 checks passed
@guschmue guschmue deleted the attention_generate_fa branch April 8, 2025 15:28
quic-zhaoxul pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Apr 17, 2025
This PR adds the flash decoding support to optimization the generation
speed when the total sequence length is large. Previously, when the
total sequence length is big enough, the softmax and softmax * v shaders
will become the bottleneck since it only uses limited gpu cores. In this
changes, we add the flash decoding support to split the present
key/value based on the total sequence length, then do reduce to get the
final result.

On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for
phi4 static kv cache
On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static
kv cache

Side effect of this PR:
It adds two extra buffers to store 1) metadata (max and exp_sum in each
split), 2) the splited qkv results with shape [B, N, split_k, H], which
increase the memory size.

TODO:
Ideally, there should only be two shaders, which can also reduce the
intermediate memory. The computeQKT can be merged into split shader and
do the final softmax adjustment in the reduce shader. However, I meet
some issues that when the total sequence length exceeds some value, the
result will become garbage. Since I can't resolve it in a short time,
leave it in as TODO to fix it in future.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants