Skip to content

Commit a82ee44

Browse files
Removed code dividing offset by the components in the shader code.
1 parent b559ec9 commit a82ee44

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Status ConvTranspose2DProgram::GenerateShaderCode(ShaderHelper& shader) const {
4545
}
4646
} else {
4747
if (is_channels_last_) {
48-
ss << "let dy_offset = " << dy.IndicesToOffset("dy_indices_t(batch, idyR, idyC, inputChannel)") << " / " << a_components_ << ";\n"
48+
ss << "let dy_offset = " << dy.IndicesToOffset("dy_indices_t(batch, idyR, idyC, inputChannel)") << ";\n"
4949
<< "let xValue = " << dy.GetByOffset("dy_offset") << ";\n";
5050
} else {
5151
ss << "let xValue = " << dy.GetByIndices("dy_indices_t(batch, inputChannel, idyR, idyC)") << ";\n";
@@ -57,7 +57,7 @@ Status ConvTranspose2DProgram::GenerateShaderCode(ShaderHelper& shader) const {
5757
} else {
5858
for (uint32_t i = 0; i < a_components_; ++i) {
5959
ss << "let w_indices = w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel + " << i << ", wOutChannel);\n"
60-
<< "let w_offset = " << w.IndicesToOffset("w_indices") << " / " << b_components_ << ";\n"
60+
<< "let w_offset = " << w.IndicesToOffset("w_indices") << ";\n"
6161
<< "let wValue = " << w.GetByOffset("w_offset") << ";\n"
6262
<< "dotProd = dotProd + xValue[" << i << "] * wValue;\n";
6363
}
@@ -132,8 +132,8 @@ Status ConvTranspose2DProgram::GenerateShaderCode(ShaderHelper& shader) const {
132132
if (pack_input_as4_) {
133133
shader.MainFunctionBody() << " let dy_indices = dy_indices_t(batch, idyR, idyC, inputChannels);\n"
134134
<< " let w_indices = w_indices_t(u32(wRPerm), u32(wCPerm, inputChannel, wOutChannel);\n"
135-
<< " var x_offset = " << dy.IndicesToOffset("dy_indices") << " / " << a_components_ << ";\n"
136-
<< " var w_offset = " << w.IndicesToOffset("w_indices") << " / " << b_components_ << ";\n";
135+
<< " var x_offset = " << dy.IndicesToOffset("dy_indices") << ";\n"
136+
<< " var w_offset = " << w.IndicesToOffset("w_indices") << ";\n";
137137
}
138138

139139
shader.MainFunctionBody() << " for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + " << (pack_input_as4_ ? 4 : a_components_) << ") {\n"
@@ -145,7 +145,7 @@ Status ConvTranspose2DProgram::GenerateShaderCode(ShaderHelper& shader) const {
145145
<< " }\n"
146146
<< " wR = wR + uniforms.strides.x - 1;\n"
147147
<< "}\n"
148-
<< "let value = dotProd" << (has_bias_ ? " + bias[d1 / " + std::to_string(components_) + "]" : "") << ";\n"
148+
<< "let value = dotProd" << (has_bias_ ? " + bias[d1]" : "") << ";\n"
149149
<< output.SetByOffset("global_idx", "value") << "\n";
150150
return Status::OK();
151151
}
@@ -169,7 +169,7 @@ ConvTranspose2DProgram CreateConvTranspose2DProgram(const std::vector<const Tens
169169
input_shape_vector[input_shape_vector.size() - 1] /= a_components;
170170
TensorShape reduced_input_shape(input_shape_vector);
171171
InlinedVector<int64_t> weight_shape_vector = weight_shape.AsShapeVector();
172-
weight_shape_vector[weight_shape_vector.size() - 1] /= components;
172+
weight_shape_vector[weight_shape_vector.size() - 1] /= b_components;
173173
TensorShape reduced_weight_shape(weight_shape_vector);
174174
InlinedVector<int64_t> output_shape_vector = output_shape.AsShapeVector();
175175
output_shape_vector[output_shape_vector.size() - 1] /= components;

0 commit comments

Comments
 (0)