Skip to content

Commit 5f6b44c

Browse files
committed
Add tests and update xfail_sets.py
1 parent b3bc02b commit 5f6b44c

File tree

3 files changed

+76
-19
lines changed

3 files changed

+76
-19
lines changed

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
417417
};
418418
} // namespace
419419

420+
static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index,
421+
Value input, int64_t dim) {
422+
// performs the operation : index = index % maxIndex to wrap index around
423+
// maxIndex
424+
Value maxIndexValue = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim));
425+
Value isBeyondMaxIndices = b.create<arith::CmpIOp>(
426+
loc, arith::CmpIPredicate::sge, index, maxIndexValue);
427+
Value wrappedIndices = b.create<arith::RemSIOp>(loc, index, maxIndexValue);
428+
return b.create<arith::SelectOp>(loc, isBeyondMaxIndices, wrappedIndices,
429+
index);
430+
}
431+
420432
namespace {
421433
// Let's say we have an input tensor: initialized with some random values of
422434
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
@@ -478,16 +490,17 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
478490

479491
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr},
480492
rewriter.getContext());
481-
482493
Value finalRes =
483494
rewriter
484495
.create<linalg::GenericOp>(
485496
loc, initTensor.getType(), ValueRange{indices}, initTensor,
486497
/*indexingMaps=*/indexingMaps,
487498
/*iteratorTypes=*/iteratorTypes,
488499
[&](OpBuilder &b, Location loc, ValueRange args) {
489-
Value index = rewriter.create<arith::IndexCastOp>(
490-
loc, rewriter.getIndexType(), args[0]);
500+
Value index =
501+
wrapIndicesAroundMax(b, loc, args[0], input, dimInt);
502+
index = rewriter.create<arith::IndexCastOp>(
503+
loc, rewriter.getIndexType(), index);
491504
SmallVector<Value> indexTarget;
492505
for (unsigned i = 0; i < inputRank; i++)
493506
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,6 +1767,7 @@
17671767
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
17681768
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
17691769
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
1770+
"AsStridedWithOffsetModule_basic",
17701771
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
17711772
"ElementwiseCosIntModule_basic",
17721773
"ElementwiseReciprocalIntModule_basic",

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,22 +1893,29 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) ->
18931893
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
18941894
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
18951895
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
1896-
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
1897-
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
1898-
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
1899-
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1900-
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1901-
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
1902-
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
1903-
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
1904-
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1905-
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
1906-
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
1907-
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
1908-
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
1909-
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
1910-
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
1911-
// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
1896+
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<120> : tensor<i32>}> : () -> tensor<i32>
1897+
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<119> : tensor<i32>}> : () -> tensor<i32>
1898+
// CHECK: %[[VAL_8:.*]] = tosa.greater %[[VAL_5]], %[[VAL_7]] : (tensor<2xi32>, tensor<i32>) -> tensor<2xi1>
1899+
// CHECK: %[[VAL_9:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
1900+
// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_9]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
1901+
// CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_5]], %[[VAL_10]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1902+
// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_8]], %[[VAL_11]], %[[VAL_5]] : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1903+
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
1904+
// CHECK: %[[VAL_14:.*]] = tosa.tile %[[VAL_13]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
1905+
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
1906+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1907+
// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1908+
// CHECK: %[[VAL_18:.*]] = tosa.concat %[[VAL_16]], %[[VAL_17]], %[[VAL_15]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
1909+
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
1910+
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
1911+
// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1912+
// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_20]], %[[VAL_21]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
1913+
// CHECK: %[[VAL_23:.*]] = tosa.reduce_sum %[[VAL_22]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
1914+
// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_23]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
1915+
// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_19]], %[[VAL_24]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
1916+
// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_25]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
1917+
// CHECK: %[[VAL_27:.*]] = torch_c.from_builtin_tensor %[[VAL_26]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
1918+
// CHECK: return %[[VAL_27]] : !torch.vtensor<[4,5,2],f32>
19121919
// CHECK: }
19131920
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
19141921
%int2 = torch.constant.int 2
@@ -2306,6 +2313,42 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor
23062313
return %2 : !torch.vtensor<[3,3],f32>
23072314
}
23082315

2316+
// -----
2317+
// CHECK-LABEL: func.func @torch.aten.as_strided$offset(
2318+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
2319+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
2320+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 30
2321+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
2322+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
2323+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
2324+
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
2325+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
2326+
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 25>} : (tensor<5x5xf32>) -> tensor<25xf32>
2327+
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[5, 6, 7, 7, 8, 9, 9, 10, 11]> : tensor<9xi32>}> : () -> tensor<9xi32>
2328+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 9, 1>} : (tensor<9xi32>) -> tensor<9x1xi32>
2329+
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2330+
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 25, 1>} : (tensor<25xf32>) -> tensor<1x25x1xf32>
2331+
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 9, 1>} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2332+
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
2333+
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32>
2334+
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2335+
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 9>} : (tensor<9x1xi32>) -> tensor<1x9xi32>
2336+
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32>
2337+
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 9>} : (tensor<1x9x1xf32>) -> tensor<9xf32>
2338+
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 3, 3>} : (tensor<9xf32>) -> tensor<3x3xf32>
2339+
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32>
2340+
// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32>
2341+
func.func @torch.aten.as_strided$offset(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
2342+
%int30 = torch.constant.int 30
2343+
%int1 = torch.constant.int 1
2344+
%int2 = torch.constant.int 2
2345+
%int3 = torch.constant.int 3
2346+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
2347+
%1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2348+
%2 = torch.aten.as_strided %arg0, %0, %1, %int30 : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[3,3],f32>
2349+
return %2 : !torch.vtensor<[3,3],f32>
2350+
}
2351+
23092352
// -----
23102353

23112354
// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic(

0 commit comments

Comments
 (0)