Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,13 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,

// fillK: range of each index, total number of fillInput(could be scatter)
// after flattened k = 1*1*3 = 3
for (int i = 0; i < ND; i++) {
fillK *= fillValuesType.getShape()[i];
int64_t fillNumElements = 1;
for (int64_t dim : fillValuesType.getShape()) {
fillNumElements *= dim;
}
if (fillNumElements % C != 0)
return std::nullopt;
fillK = fillNumElements / C;
SmallVector<int64_t, 3> tosaFillValuesShape({N, fillK, C}); // {1,3,1}

// Reshape/Flatten fillValues to 3d tensor
Expand Down
1 change: 0 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3825,7 +3825,6 @@
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
"IndexPutImpl2DFloatAccumulateModule_basic",
"IndexPutImpl2DImplicitModule_basic",
"IndexPutImpl2DIndexModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
Expand Down
20 changes: 20 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2746,6 +2746,26 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t

// -----

// CHECK-LABEL: func.func @torch.aten.index_put_hacked_twin_flattened_updates(
// CHECK: %[[SCATTER:.*]] = tosa.scatter
// CHECK-SAME: (tensor<1x6x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x6x1xf32>
// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[SCATTER]]
// CHECK-SAME: (tensor<1x6x1xf32>, !tosa.shape<3>) -> tensor<1x2x3xf32>
// CHECK: torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<1x2x3xf32> -> !torch.vtensor<[1,2,3],f32>
func.func @torch.aten.index_put_hacked_twin_flattened_updates(
%arg0: !torch.vtensor<[1,2,3],f32>,
%arg1: !torch.vtensor<[6],si64>,
%arg2: !torch.vtensor<[6],si64>,
%arg3: !torch.vtensor<[6],si64>,
%arg4: !torch.vtensor<[6],f32>) -> !torch.vtensor<[1,2,3],f32> {
%indices = torch.prim.ListConstruct %arg1, %arg2, %arg3 : (!torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>) -> !torch.list<vtensor>
%false = torch.constant.bool false
%0 = torch.aten.index_put.hacked_twin %arg0, %indices, %arg4, %false : !torch.vtensor<[1,2,3],f32>, !torch.list<vtensor>, !torch.vtensor<[6],f32>, !torch.bool -> !torch.vtensor<[1,2,3],f32>
return %0 : !torch.vtensor<[1,2,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4,2],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
Expand Down
Loading