Skip to content

Commit 8540f68

Browse files
MaheshRavishankardcaballeSimon CamphausenhanhanWThomasRaoux
authored
Integrate llvm-project and bump dependencies. (#10087)
* Integrate llvm-project and bump dependencies. * MHLO_COMMIT=3f94dc7ff68524b1cf2dd420d9997df8f3ccd7a8 * LLVM_COMIT=6c66b089bcd7 * TF_COMMIT=0946dfd12422221e3695a479891aa2f2cad147e5 * Fixes with TypedAttr * Update mhlo and llvm pointers to cherry-pick fixes * OpaqueAttr fixes * Fix EmitC failures in llvm integrate * Fix StreamBase.td * Use patched mlir-hlo branch * fix nvgpu lowering * Fix convert type utility * Disable TFL/test/flex_ops.mlir * Cherry-pick local MHLO patch https://github.com/iree-org/iree-mhlo-fork/commit/a6a3308969278cf9ec36132fc9235314fd091c23 Co-authored-by: Diego Caballero <[email protected]> Co-authored-by: Simon Camphausen <[email protected]> Co-authored-by: Hanhan Wang <[email protected]> Co-authored-by: Thomas Raoux <[email protected]> Co-authored-by: Stella Laurenzo <[email protected]>
1 parent 36a212a commit 8540f68

File tree

64 files changed

+637
-718
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+637
-718
lines changed

compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ func.func @three_init_tensor_uses() {
479479
%c1638400 = arith.constant 1638400 : index
480480
%c0 = arith.constant 0 : index
481481
%cst = arith.constant 3.40282347E+38 : f32
482-
%cst_0 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<64xf32>
482+
%cst_0 = arith.constant dense_resource<__elided__> : tensor<64xf32>
483483
%cst_1 = arith.constant 0.000000e+00 : f32
484484
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:6400x64xf32>
485485
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c1638400) alignment(32) : !flow.dispatch.tensor<readonly:64x64xf32>

compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ func.func @matmul_fill() {
148148
func.func @elementwise() {
149149
%c4 = arith.constant 4 : index
150150
%c0 = arith.constant 0 : index
151-
%cst = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x10xf32>
151+
%cst = arith.constant dense_resource<__elided__> : tensor<1x10xf32>
152152
%c512 = arith.constant 512 : index
153153
%c64 = arith.constant 64 : index
154154
%c10 = arith.constant 10 : index
@@ -186,7 +186,7 @@ func.func @elementwise() {
186186
return
187187
}
188188
// CHECK: func.func @elementwise()
189-
// CHECK-DAG: %[[CST_TENSOR:.+]] = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x10xf32>
189+
// CHECK-DAG: %[[CST_TENSOR:.+]] = arith.constant dense_resource<__elided__> : tensor<1x10xf32>
190190
// CHECK-DAG: %[[CST_BUF:.+]] = bufferization.to_memref %[[CST_TENSOR]]
191191
// CHECK-DAG: %[[IN_BUF:.+]] = hal.interface.binding.subspan set(0) binding(0) {{.+}} : memref<1x10xf32>
192192
// CHECK-DAG: %[[OUT_BUF:.+]] = hal.interface.binding.subspan set(0) binding(1) {{.+}} : memref<1x10xf32>

compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ func.func @binding_ptrs() {
1818
// CHECK: %[[STATE:.+]] = llvm.load %arg1 : !llvm.ptr<struct<[[DISPATCH_STATE_TYPE]]>>
1919
// CHECK: %[[BINDING_PTRS:.+]] = llvm.extractvalue %[[STATE]][10]
2020
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
21-
// CHECK: %[[ARRAY_PTR:.+]] = llvm.getelementptr %[[BINDING_PTRS]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
21+
// CHECK: %[[ARRAY_PTR:.+]] = llvm.getelementptr %[[BINDING_PTRS]][1] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>
2222
// CHECK: %[[BASE_PTR_I8:.+]] = llvm.load %[[ARRAY_PTR]] : !llvm.ptr<ptr<i8>>
23-
// CHECK: %[[BUFFER_I8:.+]] = llvm.getelementptr %[[BASE_PTR_I8]][%[[C72]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
23+
// CHECK: %[[BUFFER_I8:.+]] = llvm.getelementptr %[[BASE_PTR_I8]][72] : (!llvm.ptr<i8>) -> !llvm.ptr<i8>
2424
// CHECK: %[[BUFFER_F32:.+]] = llvm.bitcast %[[BUFFER_I8]] : !llvm.ptr<i8> to !llvm.ptr<f32>
2525
// CHECK: %[[DESC_A:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
2626
// CHECK: %[[DESC_B:.+]] = llvm.insertvalue %[[BUFFER_F32]], %[[DESC_A]][0]

compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ iree_compiler_cc_library(
7575
"@llvm-project//mlir:MemRefTransforms",
7676
"@llvm-project//mlir:NVGPUDialect",
7777
"@llvm-project//mlir:NVGPUToNVVM",
78+
"@llvm-project//mlir:NVGPUTransforms",
7879
"@llvm-project//mlir:NVVMDialect",
7980
"@llvm-project//mlir:PDLDialect",
8081
"@llvm-project//mlir:PDLInterpDialect",

compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ iree_cc_library(
6464
MLIRMemRefTransforms
6565
MLIRNVGPUDialect
6666
MLIRNVGPUToNVVM
67+
MLIRNVGPUTransforms
6768
MLIRNVVMDialect
6869
MLIRPDLDialect
6970
MLIRPDLInterpDialect

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1313
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1414
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
15+
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
1516
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1617

1718
namespace mlir {
@@ -241,10 +242,26 @@ struct LLVMGPUVectorToGPUPass
241242
populatePrepareVectorToMMAPatterns(patterns, llvmgpuUseMMASync);
242243
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
243244

244-
if (llvmgpuUseMMASync)
245+
if (llvmgpuUseMMASync) {
245246
(void)convertVectorToNVVMCompatibleMMASync(funcOp);
246-
else
247+
248+
// TODO: Remove once populateMmaSyncF32ToTF32Patterns is fixed to not add
249+
// attribute tf32 attributes to none f32 ops.
250+
bool hasFP32mma = false;
251+
funcOp.walk([&hasFP32mma](nvgpu::MmaSyncOp op) {
252+
if (op.getType().cast<VectorType>().getElementType().isF32())
253+
hasFP32mma = true;
254+
});
255+
if (hasFP32mma) {
256+
// Use TF32 for float32 case for now.
257+
RewritePatternSet patterns(funcOp.getContext());
258+
nvgpu::populateMmaSyncF32ToTF32Patterns(
259+
patterns, nvgpu::MmaSyncF32Lowering::TF32);
260+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
261+
}
262+
} else {
247263
convertVectorToMMAOps(funcOp);
264+
}
248265
createAsyncGroups(funcOp);
249266
}
250267
};

compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ hal.executable @dwconv_elementwise {
321321
hal.executable.export public @dwconv_elementwise layout(#executable_layout)
322322
builtin.module {
323323
func.func @dwconv_elementwise() {
324-
%cst = arith.constant opaque<"_", "0xDEADBEEF"> : tensor<3x3x1x4xf32>
324+
%cst = arith.constant dense_resource<__elided__> : tensor<3x3x1x4xf32>
325325
%cst_8 = arith.constant 1.001000e+00 : f32
326326
%cst_9 = arith.constant 0.000000e+00 : f32
327327
%c18 = arith.constant 18 : index

compiler/src/iree/compiler/Dialect/Flow/Transforms/StripAndSplatConstantVariables.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class StripAndSplatConstantVariablesPass
5555

5656
auto tensorType = op.getType().cast<TensorType>();
5757
auto elementType = tensorType.getElementType();
58-
DenseElementsAttr newValue;
58+
TypedAttr newValue;
5959
if (elementType.isa<FloatType>()) {
6060
newValue = DenseElementsAttr::get(
6161
tensorType, FloatAttr::get(elementType, 1.0 / replaceIndex));

compiler/src/iree/compiler/Dialect/HAL/IR/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ td_library(
3333
),
3434
deps = [
3535
"//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
36+
"@llvm-project//mlir:BuiltinDialectTdFiles",
3637
"@llvm-project//mlir:FuncTdFiles",
3738
"@llvm-project//mlir:OpBaseTdFiles",
3839
"@llvm-project//mlir:ViewLikeInterfaceTdFiles",

compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include "iree/compiler/Dialect/HAL/IR/HALBase.td"
1111
include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td"
1212
include "iree/compiler/Dialect/Util/IR/UtilAttrs.td"
1313
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
14+
include "mlir/IR/BuiltinAttributeInterfaces.td"
1415
include "mlir/IR/OpAsmInterface.td"
1516
include "mlir/IR/SymbolInterfaces.td"
1617
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -1360,7 +1361,7 @@ def HAL_DeviceQueryOp :
13601361
HAL_Device:$device,
13611362
StrAttr:$category,
13621363
StrAttr:$key,
1363-
OptionalAttr<AnyAttr>:$default_value
1364+
OptionalAttr<TypedAttrInterface>:$default_value
13641365
);
13651366
let results = (outs
13661367
I1:$ok,

compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ struct GlobalOpExpansion
9393
auto resourceOp = rewriter.replaceOpWithNewOp<IREE::Util::GlobalOp>(
9494
globalOp, globalOp.getName(), globalOp.getIsMutable(), resourceType,
9595
initialValue && !tensorInitializerRequired
96-
? llvm::Optional<Attribute>{initialValue}
96+
? llvm::Optional<TypedAttr>{initialValue}
9797
: llvm::None);
9898
resourceOp.setVisibility(globalOp.getVisibility());
9999

@@ -108,7 +108,7 @@ struct GlobalOpExpansion
108108
auto indexType = rewriter.getIndexType();
109109
auto resourceSizeOp = rewriter.create<IREE::Util::GlobalOp>(
110110
globalOp.getLoc(), (globalOp.getName() + "__size").str(),
111-
globalOp.getIsMutable(), indexType, Optional<Attribute>{});
111+
globalOp.getIsMutable(), indexType, Optional<TypedAttr>{});
112112
resourceSizeOp.setVisibility(globalOp.getVisibility());
113113

114114
// Materialize the initializer if we need to setup a tensor-like constant.

compiler/src/iree/compiler/Dialect/Stream/IR/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ td_library(
2727
),
2828
deps = [
2929
"//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
30+
"@llvm-project//mlir:BuiltinDialectTdFiles",
3031
"@llvm-project//mlir:FuncTdFiles",
3132
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
3233
"@llvm-project//mlir:OpBaseTdFiles",

compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include "iree/compiler/Dialect/Util/IR/UtilBase.td"
1212
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
1313
include "iree/compiler/Dialect/Util/IR/UtilTypes.td"
1414
include "mlir/IR/AttrTypeBase.td"
15+
include "mlir/IR/BuiltinAttributeInterfaces.td"
1516
include "mlir/IR/EnumAttr.td"
1617
include "mlir/IR/SubElementInterfaces.td"
1718

@@ -320,7 +321,8 @@ def Stream_Timepoint : TypeDef<Stream_Dialect, "Timepoint", [
320321
let parameters = (ins);
321322
}
322323

323-
def Stream_TimepointAttr : AttrDef<Stream_Dialect, "Timepoint", []> {
324+
def Stream_TimepointAttr : AttrDef<Stream_Dialect, "Timepoint",
325+
[TypedAttrInterface]> {
324326
let mnemonic = "timepoint";
325327
let summary = [{an immediately-resolved timepoint}];
326328
let description = [{}];

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ struct TensorConstantToSplat : public OpRewritePattern<TensorConstantOp> {
868868
"only constant splat attrs can be converted to splat ops");
869869
}
870870

871-
auto splatElementAttr = splatAttr.getSplatValue<Attribute>();
871+
auto splatElementAttr = splatAttr.getSplatValue<TypedAttr>();
872872
auto splatValue = rewriter.create<arith::ConstantOp>(
873873
constantOp.getLoc(), splatElementAttr.getType(), splatElementAttr);
874874
auto resultType = IREE::Stream::ResourceType::get(constantOp.getContext());
@@ -1202,7 +1202,7 @@ struct ConvertSplatConstantsIntoSplats
12021202
if (!value.isSplat()) return failure();
12031203

12041204
auto splatElementAttr =
1205-
value.dyn_cast<SplatElementsAttr>().getSplatValue<Attribute>();
1205+
value.dyn_cast<SplatElementsAttr>().getSplatValue<TypedAttr>();
12061206
auto splatValue = rewriter.create<arith::ConstantOp>(
12071207
constantOp.getLoc(), splatElementAttr.getType(), splatElementAttr);
12081208
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncSplatOp>(

compiler/src/iree/compiler/Dialect/Stream/Transforms/OutlineConstants.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ static bool isOutlinableValue(Attribute value) {
3030
if (auto elementsAttr = value.dyn_cast<DenseElementsAttr>()) {
3131
// Don't outline splats - we want those fused.
3232
return !elementsAttr.isSplat();
33-
} else if (auto opaqueAttr = value.dyn_cast<OpaqueElementsAttr>()) {
34-
return !opaqueAttr.isSplat();
3533
}
3634
return false;
3735
}

compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,6 @@ struct ConstantSlice {
5050
uint64_t getRawLength() const {
5151
if (auto denseAttr = value.dyn_cast<DenseElementsAttr>()) {
5252
return denseAttr.getRawData().size();
53-
} else if (auto opaqueAttr = value.dyn_cast<OpaqueElementsAttr>()) {
54-
// Later on in the pipeline opaque attrs will cause the compiler to fail
55-
// (as at some point we need to get the data) but this allows us to run
56-
// the stream transforms on IR that has had its large constants elided.
57-
return opaqueAttr.getNumElements() *
58-
opaqueAttr.getElementType().getIntOrFloatBitWidth();
5953
} else {
6054
assert(false && "invalid constant attr type");
6155
return 0;

compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct ConstantSet {
4141
// Locations of all constants that went into the table.
4242
SetVector<Location> locs;
4343
// Operand index -> all values from dispatch sites.
44-
SmallVector<std::pair<unsigned, SmallVector<Attribute>>> values;
44+
SmallVector<std::pair<unsigned, SmallVector<TypedAttr>>> values;
4545
};
4646

4747
struct ConstantTable {
@@ -92,7 +92,7 @@ static ConstantTable buildConstantTable(
9292
set.type = operandType;
9393
typeOrder.push_back(operandType);
9494
}
95-
SmallVector<Attribute> values;
95+
SmallVector<TypedAttr> values;
9696
for (auto dispatchOp : dispatchOps) {
9797
auto operand = dispatchOp.getUniformOperands()[idx];
9898
Attribute constantValue;

compiler/src/iree/compiler/Dialect/Util/IR/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ td_library(
3030
include = ["*.td"],
3131
),
3232
deps = [
33+
"@llvm-project//mlir:BuiltinDialectTdFiles",
3334
"@llvm-project//mlir:CallInterfacesTdFiles",
3435
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
3536
"@llvm-project//mlir:FunctionInterfacesTdFiles",

compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,6 @@ CompositeAttr CompositeAttr::get(MLIRContext *context,
346346
if (auto serializableAttr =
347347
valueAttr.dyn_cast<SerializableAttrInterface>()) {
348348
calculatedLength += serializableAttr.getStorageSize();
349-
} else if (auto opaqueAttr = valueAttr.dyn_cast<OpaqueElementsAttr>()) {
350-
// Allow opaque attrs to be placed into composites ease debugging of IR
351-
// that has had large attrs elided; these will fail to actually serialize
352-
// but being able to run most passes with these unserializable attrs is
353-
// useful.
354-
calculatedLength += opaqueAttr.getNumElements() *
355-
opaqueAttr.getElementType().getIntOrFloatBitWidth();
356349
} else {
357350
return {};
358351
}
@@ -369,12 +362,8 @@ LogicalResult CompositeAttr::verify(
369362
if (auto serializableAttr =
370363
valueAttr.dyn_cast<SerializableAttrInterface>()) {
371364
calculatedLength += serializableAttr.getStorageSize();
372-
} else if (auto opaqueAttr = valueAttr.dyn_cast<OpaqueElementsAttr>()) {
373-
calculatedLength += opaqueAttr.getNumElements() *
374-
opaqueAttr.getElementType().getIntOrFloatBitWidth();
375365
} else {
376-
return emitError() << "value is not serializable: "
377-
<< valueAttr.getType();
366+
return emitError() << "value is not serializable: " << valueAttr;
378367
}
379368
}
380369
if (calculatedLength != totalLength) {
@@ -464,7 +453,7 @@ LogicalResult CompositeAttr::serializeToStream(llvm::support::endianness endian,
464453
auto serializableAttr = valueAttr.dyn_cast<SerializableAttrInterface>();
465454
if (!serializableAttr) {
466455
llvm::errs() << "unable to serialize a non-serializable attribute: "
467-
<< valueAttr.getType() << "\n";
456+
<< valueAttr << "\n";
468457
return failure();
469458
}
470459
if (failed(serializableAttr.serializeToStream(endian, os))) {

compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
8484
return parser.emitError(parser.getCurrentLocation())
8585
<< "expected attribute";
8686
}
87-
typeAttr = TypeAttr::get(attr.getType());
87+
88+
if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
89+
typeAttr = TypeAttr::get(typedAttr.getType());
90+
}
8891
return success();
8992
}
9093

@@ -107,7 +110,8 @@ ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
107110
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
108111
Attribute attr) {
109112
bool needsSpace = false;
110-
if (!attr || attr.getType() != type.getValue()) {
113+
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
114+
if (!typedAttr || typedAttr.getType() != type.getValue()) {
111115
p << ": ";
112116
p.printAttribute(type);
113117
needsSpace = true; // subsequent attr value needs a space separator
@@ -679,7 +683,7 @@ ParseResult UnfoldableConstantOp::parse(OpAsmParser &parser,
679683
// If the attribute is a symbol reference, then we expect a trailing type.
680684
Type type;
681685
if (!valueAttr.isa<SymbolRefAttr>())
682-
type = valueAttr.getType();
686+
type = valueAttr.cast<TypedAttr>().getType();
683687
else if (parser.parseColonType(type))
684688
return failure();
685689

@@ -785,9 +789,16 @@ static bool isGlobalTypeCompatible(Type globalType, Type accessType) {
785789
return globalType == accessType;
786790
}
787791

792+
void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name,
793+
bool isMutable, Type type, TypedAttr initialValue,
794+
ArrayRef<NamedAttribute> attrs) {
795+
build(builder, result, name, isMutable, type,
796+
Optional<TypedAttr>(initialValue), attrs);
797+
}
798+
788799
void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name,
789800
bool isMutable, Type type,
790-
Optional<Attribute> initialValue,
801+
Optional<TypedAttr> initialValue,
791802
ArrayRef<NamedAttribute> attrs) {
792803
result.addAttribute(SymbolTable::getSymbolAttrName(),
793804
builder.getStringAttr(name));

compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
1111
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
1212
include "iree/compiler/Dialect/Util/IR/UtilTypes.td"
13+
include "mlir/IR/BuiltinAttributeInterfaces.td"
1314
include "mlir/IR/FunctionInterfaces.td"
1415
include "mlir/IR/OpAsmInterface.td"
1516
include "mlir/IR/SymbolInterfaces.td"
@@ -337,7 +338,7 @@ def Util_UnfoldableConstantOp : Util_Op<"unfoldable_constant"> {
337338
let results = (outs AnyType);
338339

339340
let builders = [
340-
OpBuilder<(ins "Attribute":$value),
341+
OpBuilder<(ins "TypedAttr":$value),
341342
[{ build($_builder, $_state, value.getType(), value); }]>];
342343

343344
let hasCanonicalizer = 1;
@@ -453,7 +454,7 @@ def Util_GlobalOp : Util_Op<"global", [
453454
SymbolNameAttr:$sym_name,
454455
TypeAttr:$type,
455456
UnitAttr:$is_mutable,
456-
OptionalAttr<AnyAttr>:$initial_value
457+
OptionalAttr<TypedAttrInterface>:$initial_value
457458
);
458459

459460
let assemblyFormat = [{
@@ -470,7 +471,14 @@ def Util_GlobalOp : Util_Op<"global", [
470471
"StringRef":$name,
471472
"bool":$isMutable,
472473
"Type":$type,
473-
"Optional<Attribute>":$initialValue,
474+
"TypedAttr":$initialValue,
475+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
476+
)>,
477+
OpBuilder<(ins
478+
"StringRef":$name,
479+
"bool":$isMutable,
480+
"Type":$type,
481+
"Optional<TypedAttr>":$initialValue,
474482
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
475483
)>,
476484
OpBuilder<(ins

0 commit comments

Comments
 (0)