Skip to content

[webgpu] Use 64 as the workgroup size of DP4AMatMulQuantize #24129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 80 additions & 20 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,82 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddOutput("output", ShaderUsage::UseUniform);
shader.AddOutput("scales", ShaderUsage::UseUniform);
shader.AdditionalImplementation() << R"ADDNL_FN(
var<workgroup> a_values : array<array<input_a_value_t, 32>, 2>;
var<workgroup> max_values : array<input_a_value_t, 4>;

fn readInput(offset: u32) -> input_a_value_t
{
if (offset >= uniforms.output_size) {
return input_a_value_t(0);
}
return input_a[offset];
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
var local_a : array<vec4<input_a_element_t>, 32>;
var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
for (var idx:u32=0;idx<32;idx+=1)
{
local_a[idx] = input_a[workgroup_idx*32 + idx];
max_value = max(max_value, abs(local_a[idx]));
}
var scale = max(max_value.x, max_value.y);
scale = max(scale, max_value.z);
scale = max(scale, max_value.w);
for (var idx:u32=0;idx<32;idx+=1)
{
output[workgroup_idx*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
}
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
scales[workgroup_idx] = scale/127;
)MAIN_FN";
if (sg_size == 32) {
let local_a = readInput(global_idx);
let max_val = subgroupMax(abs(local_a));
if (global_idx >= uniforms.output_size) {
return;
}
let max_temp = max(max_val.xy, max_val.zw);
let scale = max(max_temp[0], max_temp[1]);
let norm_a = local_a/scale;
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
if (local_idx % 32 == 0)
{
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
scales[workgroup_idx * 2 + local_idx / 32] = scale/127;
}
} else if (sg_size == 16) {
let local_a = readInput(global_idx);
max_values[local_idx / 16] = subgroupMax(abs(local_a));
if (global_idx >= uniforms.output_size) {
return;
}
var max_val = input_a_value_t(0);
if (local_idx < 32) {
max_val = max(max_values[0], max_values[1]);
} else {
max_val = max(max_values[2], max_values[3]);
}
let max_temp = max(max_val.xy, max_val.zw);
let scale = max(max_temp[0], max_temp[1]);
let norm_a = local_a/scale;
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
if (local_idx % 32 == 0)
{
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
scales[workgroup_idx * 2 + local_idx / 32] = scale/127;
}
} else {
let local_row = local_idx / 32u;
let local_col = local_idx % 32u;
a_values[local_row][local_col] = readInput(global_idx);
workgroupBarrier();

if (global_idx >= uniforms.output_size) {
return;
}

var max_val = input_a_value_t(0);
for (var i = 0u; i < 32u; i++)
{
max_val = max(max_val, abs(a_values[local_row][i]));
}
let max_temp = max(max_val.xy, max_val.zw);
let scale = max(max_temp[0], max_temp[1]);
let norm_a = a_values[local_row][local_col]/scale;
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
if (local_col == 0u)
{
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
scales[workgroup_idx * 2 + local_row] = scale/127;
}
}
)MAIN_FN";
return Status::OK();
}

Expand Down Expand Up @@ -386,15 +444,17 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor

constexpr uint32_t kBlockSizeA = 128;
DP4AMatMulQuantizeProgram quantize_program;
quantize_program.SetWorkgroupSize(1);
quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1);
quantize_program.SetWorkgroupSize(64);
uint32_t tile_size = 64 * kVec4Components;
quantize_program.SetDispatchGroupSize((M * K + tile_size - 1) / tile_size, 1, 1);
TensorShape a_quant_shape{1, M, K / kU32Components};
Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), a_quant_shape);
TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA});
Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims);
quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)}})
.AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1},
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}});
{&a_scale, ProgramTensorMetadataDependency::Rank, 1}})
.AddUniformVariable({M * K / kU32Components});
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));

if (M < min_M_for_tile_optimization) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram
public:
DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});
};

class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
Expand Down
Loading