Skip to content

Commit 24e0bd3

Browse files
authored
[JS/WebGPU] Support Log operator (#17045)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 289600b commit 24e0bd3

File tree

6 files changed

+15
-1
lines changed

6 files changed

+15
-1
lines changed

js/web/docs/webgpu-operators.md

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Do not modify directly.*
4343
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
4444
| LayerNormalization | ai.onnx(17+) | |
4545
| LeakyRelu | ai.onnx(6-15,16+) | |
46+
| Log | ai.onnx(6-12,13+) | |
4647
| MatMul | ai.onnx(1-12,13+) | |
4748
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation |
4849
| MemcpyFromHost | ai.onnx(1+) | |

js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
6161
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
6262
['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
6363
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
64+
['Log', [unaryOps.log]],
6465
['MatMul', [matMul]],
6566
// TODO: support new attributes for MaxPool-8 and MaxPool-10
6667
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],

js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts

+4
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,7 @@ export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttrib
231231
`const thresholded_relu_alpha_: vec4<f32> = vec4<f32>(${attributes.alpha});`, attributes.cacheKey));
232232
return 0;
233233
};
234+
235+
export const log = (context: ComputeContext): void => {
236+
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Log', 'log'));
237+
};

js/web/test/suite-test-list.jsonc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,7 @@
13391339
"global-average-pool.jsonc",
13401340
//"greater.jsonc",
13411341
//"less.jsonc",
1342-
//"log.jsonc",
1342+
"log.jsonc",
13431343
//"matmul.jsonc", // <--- some tests fail (when input is 3D/4D/5D)
13441344
"mul.jsonc",
13451345
//"neg.jsonc",

onnxruntime/core/providers/js/js_execution_provider.cc

+4
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
9595
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf);
9696
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid);
9797
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid);
98+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log);
99+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log);
98100

99101
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, Sin);
100102
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, Cos);
@@ -319,6 +321,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
319321
KERNEL_CREATE_INFO(13, Erf),
320322
KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid),
321323
KERNEL_CREATE_INFO(13, Sigmoid),
324+
KERNEL_CREATE_INFO_VERSIONED(6, 12, Log),
325+
KERNEL_CREATE_INFO(13, Log),
322326

323327
KERNEL_CREATE_INFO(7, Sin),
324328
KERNEL_CREATE_INFO(7, Cos),

onnxruntime/core/providers/js/operators/unary.cc

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid)
5656
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid)
5757
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid)
5858

59+
JSEP_KERNEL_IMPL(Log, Log)
60+
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, float, Log)
61+
JSEP_ELEMENTWISE_KERNEL(Log, 13, float, Log)
62+
5963
JSEP_KERNEL_IMPL(Sin, Sin)
6064
JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin)
6165

0 commit comments

Comments
 (0)