Skip to content

Commit 1e86d5a

Browse files
[DIP] Infer operation data type based on its params for Corr2D (buddy-compiler#63)
* Add lit-test for correlation2D * Add correlation i32 test * Add support for i32 correlation 2d * Enable default attribute printing for DIP * Add a check for compatbility return type * Update lit test to find mlir cpu utils * Make output to be param for Corr2d - having it as a return value causes to segfault correlation2D sample - it is not clear how to return MemRef from C-interface * Mark attributes with <> in dip.mlir * Add a operand type check for Corr2D op and a test - Make sure that input, kernel, output and constant have the same value and use as inferred type - Adding a negative lit test to check params of the op * Fix review comments * Add support for i8,i64,f64 * Fix correlation2D_f64 test * by constructing F64 correctly * Fix review comment and formatting issues * Fix more review comment - extend correlation2d_invalid_type test to cover a condition for supported types - add comments for utility functions * Remove insertZeroConstantOp from LowerDIPPass - it is present in DIPItility.h * Trivial changes Co-authored-by: meshtag <[email protected]>
1 parent 2a826a3 commit 1e86d5a

File tree

13 files changed

+388
-32
lines changed

13 files changed

+388
-32
lines changed

examples/DIPDialect/dip.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
func.func @corr_2d_constant_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32)
22
{
3-
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
3+
dip.corr_2d <CONSTANT_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
44
return
55
}
66

77
func.func @corr_2d_replicate_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32)
88
{
9-
dip.corr_2d REPLICATE_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
9+
dip.corr_2d <REPLICATE_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
1010
return
1111
}
1212

include/Dialect/DIP/DIPDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def DIP_Dialect : Dialect {
3535
of developing a MLIR backend for performing image processing operations
3636
such as 2D Correlation, Morphological processing, etc.
3737
}];
38+
let useDefaultAttributePrinterParser = 1;
3839
let cppNamespace = "::buddy::dip";
3940
}
4041

include/Dialect/DIP/DIPOps.td

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,31 @@ def DIP_InterpolationType : I32EnumAttr<"InterpolationType",
5555
let cppNamespace = "::buddy::dip";
5656
}
5757

58-
def DIP_BoundaryOptionAttr : EnumAttr<DIP_Dialect, DIP_BoundaryOption, "boundary_option">;
58+
def DIP_BoundaryOptionAttr : EnumAttr<DIP_Dialect, DIP_BoundaryOption, "boundary_option"> {
59+
let assemblyFormat = "`<` $value `>`";
60+
}
5961
def DIP_InterpolationAttr : EnumAttr<DIP_Dialect, DIP_InterpolationType, "interpolation_type">;
6062

61-
def DIP_Corr2DOp : DIP_Op<"corr_2d">
62-
{
63+
def DIP_Corr2DOp : DIP_Op<"corr_2d"> {
6364
let summary = [{This operation is used for performing 2D correlation on an image.
64-
The 2D correlation API provided by the linalg dialect is more suited for
65-
applications in which boundary extrapolation is not explicitly required.
66-
Due to this, dimensions of output are always less than the input dimensions after
67-
using linalg dialect's 2D correlation API.
68-
69-
dip.corr_2d performs boundary extrapolation for making the size of the output image
70-
equal to the size of the input image. Boundary extrapolation can be done using
71-
different methods, supported options are :
72-
a. Constant Padding : Uses a constant for padding whole extra region in input image
73-
for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk)
74-
b. Replicate Padding : Uses last/first element of respective column/row for padding
75-
the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg)
76-
For example:
65+
The 2D correlation API provided by the linalg dialect is more suited for
66+
applications in which boundary extrapolation is not explicitly required.
67+
Due to this, dimensions of output are always less than the input dimensions after
68+
using linalg dialect's 2D correlation API.
69+
70+
dip.corr_2d performs boundary extrapolation for making the size of the output image
71+
equal to the size of the input image. Boundary extrapolation can be done using
72+
different methods, supported options are:
73+
a. Constant Padding : Uses a constant for padding whole extra region in input image
74+
for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk)
75+
b. Replicate Padding : Uses last/first element of respective column/row for padding
76+
the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg)
77+
For example:
7778

78-
```mlir
79-
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
80-
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, index
81-
```
79+
```mlir
80+
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
81+
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, index
82+
```
8283
}];
8384

8485
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "inputMemref",
@@ -87,7 +88,9 @@ def DIP_Corr2DOp : DIP_Op<"corr_2d">
8788
[MemRead]>:$memrefK,
8889
Arg<AnyRankedOrUnrankedMemRef, "outputMemref",
8990
[MemRead]>:$memrefCO,
90-
Index : $centerX, Index : $centerY, F32 : $constantValue,
91+
Index : $centerX,
92+
Index : $centerY,
93+
AnyTypeOf<[AnyI8, AnyI32, AnyI64, AnyFloat]> : $constantValue,
9194
DIP_BoundaryOptionAttr:$boundary_option);
9295

9396
let assemblyFormat = [{

include/Utils/DIPUtils.h

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,44 @@
2424

2525
#include "Utils/Utils.h"
2626

27+
// Inserts a constant op with value 0 into a location `loc` based on type
28+
// `type`. Supported types are : f32, f64, integer types
29+
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
30+
Type elemTy) {
31+
Value op = {};
32+
auto bitWidth = elemTy.getIntOrFloatBitWidth();
33+
if (elemTy.isF32() || elemTy.isF64()) {
34+
FloatType type =
35+
elemTy.isF32() ? FloatType::getF32(ctx) : FloatType::getF64(ctx);
36+
auto zero = APFloat::getZero(type.getFloatSemantics());
37+
op = builder.create<ConstantFloatOp>(loc, zero, type);
38+
} else if (elemTy.isInteger(bitWidth)) {
39+
IntegerType type = IntegerType::get(ctx, bitWidth);
40+
op = builder.create<ConstantIntOp>(loc, 0, type);
41+
}
42+
43+
return op;
44+
}
45+
46+
// Inserts FMA operation into a given location `loc` based on type `type`.
47+
// Note: FMA is done by Multiply and Add for integer types, because there is no
48+
// dedicated FMA operation for them.
49+
// Supported types: f32, f64, integer types
50+
Value insertFMAOp(OpBuilder &builder, Location loc, VectorType type,
51+
Value inputVec, Value kernelVec, Value outputVec) {
52+
Value res = {};
53+
auto elemTy = type.getElementType();
54+
auto bitWidth = elemTy.getIntOrFloatBitWidth();
55+
if (elemTy.isF32() || elemTy.isF64()) {
56+
res = builder.create<vector::FMAOp>(loc, inputVec, kernelVec, outputVec);
57+
} else if (elemTy.isInteger(bitWidth)) {
58+
Value mul = builder.create<arith::MulIOp>(loc, inputVec, kernelVec);
59+
res = builder.create<arith::AddIOp>(loc, mul, outputVec);
60+
}
61+
62+
return res;
63+
}
64+
2765
// Calculate result of FMA and store it in output memref. This function cannot
2866
// handle tail processing.
2967
void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc,
@@ -32,7 +70,8 @@ void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc,
3270
Value beginIdx, Value endIdx) {
3371
Value outputVec = builder.create<LoadOp>(loc, vecType, output,
3472
ValueRange{beginIdx, endIdx});
35-
Value resVec = builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
73+
Value resVec =
74+
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
3675
builder.create<StoreOp>(loc, resVec, output, ValueRange{beginIdx, endIdx});
3776
}
3877

@@ -72,7 +111,7 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc,
72111
Value outputVec = builder.create<LoadOp>(loc, vecType, output,
73112
ValueRange{beginIdx, endIdx});
74113
Value resVec =
75-
builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
114+
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
76115
builder.create<StoreOp>(loc, resVec, output,
77116
ValueRange{beginIdx, endIdx});
78117

@@ -85,7 +124,7 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc,
85124
loc, vecType, output, ValueRange{beginIdx, endIdx}, extraElemMask,
86125
zeroPadding);
87126
Value resVec =
88-
builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
127+
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
89128
builder.create<MaskedStoreOp>(loc, output, ValueRange{beginIdx, endIdx},
90129
extraElemMask, resVec);
91130

lib/Conversion/LowerDIP/LowerDIPPass.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
#include "mlir/Dialect/Math/IR/Math.h"
2626
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2727
#include "mlir/Dialect/Vector/IR/VectorOps.h"
28+
#include "mlir/IR/Builders.h"
29+
#include "mlir/IR/Location.h"
30+
#include "mlir/IR/MLIRContext.h"
31+
#include "mlir/IR/ValueRange.h"
2832
#include "mlir/Pass/Pass.h"
2933

3034
#include "DIP/DIPDialect.h"
@@ -56,7 +60,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
5660
LogicalResult matchAndRewrite(dip::Corr2DOp op,
5761
PatternRewriter &rewriter) const override {
5862
auto loc = op->getLoc();
59-
auto ctx = op->getContext();
63+
auto *ctx = op->getContext();
6064

6165
// Create constant indices.
6266
Value c0 = rewriter.create<ConstantIndexOp>(loc, 0);
@@ -72,7 +76,24 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
7276
auto boundaryOptionAttr = op.boundary_option();
7377
Value strideVal = rewriter.create<ConstantIndexOp>(loc, stride);
7478

75-
FloatType f32 = FloatType::getF32(ctx);
79+
auto inElemTy = input.getType().cast<MemRefType>().getElementType();
80+
auto kElemTy = kernel.getType().cast<MemRefType>().getElementType();
81+
auto outElemTy = output.getType().cast<MemRefType>().getElementType();
82+
auto constElemTy = constantValue.getType();
83+
if (inElemTy != kElemTy || kElemTy != outElemTy ||
84+
outElemTy != constElemTy) {
85+
return op->emitOpError() << "input, kernel, output and constant must "
86+
"have the same element type";
87+
}
88+
// NB: we can infer element type for all operation to be the same as input
89+
// since we verified that the operand types are the same
90+
auto elemTy = inElemTy;
91+
auto bitWidth = elemTy.getIntOrFloatBitWidth();
92+
if (!elemTy.isF64() && !elemTy.isF32() && !elemTy.isInteger(bitWidth)) {
93+
return op->emitOpError() << "supports only f32, f64 and integer types. "
94+
<< elemTy << "is passed";
95+
}
96+
7697
IntegerType i1 = IntegerType::get(ctx, 1);
7798

7899
// Create DimOp.
@@ -90,11 +111,10 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
90111
kernelSize};
91112
SmallVector<int64_t, 8> steps{1, 1, stride, 1};
92113

93-
VectorType vectorTy32 = VectorType::get({stride}, f32);
114+
VectorType vectorTy32 = VectorType::get({stride}, elemTy);
94115
VectorType vectorMaskTy = VectorType::get({stride}, i1);
95116

96-
Value zeroPaddingElem =
97-
rewriter.create<ConstantFloatOp>(loc, (APFloat)(float)0, f32);
117+
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, elemTy);
98118
Value zeroPadding =
99119
rewriter.create<BroadcastOp>(loc, vectorTy32, zeroPaddingElem);
100120

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//
2+
// x86
3+
//
4+
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
5+
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
6+
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
7+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
8+
// RUN: | FileCheck %s
9+
10+
memref.global "private" @global_input : memref<3x3xf32> = dense<[[0. , 1. , 2. ],
11+
[10., 11., 12.],
12+
[20., 21., 22.]]>
13+
14+
memref.global "private" @global_identity : memref<3x3xf32> = dense<[[0., 0., 0.],
15+
[0., 1., 0.],
16+
[0., 0., 0.]]>
17+
18+
memref.global "private" @global_output : memref<3x3xf32> = dense<[[0., 0., 0.],
19+
[0., 0., 0.],
20+
[0., 0., 0.]]>
21+
func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
22+
23+
func.func @main() -> i32 {
24+
%input = memref.get_global @global_input : memref<3x3xf32>
25+
%identity = memref.get_global @global_identity : memref<3x3xf32>
26+
%output = memref.get_global @global_output : memref<3x3xf32>
27+
28+
%kernelAnchorX = arith.constant 1 : index
29+
%kernelAnchorY = arith.constant 1 : index
30+
%c = arith.constant 0. : f32
31+
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32
32+
33+
%printed_output = memref.cast %output : memref<3x3xf32> to memref<*xf32>
34+
call @printMemrefF32(%printed_output) : (memref<*xf32>) -> ()
35+
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
36+
// CHECK{LITERAL}: [[0, 1, 2],
37+
// CHECK{LITERAL}: [10, 11, 12],
38+
// CHECK{LITERAL}: [20, 21, 22]]
39+
40+
%ret = arith.constant 0 : i32
41+
return %ret : i32
42+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//
2+
// x86
3+
//
4+
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
5+
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
6+
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
7+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
8+
// RUN: | FileCheck %s
9+
10+
memref.global "private" @global_input : memref<3x3xf64> = dense<[[0. , 1. , 2. ],
11+
[10., 11., 12.],
12+
[20., 21., 22.]]>
13+
14+
memref.global "private" @global_identity : memref<3x3xf64> = dense<[[0., 0., 0.],
15+
[0., 1., 0.],
16+
[0., 0., 0.]]>
17+
18+
memref.global "private" @global_output : memref<3x3xf64> = dense<[[0., 0., 0.],
19+
[0., 0., 0.],
20+
[0., 0., 0.]]>
21+
func.func private @printMemrefF64(memref<*xf64>) attributes { llvm.emit_c_interface }
22+
23+
func.func @main() -> i32 {
24+
%input = memref.get_global @global_input : memref<3x3xf64>
25+
%identity = memref.get_global @global_identity : memref<3x3xf64>
26+
%output = memref.get_global @global_output : memref<3x3xf64>
27+
28+
%kernelAnchorX = arith.constant 1 : index
29+
%kernelAnchorY = arith.constant 1 : index
30+
%c = arith.constant 0. : f64
31+
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xf64>, memref<3x3xf64>, memref<3x3xf64>, index, index, f64
32+
33+
%printed_output = memref.cast %output : memref<3x3xf64> to memref<*xf64>
34+
call @printMemrefF64(%printed_output) : (memref<*xf64>) -> ()
35+
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
36+
// CHECK{LITERAL}: [[0, 1, 2],
37+
// CHECK{LITERAL}: [10, 11, 12],
38+
// CHECK{LITERAL}: [20, 21, 22]]
39+
40+
%ret = arith.constant 0 : i32
41+
return %ret : i32
42+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//
2+
// x86
3+
//
4+
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
5+
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
6+
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
7+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
8+
// RUN: | FileCheck %s
9+
10+
memref.global "private" @global_input : memref<3x3xi32> = dense<[[0 , 1 , 2 ],
11+
[10, 11, 12],
12+
[20, 21, 22]]>
13+
14+
memref.global "private" @global_identity : memref<3x3xi32> = dense<[[0, 0, 0],
15+
[0, 1, 0],
16+
[0, 0, 0]]>
17+
18+
memref.global "private" @global_output : memref<3x3xi32> = dense<[[0, 0, 0],
19+
[0, 0, 0],
20+
[0, 0, 0]]>
21+
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }
22+
23+
func.func @main() -> i32 {
24+
%input = memref.get_global @global_input : memref<3x3xi32>
25+
%identity = memref.get_global @global_identity : memref<3x3xi32>
26+
%output = memref.get_global @global_output: memref<3x3xi32>
27+
28+
%kernelAnchorX = arith.constant 1 : index
29+
%kernelAnchorY = arith.constant 1 : index
30+
%c = arith.constant 0 : i32
31+
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi32>, memref<3x3xi32>, memref<3x3xi32>, index, index, i32
32+
33+
%printed_output = memref.cast %output : memref<3x3xi32> to memref<*xi32>
34+
call @printMemrefI32(%printed_output) : (memref<*xi32>) -> ()
35+
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
36+
// CHECK{LITERAL}: [[0, 1, 2],
37+
// CHECK{LITERAL}: [10, 11, 12],
38+
// CHECK{LITERAL}: [20, 21, 22]]
39+
%ret = arith.constant 0 : i32
40+
return %ret : i32
41+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//
2+
// x86
3+
//
4+
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
5+
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
6+
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
7+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
8+
// RUN: | FileCheck %s
9+
10+
memref.global "private" @global_input : memref<3x3xi64> = dense<[[0 , 1 , 2 ],
11+
[10, 11, 12],
12+
[20, 21, 22]]>
13+
14+
memref.global "private" @global_identity : memref<3x3xi64> = dense<[[0, 0, 0],
15+
[0, 1, 0],
16+
[0, 0, 0]]>
17+
18+
memref.global "private" @global_output : memref<3x3xi64> = dense<[[0, 0, 0],
19+
[0, 0, 0],
20+
[0, 0, 0]]>
21+
22+
func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
23+
24+
func.func @main() -> i32 {
25+
%input = memref.get_global @global_input : memref<3x3xi64>
26+
%identity = memref.get_global @global_identity : memref<3x3xi64>
27+
%output = memref.get_global @global_output: memref<3x3xi64>
28+
29+
%kernelAnchorX = arith.constant 1 : index
30+
%kernelAnchorY = arith.constant 1 : index
31+
%c = arith.constant 0 : i64
32+
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi64>, memref<3x3xi64>, memref<3x3xi64>, index, index, i64
33+
34+
%printed_output = memref.cast %output : memref<3x3xi64> to memref<*xi64>
35+
call @printMemrefI64(%printed_output) : (memref<*xi64>) -> ()
36+
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
37+
// CHECK{LITERAL}: [[0, 1, 2],
38+
// CHECK{LITERAL}: [10, 11, 12],
39+
// CHECK{LITERAL}: [20, 21, 22]]
40+
%ret = arith.constant 0 : i32
41+
return %ret : i32
42+
}

0 commit comments

Comments
 (0)