Skip to content

Commit f1d790c

Browse files
authored
[webgpu] fix LayerNorm with empty input (#24244)
### Description This PR fixes test case `CudaKernelTest.LayerNorm_NullInput`, in which the input is 0-sized for LayerNorm. `context.Output()` need to be called before returning.
1 parent d71aa4d commit f1d790c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

onnxruntime/core/providers/webgpu/nn/layer_norm.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,6 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
7474

7575
const auto x_shape = x->Shape();
7676

77-
if (x_shape.Size() == 0) {
78-
return Status::OK();
79-
}
80-
8177
const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
8278

8379
const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions());
@@ -109,6 +105,10 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
109105
auto* mean = context.Output(1, mean_shape);
110106
auto* inv_std_dev = context.Output(2, mean_shape);
111107

108+
if (x_shape.Size() == 0) {
109+
return Status::OK();
110+
}
111+
112112
LayerNormProgram program{bias != nullptr, is_fp16, simplified, mean != nullptr, inv_std_dev != nullptr};
113113

114114
program

0 commit comments

Comments
 (0)