diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ec9dd518db73..49ae20903990 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10979,6 +10979,29 @@ def Torch_AtenAtleast2dOp : Torch_Op<"aten.atleast_2d", [ }]; } +def Torch_AtenAtleast3dOp : Torch_Op<"aten.atleast_3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atleast_3d : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtleast3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtleast3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c1092171fcd6..47a56e007deb 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10966,6 +10966,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atleast_3d\"(%arg0: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.prim.ListConstruct %int1, %6, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %7 : !torch.list\n" +" } else {\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.list) {\n" +" %9:2 = torch.prim.ListUnpack %arg0 : !torch.list -> !torch.int, !torch.int\n" +" %10 = torch.prim.ListConstruct %9#0, %9#1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %10 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.list\n" +" }\n" +" torch.prim.If.yield %8 : !torch.list\n" +" }\n" +" torch.prim.If.yield %5 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -15858,6 +15890,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atleast_3d\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.hstack\"(%arg0: !torch.list>) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9384390cb216..b37daccd7c81 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1941,6 +1941,41 @@ class DecomposeAtenAtleast2dOp : public OpRewritePattern { }; } // namespace +class DecomposeAtenAtleast3dOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAtleast3dOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + Type opType = op.getType(); + + auto inputType = cast(input.getType()); + SmallVector inputShape(inputType.getSizes()); + + if (inputShape.size() >= 3) { + rewriter.replaceOp(op, input); + return success(); + } + + auto atleast2dResShape = + inputShape.empty() + ? SmallVector{1, 1} + : (inputShape.size() == 1 ? SmallVector{1, inputShape[0]} + : inputShape); + auto atleast2dResType = rewriter.getType( + atleast2dResShape, inputType.getOptionalDtype()); + auto atleast2dRes = + rewriter.create(loc, atleast2dResType, input); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp(op, opType, atleast2dRes, + zero); + return success(); + } +}; + namespace { // Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce // operation and permute operation. Currently, this pass doesn't support @@ -11722,6 +11757,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 9e1cb530aa0b..fb83535eef8c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -400,6 +400,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ef84b65e9b68..eb83687cc39f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1004,6 +1004,11 @@ "Atleast2dModule0dInput_basic", "Atleast2dModule1dInput_basic", "Atleast2dModule2dInput_basic", + "Atleast2dModule3dInput_basic", + "Atleast3dModule0dInput_basic", + "Atleast3dModule1dInput_basic", + "Atleast3dModule2dInput_basic", + "Atleast3dModule3dInput_basic", "AtenLinear1D_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", @@ -2006,6 +2011,11 @@ "Atleast2dModule0dInput_basic", "Atleast2dModule1dInput_basic", "Atleast2dModule2dInput_basic", + "Atleast2dModule3dInput_basic", + "Atleast3dModule0dInput_basic", + "Atleast3dModule1dInput_basic", + "Atleast3dModule2dInput_basic", + "Atleast3dModule3dInput_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index cb10a8aead79..6f788205afc7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2318,6 +2318,18 @@ def aten〇atleast_2d〡shape(self: List[int]) -> List[int]: else: return self +def aten〇atleast_3d〡shape(self: List[int]) -> List[int]: + if len(self)==0: + return [1, 1, 1] + elif len(self)==1: + x=self[0] + return [1, 1, x] + elif len(self)==2: + x, y = self + return [1, x, y] + else: + return self + def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: return upstream_shape_functions.stack(tensors, dim) @@ -5676,6 +5688,11 @@ def aten〇atleast_2d〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atleast_3d〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]), Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 46485b3173cc..f2ba5cf893bb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -844,6 +844,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::atleast_1d : (Tensor) -> (Tensor)") emit("aten::atleast_2d : (Tensor) -> (Tensor)") + emit("aten::atleast_3d : (Tensor) -> (Tensor)") emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") emit("aten::trace : (Tensor) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index d1ddc42b39b1..209a6cd27c00 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1551,6 +1551,102 @@ def Atleast1dModule1dInput_basic(module, tu: TestUtils): module.forward(tu.rand(4)) +class Atleast2dModule0dInput(torch.nn.Module): + @export + @annotate_args([None, [(), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule0dInput()) +def Atleast2dModule0dInput_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +class Atleast2dModule1dInput(torch.nn.Module): + @export + @annotate_args([None, [(10,), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule1dInput()) +def Atleast2dModule1dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(10)) + + +class Atleast2dModule2dInput(torch.nn.Module): + @export + @annotate_args([None, [(3, 4), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule2dInput()) +def Atleast2dModule2dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +class Atleast2dModule3dInput(torch.nn.Module): + @export + @annotate_args([None, [(2, 3, 4), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule3dInput()) +def Atleast2dModule3dInput_basic(module, tu: TestUtils): + result = module.forward(tu.rand(2, 3, 4)) + + +class Atleast3dModule0dInput(torch.nn.Module): + @export + @annotate_args([None, [(), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_3d(x) + + +@register_test_case(module_factory=lambda: Atleast3dModule0dInput()) +def Atleast3dModule0dInput_basic(module, tu: TestUtils): + result = module.forward(tu.rand()) + + +class Atleast3dModule1dInput(torch.nn.Module): + @export + @annotate_args([None, [(10,), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_3d(x) + + +@register_test_case(module_factory=lambda: Atleast3dModule1dInput()) +def Atleast3dModule1dInput_basic(module, tu: TestUtils): + result = module.forward(tu.rand(10)) + + +class Atleast3dModule2dInput(torch.nn.Module): + @export + @annotate_args([None, [(4, 5), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_3d(x) + + +@register_test_case(module_factory=lambda: Atleast3dModule2dInput()) +def Atleast3dModule2dInput_basic(module, tu: TestUtils): + result = module.forward(tu.rand(4, 5)) + + +class Atleast3dModule3dInput(torch.nn.Module): + @export + @annotate_args([None, [(2, 3, 4), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_3d(x) + + +@register_test_case(module_factory=lambda: Atleast3dModule3dInput()) +def Atleast3dModule3dInput_basic(module, tu: TestUtils): + result = module.forward(tu.rand(2, 3, 4)) + + # ==============================================================================