From ee26865b9c06be650736ce0479abca64fab32211 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Tue, 28 Jan 2025 22:41:12 +0000 Subject: [PATCH] Bump llvm at 95d993a838863269dc1b90de3808c1e40ac6d5f2 --- externals/llvm-project | 2 +- .../TorchToTosa/TosaLegalizeUtils.cpp | 44 +++++++------------ lib/Dialect/Torch/Utils/Utils.cpp | 8 ++-- 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index e2402615a5a7..95d993a83886 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit e2402615a5a76d46a433dfcc1de10b38a1263c9d +Subproject commit 95d993a838863269dc1b90de3808c1e40ac6d5f2 diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 1ed360ddae61..be7da26fb9dc 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -301,31 +301,31 @@ std::optional getConstTensor(PatternRewriter &rewriter, (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isBF16()) || (src.isF32() && dest.isF16()) || - (src.isF32() && dest.isFloat8E4M3()) || - (src.isF32() && dest.isFloat8E5M2()) || + (src.isF32() && isa(dest)) || + (src.isF32() && isa(dest)) || // f16 -> * (src.isF16() && dest.isInteger(32)) || (src.isF16() && dest.isInteger(16)) || (src.isF16() && dest.isInteger(8)) || (src.isF16() && dest.isBF16()) || (src.isF16() && dest.isF32()) || - (src.isF16() && dest.isFloat8E4M3()) || - (src.isF16() && dest.isFloat8E5M2()) || + (src.isF16() && isa(dest)) || + (src.isF16() && isa(dest)) || // bf16 -> * (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isInteger(16)) || (src.isBF16() && dest.isInteger(8)) || (src.isBF16() && dest.isF32()) || - (src.isBF16() && dest.isFloat8E4M3()) || - (src.isBF16() && dest.isFloat8E5M2()) || + (src.isBF16() && isa(dest)) || + (src.isBF16() && isa(dest)) || // fp8e4m3 -> * - (src.isFloat8E4M3() && dest.isBF16()) || - (src.isFloat8E4M3() && dest.isF32()) || - (src.isFloat8E4M3() && dest.isF16()) || + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16()) || // fp8e5m2 -> * - (src.isFloat8E5M2() && dest.isBF16()) || - (src.isFloat8E5M2() && dest.isF32()) || - (src.isFloat8E5M2() && dest.isF16())) { + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16())) { return success(); } // clang-format on @@ -488,10 +488,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && outputElemTy.isInteger(48)) { accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); - } else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() && - outputElemTy.isF16()) || - (inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() && - outputElemTy.isF16())) { + } else if ((isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16()) || + (isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16())) { accType = mlir::TypeAttr::get(rewriter.getF16Type()); } else { accType = mlir::TypeAttr::get(outputElemTy); @@ -500,17 +500,5 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, return success(); } -// Temporary function to get TOSA const shape -// TODO: Remove this function when getTosaConstShape is available in -// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h -Value getTosaConstShape(PatternRewriter &rewriter, Location loc, - llvm::ArrayRef shape) { - auto attr = rewriter.getIndexTensorAttr(shape); - auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size()); - mlir::Operation *mlir_op = - rewriter.create(loc, type, attr); - return mlir_op->getResult(0); -} - } // namespace tosa } // namespace mlir diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index c0984efffd9c..7f80e84044df 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -652,13 +652,13 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getF32Type(); if (isa(inputType)) return rewriter.getF64Type(); - if (inputType.isFloat8E5M2()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E4M3FN()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E5M2FNUZ()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E4M3FNUZ()) + if (isa(inputType)) return rewriter.getF32Type(); if (inputType.isInteger(8)) // this is an intentional deviation from CUDA (which accumulates i8 to i64)