From 234834431211f5f1aa9ca471820e444550fa22de Mon Sep 17 00:00:00 2001 From: root Date: Wed, 19 Mar 2025 20:27:14 +0000 Subject: [PATCH 1/6] Initial implementation of AtenOuterOp - Defined the op in Linear.cpp TODO: - Testing, and perhaps add some test(-s) inside torch-mlir? --- lib/Conversion/TorchToLinalg/Linear.cpp | 135 ++++++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9945c52a1684..04dbfa1b5351 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1673,6 +1673,139 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern { } // namespace +namespace { + class ConvertAtenOuterOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + Value lhs = adaptor.getSelf(); + Value rhs = op->getOperand(1); + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { + return failure(); + } + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); + + auto lhsTorchType = cast(op.getSelf().getType()); + auto rhsTorchType = cast(op.getOperand(1).getType()); + + // Get the rank of both matrix. + unsigned lhsRank = lhsType.getRank(); + unsigned rhsRank = rhsType.getRank(); + + Value lhsZeroPoint, rhsZeroPoint; + getZeroPoint(op.getSelf(), lhsZeroPoint); + getZeroPoint(op.getOperand(1), rhsZeroPoint); + + if (static_cast(lhsZeroPoint) != static_cast(rhsZeroPoint)) { + return rewriter.notifyMatchFailure( + op, "unsupported: aten.outer with mixed quantization"); + } + + bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType); + bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType); + + if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) { + // Allows quantized types to mismatch + return rewriter.notifyMatchFailure( + op, "unsupported: aten.outer with different input element types"); + } + + Type newResultType = getTypeConverter()->convertType(op.getType()); + auto resultType = cast(newResultType); + Type elementType = resultType.getElementType(); + + // Quantized case + if (lhsZeroPoint) { + // get each zero point ready to pass to a quantized_matmul + lhsZeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, + getTypeConverter()->convertType(lhsZeroPoint.getType()), + lhsZeroPoint); + rhsZeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, + getTypeConverter()->convertType(rhsZeroPoint.getType()), + rhsZeroPoint); + lhsZeroPoint = rewriter.create( + loc, rewriter.getI32Type(), lhsZeroPoint); + rhsZeroPoint = rewriter.create( + loc, rewriter.getI32Type(), rhsZeroPoint); + + // change uint8 quantization -> int8 quantization + int64_t numBits = + cast(lhsType.getElementType()).getWidth(); + signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); + numBits = cast(rhsType.getElementType()).getWidth(); + signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); + + if (lhsRank == 1 && rhsRank == 1) { + int64_t lhsDim = lhsType.getShape()[0]; + int64_t rhsDim = rhsType.getShape()[0]; + + // Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m] + auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType()); + auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType()); + SmallVector reassociation = {{0, 1}}; + lhs = rewriter.create(loc, lhsUnsqueezeType, lhs, reassociation); + rhs = rewriter.create(loc, rhsUnsqueezeType, rhs, reassociation); + + // Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator. + Value lhsDimVal = rewriter.create(loc, lhs, 0); + Value rhsDimVal = rewriter.create(loc, rhs, 1); + Value zeroTensor = createZeroInitTensor(rewriter, loc, + ValueRange{lhsDimVal, rhsDimVal}, + elementType); + + // Use the quantized version of matmul. + Value outerProd = rewriter.create( + loc, zeroTensor.getType(), + ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, + zeroTensor).getResult(0); + + rewriter.replaceOpWithNewOp(op, newResultType, outerProd); + return success(); + } + return rewriter.notifyMatchFailure(op, "unsupported: quantized aten.outer op case"); + } + + + // Non Quantized Outter Product + if (lhsRank == 1 && rhsRank == 1) { + int64_t lhsDim = lhsType.getShape()[0]; + int64_t rhsDim = rhsType.getShape()[0]; + + // Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m] + auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType()); + auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType()); + SmallVector reassociation = {{0, 1}}; + lhs = rewriter.create(loc, lhsUnsqueezeType, lhs, reassociation); + rhs = rewriter.create(loc, rhsUnsqueezeType, rhs, reassociation); + + // Create a zero-initialized tensor with shape [lhsDim, rhsDim] + Value lhsDimVal = rewriter.create(loc, lhs, 0); + Value rhsDimVal = rewriter.create(loc, rhs, 1); + Value zeroTensor = createZeroInitTensor(rewriter, loc, + ValueRange{lhsDimVal, rhsDimVal}, + elementType); + + // Use linalg::MatmulOp to compute the outer product. + Value outerProd = rewriter.create( + loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor).getResult(0); + + rewriter.replaceOpWithNewOp(op, newResultType, outerProd); + return success(); + } + + return failure(); + } + }; +} // namespace + void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1689,4 +1822,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } From 6c4048c59cc19a741f42a89ae673adc8e048a9d8 Mon Sep 17 00:00:00 2001 From: amemov Date: Mon, 24 Mar 2025 13:50:47 +0000 Subject: [PATCH 2/6] Addressed the comments: - Rewrote the ConvertAtenOuterOp without unsqueezing - Replaced linalg::MatmulOp with linalg::GenericOp for buidling result of the op - Added error messages for - Added test case in e2e tests - placed in matmul.py --- lib/Conversion/TorchToLinalg/Linear.cpp | 193 ++++++------------ .../torch_mlir_e2e_test/test_suite/matmul.py | 24 +++ 2 files changed, 90 insertions(+), 127 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 04dbfa1b5351..75c9953ad4cd 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1674,136 +1674,75 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern { } // namespace namespace { - class ConvertAtenOuterOp : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - Location loc = op->getLoc(); - Value lhs = adaptor.getSelf(); - Value rhs = op->getOperand(1); - - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { - return failure(); - } - auto lhsType = cast(lhs.getType()); - auto rhsType = cast(rhs.getType()); - - auto lhsTorchType = cast(op.getSelf().getType()); - auto rhsTorchType = cast(op.getOperand(1).getType()); - - // Get the rank of both matrix. - unsigned lhsRank = lhsType.getRank(); - unsigned rhsRank = rhsType.getRank(); - - Value lhsZeroPoint, rhsZeroPoint; - getZeroPoint(op.getSelf(), lhsZeroPoint); - getZeroPoint(op.getOperand(1), rhsZeroPoint); - - if (static_cast(lhsZeroPoint) != static_cast(rhsZeroPoint)) { - return rewriter.notifyMatchFailure( - op, "unsupported: aten.outer with mixed quantization"); - } - - bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType); - bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType); - - if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) { - // Allows quantized types to mismatch - return rewriter.notifyMatchFailure( - op, "unsupported: aten.outer with different input element types"); - } - - Type newResultType = getTypeConverter()->convertType(op.getType()); - auto resultType = cast(newResultType); - Type elementType = resultType.getElementType(); - - // Quantized case - if (lhsZeroPoint) { - // get each zero point ready to pass to a quantized_matmul - lhsZeroPoint = typeConverter->materializeTargetConversion( - rewriter, loc, - getTypeConverter()->convertType(lhsZeroPoint.getType()), - lhsZeroPoint); - rhsZeroPoint = typeConverter->materializeTargetConversion( - rewriter, loc, - getTypeConverter()->convertType(rhsZeroPoint.getType()), - rhsZeroPoint); - lhsZeroPoint = rewriter.create( - loc, rewriter.getI32Type(), lhsZeroPoint); - rhsZeroPoint = rewriter.create( - loc, rewriter.getI32Type(), rhsZeroPoint); - - // change uint8 quantization -> int8 quantization - int64_t numBits = - cast(lhsType.getElementType()).getWidth(); - signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); - numBits = cast(rhsType.getElementType()).getWidth(); - signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); - - if (lhsRank == 1 && rhsRank == 1) { - int64_t lhsDim = lhsType.getShape()[0]; - int64_t rhsDim = rhsType.getShape()[0]; +class ConvertAtenOuterOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { - // Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m] - auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType()); - auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType()); - SmallVector reassociation = {{0, 1}}; - lhs = rewriter.create(loc, lhsUnsqueezeType, lhs, reassociation); - rhs = rewriter.create(loc, rhsUnsqueezeType, rhs, reassociation); - - // Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator. - Value lhsDimVal = rewriter.create(loc, lhs, 0); - Value rhsDimVal = rewriter.create(loc, rhs, 1); - Value zeroTensor = createZeroInitTensor(rewriter, loc, - ValueRange{lhsDimVal, rhsDimVal}, - elementType); - - // Use the quantized version of matmul. - Value outerProd = rewriter.create( - loc, zeroTensor.getType(), - ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, - zeroTensor).getResult(0); - - rewriter.replaceOpWithNewOp(op, newResultType, outerProd); - return success(); - } - return rewriter.notifyMatchFailure(op, "unsupported: quantized aten.outer op case"); - } - - - // Non Quantized Outter Product - if (lhsRank == 1 && rhsRank == 1) { - int64_t lhsDim = lhsType.getShape()[0]; - int64_t rhsDim = rhsType.getShape()[0]; - - // Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m] - auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType()); - auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType()); - SmallVector reassociation = {{0, 1}}; - lhs = rewriter.create(loc, lhsUnsqueezeType, lhs, reassociation); - rhs = rewriter.create(loc, rhsUnsqueezeType, rhs, reassociation); - - // Create a zero-initialized tensor with shape [lhsDim, rhsDim] - Value lhsDimVal = rewriter.create(loc, lhs, 0); - Value rhsDimVal = rewriter.create(loc, rhs, 1); - Value zeroTensor = createZeroInitTensor(rewriter, loc, - ValueRange{lhsDimVal, rhsDimVal}, - elementType); - - // Use linalg::MatmulOp to compute the outer product. - Value outerProd = rewriter.create( - loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor).getResult(0); - - rewriter.replaceOpWithNewOp(op, newResultType, outerProd); - return success(); - } - + Location loc = op->getLoc(); + Value lhs = adaptor.getSelf(); + Value rhs = op->getOperand(1); + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { return failure(); } - }; + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); + + if (!lhsType || !rhsType) + return rewriter.notifyMatchFailure(op, + "outer: expected ranked tensor types"); + if (lhsType.getRank() != 1 || rhsType.getRank() != 1) + return rewriter.notifyMatchFailure( + op, "outer: expected 1D tensors for outer op lowering"); + + Value lhsDim = getDimOp(rewriter, loc, lhs, 1); + Value rhsDim = getDimOp(rewriter, loc, rhs, 1); + Type elementType = lhsType.getElementType(); + Type newResultType = getTypeConverter()->convertType(op.getType()); + + // Create a zero-initialized tensor with shape [lhsDim, rhsDim] + Value zeroTensor = createZeroInitTensor( + rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType); + + // Set up affine indexing maps: + // We create a 2D loop iteration space. For the lhs, we use the first index + // (i), for the rhs, the second index (j), and for the result, both (i, j). + AffineMap mapLhs = + AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(0)}, + rewriter.getContext()); + AffineMap mapRhs = + AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(1)}, + rewriter.getContext()); + AffineMap mapOut = + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); + + SmallVector iteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel}; + + Value outerProd = + rewriter + .create( + loc, zeroTensor.getType(), + /*inputs=*/ValueRange{lhsDim, rhsDim}, + /*outputs=*/zeroTensor, + /*indexingMaps=*/ + SmallVector{mapLhs, mapRhs, mapOut}, + /*iteratortType=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value lhsElem = args[0]; + Value rhsElem = args[1]; + Value mult = b.create(loc, lhsElem, rhsElem); + b.create(loc, mult); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, newResultType, outerProd); + return success(); + } +}; } // namespace void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 17240cf953df..ecd9e8657c31 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -918,3 +918,27 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: AtenLinalgCrossDynamic()) def AtenLinalgCrossDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) + + +# ============================================================================== + + +class AtenOuter(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.outer(lhs, rhs) + + +@register_test_case(module_factory=lambda: AtenOuter()) +def AtenOuter_basic(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(3)) \ No newline at end of file From e8d511c4f01eb7203bf5ddf312ffd24090d8b624 Mon Sep 17 00:00:00 2001 From: amemov Date: Thu, 3 Apr 2025 13:57:16 +0000 Subject: [PATCH 3/6] Changed createZeroInitTensor to createInitTensor with NULL --- lib/Conversion/TorchToLinalg/Linear.cpp | 4 ++-- projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 75c9953ad4cd..b48b81e67482 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1706,8 +1706,8 @@ class ConvertAtenOuterOp : public OpConversionPattern { // Create a zero-initialized tensor with shape [lhsDim, rhsDim] Value zeroTensor = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType); - - // Set up affine indexing maps: + + // Set up affine indexing maps: // We create a 2D loop iteration space. For the lhs, we use the first index // (i), for the rhs, the second index (j), and for the result, both (i, j). AffineMap mapLhs = diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index ecd9e8657c31..af307dd325d2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -941,4 +941,4 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: AtenOuter()) def AtenOuter_basic(module, tu: TestUtils): - module.forward(tu.rand(3), tu.rand(3)) \ No newline at end of file + module.forward(tu.rand(3), tu.rand(3)) From 73dfc8ad925090431a93d29aac06510d1ed1cbff Mon Sep 17 00:00:00 2001 From: amemov Date: Wed, 9 Apr 2025 03:33:01 +0000 Subject: [PATCH 4/6] Added missing change --- lib/Conversion/TorchToLinalg/Linear.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index b48b81e67482..7e1fbf9dd738 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1704,8 +1704,8 @@ class ConvertAtenOuterOp : public OpConversionPattern { Type newResultType = getTypeConverter()->convertType(op.getType()); // Create a zero-initialized tensor with shape [lhsDim, rhsDim] - Value zeroTensor = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType); + Value initTensor = createInitTensor( + rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType, NULL); // Set up affine indexing maps: // We create a 2D loop iteration space. For the lhs, we use the first index @@ -1725,9 +1725,9 @@ class ConvertAtenOuterOp : public OpConversionPattern { Value outerProd = rewriter .create( - loc, zeroTensor.getType(), + loc, initTensor.getType(), /*inputs=*/ValueRange{lhsDim, rhsDim}, - /*outputs=*/zeroTensor, + /*outputs=*/initTensor, /*indexingMaps=*/ SmallVector{mapLhs, mapRhs, mapOut}, /*iteratortType=*/iteratorTypes, From 4df6def3c5fae37412c1c9e1f814f71873c7d5a4 Mon Sep 17 00:00:00 2001 From: amemov Date: Sat, 12 Apr 2025 15:12:54 +0000 Subject: [PATCH 5/6] Addressed the problem with testing --- lib/Conversion/TorchToLinalg/Linear.cpp | 8 ++++---- .../pt1/python/torch_mlir_e2e_test/test_suite/matmul.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 7e1fbf9dd738..d9a4e74c231d 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1688,8 +1688,8 @@ class ConvertAtenOuterOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { return failure(); } - auto lhsType = cast(lhs.getType()); - auto rhsType = cast(rhs.getType()); + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); if (!lhsType || !rhsType) return rewriter.notifyMatchFailure(op, @@ -1698,8 +1698,8 @@ class ConvertAtenOuterOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "outer: expected 1D tensors for outer op lowering"); - Value lhsDim = getDimOp(rewriter, loc, lhs, 1); - Value rhsDim = getDimOp(rewriter, loc, rhs, 1); + Value lhsDim = getDimOp(rewriter, loc, lhs, 0); + Value rhsDim = getDimOp(rewriter, loc, rhs, 0); Type elementType = lhsType.getElementType(); Type newResultType = getTypeConverter()->convertType(op.getType()); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index af307dd325d2..9fde04330b74 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -931,8 +931,8 @@ def __init__(self): @annotate_args( [ None, - ([-1], torch.float32, True), - ([-1], torch.float32, True), + ([3], torch.float32, True), + ([3], torch.float32, True), ] ) def forward(self, lhs, rhs): From 0e887376b552e6b9d5f7d6cda6128f445fcb23f1 Mon Sep 17 00:00:00 2001 From: amemov Date: Sat, 19 Apr 2025 23:09:50 +0000 Subject: [PATCH 6/6] Addressed the feedback --- lib/Conversion/TorchToLinalg/Linear.cpp | 8 +++-- .../torch_mlir_e2e_test/test_suite/matmul.py | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index d9a4e74c231d..dff644a65dce 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1683,7 +1683,7 @@ class ConvertAtenOuterOp : public OpConversionPattern { Location loc = op->getLoc(); Value lhs = adaptor.getSelf(); - Value rhs = op->getOperand(1); + Value rhs = adaptor.getVec2(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { return failure(); @@ -1704,8 +1704,10 @@ class ConvertAtenOuterOp : public OpConversionPattern { Type newResultType = getTypeConverter()->convertType(op.getType()); // Create a zero-initialized tensor with shape [lhsDim, rhsDim] - Value initTensor = createInitTensor( - rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType, NULL); + SmallVector resultShape = + getAsOpFoldResult(ValueRange{lhsDim, rhsDim}); + Value initTensor = + rewriter.create(loc, resultShape, elementType); // Set up affine indexing maps: // We create a 2D loop iteration space. For the lhs, we use the first index diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 9fde04330b74..79bacc2bb9c7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -942,3 +942,37 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: AtenOuter()) def AtenOuter_basic(module, tu: TestUtils): module.forward(tu.rand(3), tu.rand(3)) + + +# ============================================================================== + + +class AtenOuterDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.outer(lhs, rhs) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(5), tu.rand(5)) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_lhs_larger(module, tu: TestUtils): + module.forward(tu.rand(7), tu.rand(4)) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_rhs_larger(module, tu: TestUtils): + module.forward(tu.rand(2), tu.rand(6))