Skip to content

Commit 57dbc70

Browse files
authored
Add Op (aten::masked_scatter) | feat (torchlib) (#2112)
From Gemma3, the error lacks of support is raised. https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/gemma3/modeling_gemma3.py#L1339
1 parent 6be9d18 commit 57dbc70

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -5202,10 +5202,26 @@ def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
52025202
return op.Where(mask, value_cast, self)
52035203

52045204

5205-
def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType:
5205+
@torch_op(("aten::masked_scatter"), trace_only=True)
5206+
def aten_masked_scatter(self: TTensor, mask: TTensor, source: TTensor) -> TTensor:
52065207
"""masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor"""
52075208

5208-
raise NotImplementedError()
5209+
if len(mask.shape) < len(self.shape):
5210+
mask = op.Expand(mask, op.Shape(self))
5211+
else:
5212+
self = op.Expand(self, op.Shape(mask))
5213+
index = op.Transpose(op.NonZero(mask), perm=[1, 0])
5214+
5215+
# NOTE: source can have more elements than needed.
5216+
# It could also have arbitrary shape.
5217+
# This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
5218+
source = op.Reshape(source, op.Constant(value_ints=[-1]))
5219+
axes = op.Constant(value_ints=[0])
5220+
starts = op.Constant(value_ints=[0])
5221+
ends = op.Gather(op.Shape(index), op.Constant(value_ints=[0]), axis=0)
5222+
source = op.Slice(source, starts, ends, axes)
5223+
5224+
return op.ScatterND(self, index, source)
52095225

52105226

52115227
def aten_masked_select(self: TensorType, mask: TensorType) -> TensorType:
@@ -6429,7 +6445,7 @@ def aten_nextafter(self: TensorType, other: TensorType) -> TensorType:
64296445
raise NotImplementedError()
64306446

64316447

6432-
@torch_op("aten::nonzero")
6448+
@torch_op("aten::nonzero", trace_only=True)
64336449
def aten_nonzero(self: TTensor) -> INT64:
64346450
"""nonzero(Tensor self) -> Tensor"""
64356451
# NOTE: In torch the return shape is [n, d], while in onnx [d, n],

tests/function_libs/torch_lib/ops_test_data.py

+1
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ def _where_input_wrangler(
932932
dtypes=(torch.bool,),
933933
reason="fixme: ORT does not have an implementation for Where with bool inputs.",
934934
),
935+
TorchLibOpInfo("masked_scatter", core_ops.aten_masked_scatter),
935936
TorchLibOpInfo(
936937
"matmul",
937938
core_ops.aten_matmul,

0 commit comments

Comments
 (0)