Skip to content

Commit 7ef0ddc

Browse files
authored
[WebGPU EP] If Implementation for WebGPU EP (#24242)
Increases operator covereage for WebGPU EP.
1 parent 4a669fd commit 7ef0ddc

File tree

4 files changed

+107
-6
lines changed

4 files changed

+107
-6
lines changed

onnxruntime/core/framework/session_state.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ Status SessionState::CreateSubgraphSessionState() {
11271127
if (!ep.empty() &&
11281128
ep != kCpuExecutionProvider && ep != kCudaExecutionProvider &&
11291129
ep != kRocmExecutionProvider && ep != kDmlExecutionProvider &&
1130-
ep != kJsExecutionProvider) {
1130+
ep != kJsExecutionProvider && ep != kWebGpuExecutionProvider) {
11311131
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
11321132
// node containing the subgraph it will create whatever state it needs internally.
11331133
continue;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/controlflow/if.h"
5+
6+
using namespace ONNX_NAMESPACE;
7+
using namespace onnxruntime::common;
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
13+
kOnnxDomain,
14+
1, 10,
15+
kWebGpuExecutionProvider,
16+
(*KernelDefBuilder::Create())
17+
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
18+
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
19+
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
20+
If);
21+
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE'
22+
// branches to be the same were relaxed in opset-11
23+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
24+
kOnnxDomain,
25+
11, 12,
26+
kWebGpuExecutionProvider,
27+
(*KernelDefBuilder::Create())
28+
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
29+
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
30+
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
31+
If);
32+
33+
// opset-13 supports sequence type for If's subgraph outputs
34+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
35+
kOnnxDomain,
36+
13, 18,
37+
kWebGpuExecutionProvider,
38+
(*KernelDefBuilder::Create())
39+
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
40+
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
41+
// Support sequence/optional tensors when all WebGPU infra
42+
// (including tests runner) supports it
43+
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
44+
If);
45+
46+
// opset-19 supports float8
47+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
48+
kOnnxDomain,
49+
19, 20,
50+
kWebGpuExecutionProvider,
51+
(*KernelDefBuilder::Create())
52+
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
53+
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
54+
// Support sequence/optional tensors when all WebGPU infra
55+
// (including tests runner) supports it
56+
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
57+
If);
58+
59+
ONNX_OPERATOR_KERNEL_EX(If,
60+
kOnnxDomain,
61+
21,
62+
kWebGpuExecutionProvider,
63+
(*KernelDefBuilder::Create())
64+
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU
65+
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
66+
// Support sequence/optional tensors when all WebGPU infra
67+
// (including tests runner) supports it
68+
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
69+
If);
70+
71+
Status If::Compute(OpKernelContext* ctx) const {
72+
// call the base CPU version.
73+
return onnxruntime::If::Compute(ctx);
74+
}
75+
76+
} // namespace webgpu
77+
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/common/common.h"
8+
#include "core/providers/cpu/controlflow/if.h"
9+
10+
namespace onnxruntime {
11+
namespace webgpu {
12+
13+
// Use the CPU implementation for the logic
14+
class If final : public onnxruntime::If {
15+
public:
16+
If(const OpKernelInfo& info) : onnxruntime::If(info) {}
17+
18+
Status Compute(OpKernelContext* ctx) const override;
19+
};
20+
21+
} // namespace webgpu
22+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

+7-5
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23,
373373
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If);
374374
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If);
375375
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If);
376-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If);
376+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, If);
377+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, If);
377378

378379
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization);
379380
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization);
@@ -700,10 +701,11 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
700701
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Pad)>,
701702
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Pad)>,
702703

703-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If)>,
704-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If)>,
705-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If)>,
706-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If)>,
704+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If)>,
705+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If)>,
706+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If)>,
707+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, If)>,
708+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, If)>,
707709

708710
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
709711
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,

0 commit comments

Comments
 (0)