Skip to content

[mlir][x86vector] Improve intrinsic operands creation #138666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

adam-smnk
Copy link
Contributor

Refactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform last mile post-processing.

Refactors intrinsic op interface to delegate initial operands mapping
to the dialect converter and allow intrinsic operands getters to only
perform the last mile post-processing.
@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Refactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform last mile post-processing.


Full diff: https://github.com/llvm/llvm-project/pull/138666.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/X86Vector/X86Vector.td (+20-5)
  • (modified) mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td (+4-2)
  • (modified) mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp (+36-36)
  • (modified) mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp (+12-9)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 4f8301f9380b8..25d9c404f0181 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -83,7 +83,10 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
     }
   }];
   let extraClassDeclaration = [{
-    SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -404,7 +407,10 @@ def DotOp : AVX_LowOp<"dot", [Pure,
     }
   }];
   let extraClassDeclaration = [{
-    SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -452,7 +458,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 
 }
@@ -500,7 +509,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -543,7 +555,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 #endif // X86VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index 5176f4a447b6e..cde9d1dce65ee 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -58,9 +58,11 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
       }],
       /*retType=*/"SmallVector<Value>",
       /*methodName=*/"getIntrinsicOperands",
-      /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
+      /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
+                    "const ::mlir::LLVMTypeConverter &":$typeConverter,
+                    "::mlir::RewriterBase &":$rewriter),
       /*methodBody=*/"",
-      /*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
+      /*defaultImplementation=*/"return SmallVector<Value>(operands);"
     >,
   ];
 }
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 8d383b1f8103b..cc7ab7f3f3895 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -31,24 +31,11 @@ void x86vector::X86VectorDialect::initialize() {
       >();
 }
 
-static SmallVector<Value>
-getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
-                 RewriterBase &rewriter,
-                 const LLVMTypeConverter &typeConverter) {
-  SmallVector<Value> operands;
-  auto opType = memrefVal.getType();
-
-  Type llvmStructType = typeConverter.convertType(opType);
-  Value llvmStruct =
-      rewriter
-          .create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
-          .getResult(0);
-  MemRefDescriptor memRefDescriptor(llvmStruct);
-
-  Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
-  operands.push_back(ptr);
-
-  return operands;
+static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
+                              const LLVMTypeConverter &typeConverter,
+                              RewriterBase &rewriter) {
+  MemRefDescriptor memRefDescriptor(buffer);
+  return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
 }
 
 LogicalResult x86vector::MaskCompressOp::verify() {
@@ -66,48 +53,61 @@ LogicalResult x86vector::MaskCompressOp::verify() {
 }
 
 SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
   auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
 
-  auto opType = getA().getType();
+  auto opType = adaptor.getA().getType();
   Value src;
-  if (getSrc()) {
-    src = getSrc();
-  } else if (getConstantSrc()) {
-    src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
+  if (adaptor.getSrc()) {
+    src = adaptor.getSrc();
+  } else if (adaptor.getConstantSrc()) {
+    src = rewriter.create<LLVM::ConstantOp>(loc, opType,
+                                            adaptor.getConstantSrcAttr());
   } else {
     auto zeroAttr = rewriter.getZeroAttr(opType);
     src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
   }
 
-  return SmallVector<Value>{getA(), src, getK()};
+  return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
 }
 
 SmallVector<Value>
-x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
-                                       const LLVMTypeConverter &typeConverter) {
-  SmallVector<Value> operands(getOperands());
+x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                       const LLVMTypeConverter &typeConverter,
+                                       RewriterBase &rewriter) {
+  SmallVector<Value> intrinsicOperands(operands);
   // Dot product of all elements, broadcasted to all elements.
   Value scale =
       rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
-  operands.push_back(scale);
+  intrinsicOperands.push_back(scale);
 
-  return operands;
+  return intrinsicOperands;
 }
 
 SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 9ee44a63ba2e4..483c1f5c3e4c6 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -84,20 +84,23 @@ LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
 /// Generic one-to-one conversion of simply mappable operations into calls
 /// to their respective LLVM intrinsics.
 struct OneToOneIntrinsicOpConversion
-    : public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
-  using OpInterfaceRewritePattern<
-      x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
+    : public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
+  using OpInterfaceConversionPattern<
+      x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
 
   OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
                                 PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
+      : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
+                                     benefit),
         typeConverter(typeConverter) {}
 
-  LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
-                                PatternRewriter &rewriter) const override {
-    return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
-                            op.getIntrinsicOperands(rewriter, typeConverter),
-                            typeConverter, rewriter);
+  LogicalResult
+  matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    return intrinsicRewrite(
+        op, rewriter.getStringAttr(op.getIntrinsicName()),
+        op.getIntrinsicOperands(operands, typeConverter, rewriter),
+        typeConverter, rewriter);
   }
 
 private:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants