-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][x86vector] Improve intrinsic operands creation #138666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
adam-smnk
wants to merge
1
commit into
llvm:main
Choose a base branch
from
adam-smnk:x86vector-improve-intrinsic-operands-creation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[mlir][x86vector] Improve intrinsic operands creation #138666
adam-smnk
wants to merge
1
commit into
llvm:main
from
adam-smnk:x86vector-improve-intrinsic-operands-creation
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Refactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform the last mile post-processing.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesRefactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform last mile post-processing. Full diff: https://github.com/llvm/llvm-project/pull/138666.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 4f8301f9380b8..25d9c404f0181 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -83,7 +83,10 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
}
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
@@ -404,7 +407,10 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
@@ -452,7 +458,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
@@ -500,7 +509,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
@@ -543,7 +555,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
}];
}
#endif // X86VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index 5176f4a447b6e..cde9d1dce65ee 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -58,9 +58,11 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
}],
/*retType=*/"SmallVector<Value>",
/*methodName=*/"getIntrinsicOperands",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
+ /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
+ "const ::mlir::LLVMTypeConverter &":$typeConverter,
+ "::mlir::RewriterBase &":$rewriter),
/*methodBody=*/"",
- /*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
+ /*defaultImplementation=*/"return SmallVector<Value>(operands);"
>,
];
}
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 8d383b1f8103b..cc7ab7f3f3895 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -31,24 +31,11 @@ void x86vector::X86VectorDialect::initialize() {
>();
}
-static SmallVector<Value>
-getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
- RewriterBase &rewriter,
- const LLVMTypeConverter &typeConverter) {
- SmallVector<Value> operands;
- auto opType = memrefVal.getType();
-
- Type llvmStructType = typeConverter.convertType(opType);
- Value llvmStruct =
- rewriter
- .create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
- .getResult(0);
- MemRefDescriptor memRefDescriptor(llvmStruct);
-
- Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
- operands.push_back(ptr);
-
- return operands;
+static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ MemRefDescriptor memRefDescriptor(buffer);
+ return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
}
LogicalResult x86vector::MaskCompressOp::verify() {
@@ -66,48 +53,61 @@ LogicalResult x86vector::MaskCompressOp::verify() {
}
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
auto loc = getLoc();
+ Adaptor adaptor(operands, *this);
- auto opType = getA().getType();
+ auto opType = adaptor.getA().getType();
Value src;
- if (getSrc()) {
- src = getSrc();
- } else if (getConstantSrc()) {
- src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
+ if (adaptor.getSrc()) {
+ src = adaptor.getSrc();
+ } else if (adaptor.getConstantSrc()) {
+ src = rewriter.create<LLVM::ConstantOp>(loc, opType,
+ adaptor.getConstantSrcAttr());
} else {
auto zeroAttr = rewriter.getZeroAttr(opType);
src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
}
- return SmallVector<Value>{getA(), src, getK()};
+ return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
}
SmallVector<Value>
-x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
- const LLVMTypeConverter &typeConverter) {
- SmallVector<Value> operands(getOperands());
+x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ SmallVector<Value> intrinsicOperands(operands);
// Dot product of all elements, broadcasted to all elements.
Value scale =
rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
- operands.push_back(scale);
+ intrinsicOperands.push_back(scale);
- return operands;
+ return intrinsicOperands;
}
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ Adaptor adaptor(operands, *this);
+ return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+ typeConverter, rewriter)};
}
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ Adaptor adaptor(operands, *this);
+ return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+ typeConverter, rewriter)};
}
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ Adaptor adaptor(operands, *this);
+ return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+ typeConverter, rewriter)};
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 9ee44a63ba2e4..483c1f5c3e4c6 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -84,20 +84,23 @@ LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct OneToOneIntrinsicOpConversion
- : public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
- using OpInterfaceRewritePattern<
- x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
+ : public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
+ using OpInterfaceConversionPattern<
+ x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
- : OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
+ : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
+ benefit),
typeConverter(typeConverter) {}
- LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
- PatternRewriter &rewriter) const override {
- return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
- op.getIntrinsicOperands(rewriter, typeConverter),
- typeConverter, rewriter);
+ LogicalResult
+ matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ return intrinsicRewrite(
+ op, rewriter.getStringAttr(op.getIntrinsicName()),
+ op.getIntrinsicOperands(operands, typeConverter, rewriter),
+ typeConverter, rewriter);
}
private:
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Refactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform last mile post-processing.