@@ -5202,10 +5202,26 @@ def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
5202
5202
return op .Where (mask , value_cast , self )
5203
5203
5204
5204
5205
- def aten_masked_scatter (self : TensorType , mask : TensorType , source : TensorType ) -> TensorType :
5205
+ @torch_op (("aten::masked_scatter" ), trace_only = True )
5206
+ def aten_masked_scatter (self : TTensor , mask : TTensor , source : TTensor ) -> TTensor :
5206
5207
"""masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor"""
5207
5208
5208
- raise NotImplementedError ()
5209
+ if len (mask .shape ) < len (self .shape ):
5210
+ mask = op .Expand (mask , op .Shape (self ))
5211
+ else :
5212
+ self = op .Expand (self , op .Shape (mask ))
5213
+ index = op .Transpose (op .NonZero (mask ), perm = [1 , 0 ])
5214
+
5215
+ # NOTE: source can have more elements than needed.
5216
+ # It could also have arbitrary shape.
5217
+ # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
5218
+ source = op .Reshape (source , op .Constant (value_ints = [- 1 ]))
5219
+ axes = op .Constant (value_ints = [0 ])
5220
+ starts = op .Constant (value_ints = [0 ])
5221
+ ends = op .Gather (op .Shape (index ), op .Constant (value_ints = [0 ]), axis = 0 )
5222
+ source = op .Slice (source , starts , ends , axes )
5223
+
5224
+ return op .ScatterND (self , index , source )
5209
5225
5210
5226
5211
5227
def aten_masked_select (self : TensorType , mask : TensorType ) -> TensorType :
@@ -6429,7 +6445,7 @@ def aten_nextafter(self: TensorType, other: TensorType) -> TensorType:
6429
6445
raise NotImplementedError ()
6430
6446
6431
6447
6432
- @torch_op ("aten::nonzero" )
6448
+ @torch_op ("aten::nonzero" , trace_only = True )
6433
6449
def aten_nonzero (self : TTensor ) -> INT64 :
6434
6450
"""nonzero(Tensor self) -> Tensor"""
6435
6451
# NOTE: In torch the return shape is [n, d], while in onnx [d, n],
0 commit comments