Skip to content

[Native WebGPU EP] Add packedQKV and do_rotary attribute support to GroupQueryAttention operator #23386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
fdd5ceb
Added GroupQuerryAttention do_rotary attribute.
satyajandhyala Jan 15, 2025
f6b0222
Added packed QKV and rotary embedding support for GQA
satyajandhyala Jan 16, 2025
ae87526
Fix lint errors.
satyajandhyala Jan 16, 2025
df90ffa
Fixed shader code compilation errors.
satyajandhyala Jan 16, 2025
0704462
more lint stuff
satyajandhyala Jan 16, 2025
177f535
Fixed shader code issues.
satyajandhyala Jan 17, 2025
0b94f10
Added split functionality to unpack packed-QKV.
satyajandhyala Jan 21, 2025
f0d238a
Removed unnecessary uniforms in GeneratePositionIdsProgram
satyajandhyala Jan 21, 2025
1009fc9
Apply split and rotrary embedding before converting input ro BSD to BNSH
satyajandhyala Jan 21, 2025
a4d8482
Fix the input_output_stride for 4-dim input.
satyajandhyala Jan 21, 2025
4d42d06
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Jan 22, 2025
e406b81
Allocate position_ids tensor size/shape even for the first prompt
satyajandhyala Feb 3, 2025
9e38efd
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Feb 5, 2025
de0d4b0
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Feb 12, 2025
0b08117
Fixed the input_output_strides
satyajandhyala Feb 18, 2025
a7328f5
Added is_first_first prompt to the shader that generates position ids…
satyajandhyala Feb 19, 2025
531c6e3
Fixed position_ids generation code.
satyajandhyala Feb 20, 2025
29819ed
Check is_first_prompt and is_subsequence_prompt flags in the c++ code…
satyajandhyala Feb 20, 2025
91e8801
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Feb 20, 2025
6bbef62
lint
satyajandhyala Feb 21, 2025
ff84b7b
Removed unused variable.
satyajandhyala Feb 21, 2025
d4e4f29
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Feb 26, 2025
e468128
Added condition to check do_rotary before call fa2
satyajandhyala Feb 26, 2025
9f2782c
typo
satyajandhyala Feb 26, 2025
b4a6546
Revert changes to rotary embedding code.
satyajandhyala Feb 26, 2025
a80e9f9
Removed packed QKV support in attention.
satyajandhyala Mar 6, 2025
19c48d4
lint
satyajandhyala Mar 6, 2025
5c70cb3
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Mar 6, 2025
3d4e022
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
satyajandhyala Mar 13, 2025
ce3d60b
Replaced gsl::naroow with gsl::narrow_cast
satyajandhyala Mar 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 142 additions & 5 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "contrib_ops/webgpu/bert/attention_common.h"
#include "contrib_ops/webgpu/bert/group_query_attention.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "contrib_ops/webgpu/bert/rotary_embedding.h"
#include "contrib_ops/webgpu/bert/flash_attention.h"

#include "core/providers/webgpu/webgpu_supported_types.h"
Expand All @@ -30,6 +31,117 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPUInput, 6),
GroupQueryAttention);

Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
sh.MainFunctionBody() << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n"
<< " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n"
<< " let index = " << packed_qkv.IndicesGet("packed_qkv_indices", "2") << ";\n"
<< " if (index < uniforms.hidden_size) {\n"
<< " " << query.SetByIndices("packed_qkv_indices", "input_data") << ";\n"
<< " } else if (index < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n"
<< " var key_indices = packed_qkv_indices;\n"
<< " " << key.IndicesSet("key_indices", "2", "u32(index - uniforms.hidden_size)") << ";\n"
<< " " << key.SetByIndices("key_indices", "input_data") << ";\n"
<< " } else {\n"
<< " var val_indices = packed_qkv_indices;\n"
<< " " << value.IndicesSet("val_indices", "2", "u32(index - uniforms.hidden_size - uniforms.kv_hidden_size)") << ";\n"
<< " " << value.SetByIndices("val_indices", "input_data") << ";\n"
<< " }";
return Status::OK();
}

Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) {
SplitPackedQKVProgram program;
auto input_size = packedQKV->Shape().Size();
program
.AddInput({packedQKV, ProgramTensorMetadataDependency::Rank})
.AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, {key, ProgramTensorMetadataDependency::Rank}, {val, ProgramTensorMetadataDependency::Rank}})
.AddUniformVariables({
{static_cast<uint32_t>(params.hidden_size_)},
{static_cast<uint32_t>(params.kv_hidden_size_)},
})
.SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform);
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
sh.MainFunctionBody() << " var pos_id: i32 = 0;\n"
<< " let batch_idx = global_idx / uniforms.sequence_length;\n"
<< " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n"
<< " let seqlen = " << seqlens.GetByOffset("batch_idx") << ";\n";
if (is_first_prompt_) {
sh.MainFunctionBody() << " let total_seqlen = seqlen + 1;\n"
<< " if (sequence_idx < total_seqlen) {\n"
<< " pos_id = sequence_idx;\n"
<< " } else {\n"
<< " pos_id = 1;\n"
<< " }\n"
<< " " << output.SetByOffset("global_idx", "pos_id") << "\n";
} else if (is_subsequent_prompt_) {
sh.MainFunctionBody() << " let total_seqlen = seqlen + 1;\n"
<< " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n"
<< " if (past_seqlen + sequence_idx < total_seqlen) {\n"
<< " pos_id = past_seqlen + sequence_idx;\n"
<< " } else {\n"
<< " pos_id = 1;\n"
<< " }\n"
<< " " << output.SetByOffset("global_idx", "pos_id") << "\n";
} else {
sh.MainFunctionBody() << " if (global_idx < uniforms.batch_size) {\n"
<< " " << output.SetByOffset("global_idx", "seqlen") << "\n"
<< " }\n";
}
return Status::OK();
}

Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) {
GeneratePositionIDsProgram program(params.is_first_prompt_, params.is_subsequent_prompt_);
auto output_size = params.batch_size_ * params.sequence_length_;
program.CacheHint(params.is_first_prompt_, params.is_subsequent_prompt_)
.AddInput({seqlens, ProgramTensorMetadataDependency::Rank})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank})
.AddUniformVariables({{static_cast<uint32_t>(params.batch_size_)}, {static_cast<uint32_t>(params.sequence_length_)}})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) {
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
const auto head_size = params.head_size_;
const auto hidden_size = is_query_input ? params.hidden_size_ : params.kv_hidden_size_;
const TensorShape global_shape({params.batch_size_, params.sequence_length_, hidden_size / head_size, static_cast<int64_t>(head_size - half_rotary_embedding_dim)});
const auto rank = global_shape.NumDimensions();
std::vector<uint32_t> global_dims(rank);
std::vector<uint32_t> global_strides(rank);
for (size_t j = 0; j < rank; ++j) {
global_dims[j] = gsl::narrow_cast<uint32_t>(global_shape[j]);
global_strides[j] = gsl::narrow_cast<uint32_t>(global_shape.SizeFromDimension(j + 1));
}
const auto input_output_strides = std::vector<uint32_t>({gsl::narrow_cast<uint32_t>(input->Shape().SizeFromDimension(1)), gsl::narrow_cast<uint32_t>(hidden_size), gsl::narrow_cast<uint32_t>(head_size), 1});
const auto output_size = gsl::narrow_cast<const uint32_t>(global_shape.Size());

RotaryEmbeddingProgram program(params.rotary_interleaved_);
program
.CacheHint(params.rotary_interleaved_)
.AddInputs({{input, ProgramTensorMetadataDependency::Rank},
{pos_ids, ProgramTensorMetadataDependency::Rank},
{cos_cache, ProgramTensorMetadataDependency::Rank},
{sin_cache, ProgramTensorMetadataDependency::Rank}})
.AddOutput(output)
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{params.scale_},
{gsl::make_span(global_dims)},
{gsl::make_span(global_strides)},
{gsl::make_span(input_output_strides)}})
.AddIndices(TensorShape{1, 1});
return context.RunProgram(program);
}

Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* query = context.Input<Tensor>(0);
const Tensor* key = context.Input<Tensor>(1);
Expand All @@ -41,7 +153,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
const Tensor* cos_cache = context.Input<Tensor>(7);
const Tensor* sin_cache = context.Input<Tensor>(8);

GroupQueryAttentionParameters params;
GroupQueryAttentionParameters params = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
key,
value,
Expand All @@ -57,9 +169,6 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
scale_,
softcap_));
WebgpuAttentionParameters parameters(params);
if (parameters.is_packed_qkv_) {
ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep.");
}
TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size_);
output_shape[1] = static_cast<int64_t>(parameters.sequence_length_);
Expand All @@ -75,11 +184,39 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor* present_value = context.Output(2, present_kv_shape);
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();

if (CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) {
if (!do_rotary_ && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, nullptr /* attention_bias */, output, past_key, present_key, past_value,
present_value, parameters, context);
}

Tensor qSplit;
Tensor kSplit;
Tensor vSplit;
if (parameters.is_packed_qkv_) {
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
ORT_RETURN_IF_ERROR(SplitPackedQKV(context, parameters, query, &qSplit, &kSplit, &vSplit));
parameters.is_packed_qkv_ = false;
query = &qSplit;
key = &kSplit;
value = &vSplit;
}

Tensor qRotary;
Tensor kRotary;
if (do_rotary_) {
qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
kRotary = context.CreateGPUTensor(key->DataType(), key->Shape());
auto pos_ids_shape = TensorShape({parameters.batch_size_, parameters.sequence_length_});
Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType<int64_t>(), pos_ids_shape);
ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlen_k, &pos_ids));
ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, query, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_query_input = */ true));
ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, key, &pos_ids, cos_cache, sin_cache, &kRotary, /* is_query_input = */ false));
query = &qRotary;
key = &kRotary;
}

TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, parameters.head_size_});
TensorShape q_new_shape(q_new_dims);
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,29 @@ namespace webgpu {

using namespace onnxruntime::webgpu;

class GeneratePositionIDsProgram final : public Program<GeneratePositionIDsProgram> {
public:
GeneratePositionIDsProgram(bool is_first_prompt, bool is_subsequent_prompt) : Program{"GeneratePositionIDs"}, is_first_prompt_(is_first_prompt), is_subsequent_prompt_(is_subsequent_prompt) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32});

private:
bool is_first_prompt_;
bool is_subsequent_prompt_;
};

class SplitPackedQKVProgram final : public Program<SplitPackedQKVProgram> {
public:
SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32},
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32});
};

class GroupQueryAttention final : public WebGpuKernel {
public:
GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) {
Expand Down
Loading