diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c98d3b3481d6..0fe0a8db92e2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2825,12 +2825,13 @@ class DecomposeAtenAminAmaxOp : public OpRewritePattern { dims = llvm::to_vector(llvm::seq(0, inputTy.getSizes().size())); } + int64_t inputRank = inputTy.getSizes().size(); + llvm::for_each(dims, [&](int64_t &d) { d = toPositiveDim(d, inputRank); }); + // For every dimension included in `dim` of the op, iterated over in // reverse order, we create a call to aten.max.dim. std::sort(dims.rbegin(), dims.rend()); for (int64_t dimInt : dims) { - int64_t inputRank = inputTy.getSizes().size(); - dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); Value dim = Torch::ConstantIntOp::create( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 9b678ac02018..2a1c4b748b9f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1381,6 +1381,27 @@ def ReduceAmaxOutOfOrderDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAmaxOutOfOrderWithNegDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.amax(a, (2, 1, -1)) + + +@register_test_case(module_factory=lambda: ReduceAmaxOutOfOrderWithNegDim()) +def ReduceAmaxOutOfOrderWithNegDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, 6, high=100)) + + +# ============================================================================== class ReduceAmaxKeepDim(torch.nn.Module): def __init__(self):