Skip to content

Commit

Permalink
Integrate llvm-project and bump dependencies. (#10087)
Browse files Browse the repository at this point in the history
* Integrate llvm-project and bump dependencies.

* MHLO_COMMIT=3f94dc7ff68524b1cf2dd420d9997df8f3ccd7a8
* LLVM_COMIT=6c66b089bcd7
* TF_COMMIT=0946dfd12422221e3695a479891aa2f2cad147e5

* Fixes with TypedAttr
* Update mhlo and llvm pointers to cherry-pick fixes
* OpaqueAttr fixes
* Fix EmitC failures in llvm integrate
* Fix StreamBase.td
* Use patched mlir-hlo branch
* fix nvgpu lowering
* Fix convert type utility
* Disable TFL/test/flex_ops.mlir
* Cherry-pick local MHLO patch https://github.com/iree-org/iree-mhlo-fork/commit/a6a3308969278cf9ec36132fc9235314fd091c23

Co-authored-by: Diego Caballero <[email protected]>
Co-authored-by: Simon Camphausen <[email protected]>
Co-authored-by: Hanhan Wang <[email protected]>
Co-authored-by: Thomas Raoux <[email protected]>
Co-authored-by: Stella Laurenzo <[email protected]>
  • Loading branch information
6 people authored Aug 15, 2022
1 parent 36a212a commit 8540f68
Show file tree
Hide file tree
Showing 64 changed files with 637 additions and 718 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ func.func @three_init_tensor_uses() {
%c1638400 = arith.constant 1638400 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 3.40282347E+38 : f32
%cst_0 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<64xf32>
%cst_0 = arith.constant dense_resource<__elided__> : tensor<64xf32>
%cst_1 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:6400x64xf32>
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c1638400) alignment(32) : !flow.dispatch.tensor<readonly:64x64xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func.func @matmul_fill() {
func.func @elementwise() {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%cst = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x10xf32>
%cst = arith.constant dense_resource<__elided__> : tensor<1x10xf32>
%c512 = arith.constant 512 : index
%c64 = arith.constant 64 : index
%c10 = arith.constant 10 : index
Expand Down Expand Up @@ -186,7 +186,7 @@ func.func @elementwise() {
return
}
// CHECK: func.func @elementwise()
// CHECK-DAG: %[[CST_TENSOR:.+]] = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x10xf32>
// CHECK-DAG: %[[CST_TENSOR:.+]] = arith.constant dense_resource<__elided__> : tensor<1x10xf32>
// CHECK-DAG: %[[CST_BUF:.+]] = bufferization.to_memref %[[CST_TENSOR]]
// CHECK-DAG: %[[IN_BUF:.+]] = hal.interface.binding.subspan set(0) binding(0) {{.+}} : memref<1x10xf32>
// CHECK-DAG: %[[OUT_BUF:.+]] = hal.interface.binding.subspan set(0) binding(1) {{.+}} : memref<1x10xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ func.func @binding_ptrs() {
// CHECK: %[[STATE:.+]] = llvm.load %arg1 : !llvm.ptr<struct<[[DISPATCH_STATE_TYPE]]>>
// CHECK: %[[BINDING_PTRS:.+]] = llvm.extractvalue %[[STATE]][10]
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ARRAY_PTR:.+]] = llvm.getelementptr %[[BINDING_PTRS]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[ARRAY_PTR:.+]] = llvm.getelementptr %[[BINDING_PTRS]][1] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[BASE_PTR_I8:.+]] = llvm.load %[[ARRAY_PTR]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[BUFFER_I8:.+]] = llvm.getelementptr %[[BASE_PTR_I8]][%[[C72]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
// CHECK: %[[BUFFER_I8:.+]] = llvm.getelementptr %[[BASE_PTR_I8]][72] : (!llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: %[[BUFFER_F32:.+]] = llvm.bitcast %[[BUFFER_I8]] : !llvm.ptr<i8> to !llvm.ptr<f32>
// CHECK: %[[DESC_A:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC_B:.+]] = llvm.insertvalue %[[BUFFER_F32]], %[[DESC_A]][0]
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:NVGPUDialect",
"@llvm-project//mlir:NVGPUToNVVM",
"@llvm-project//mlir:NVGPUTransforms",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:PDLDialect",
"@llvm-project//mlir:PDLInterpDialect",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ iree_cc_library(
MLIRMemRefTransforms
MLIRNVGPUDialect
MLIRNVGPUToNVVM
MLIRNVGPUTransforms
MLIRNVVMDialect
MLIRPDLDialect
MLIRPDLInterpDialect
Expand Down
21 changes: 19 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
Expand Down Expand Up @@ -241,10 +242,26 @@ struct LLVMGPUVectorToGPUPass
populatePrepareVectorToMMAPatterns(patterns, llvmgpuUseMMASync);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

if (llvmgpuUseMMASync)
if (llvmgpuUseMMASync) {
(void)convertVectorToNVVMCompatibleMMASync(funcOp);
else

// TODO: Remove once populateMmaSyncF32ToTF32Patterns is fixed to not add
// attribute tf32 attributes to none f32 ops.
bool hasFP32mma = false;
funcOp.walk([&hasFP32mma](nvgpu::MmaSyncOp op) {
if (op.getType().cast<VectorType>().getElementType().isF32())
hasFP32mma = true;
});
if (hasFP32mma) {
// Use TF32 for float32 case for now.
RewritePatternSet patterns(funcOp.getContext());
nvgpu::populateMmaSyncF32ToTF32Patterns(
patterns, nvgpu::MmaSyncF32Lowering::TF32);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
} else {
convertVectorToMMAOps(funcOp);
}
createAsyncGroups(funcOp);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ hal.executable @dwconv_elementwise {
hal.executable.export public @dwconv_elementwise layout(#executable_layout)
builtin.module {
func.func @dwconv_elementwise() {
%cst = arith.constant opaque<"_", "0xDEADBEEF"> : tensor<3x3x1x4xf32>
%cst = arith.constant dense_resource<__elided__> : tensor<3x3x1x4xf32>
%cst_8 = arith.constant 1.001000e+00 : f32
%cst_9 = arith.constant 0.000000e+00 : f32
%c18 = arith.constant 18 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class StripAndSplatConstantVariablesPass

auto tensorType = op.getType().cast<TensorType>();
auto elementType = tensorType.getElementType();
DenseElementsAttr newValue;
TypedAttr newValue;
if (elementType.isa<FloatType>()) {
newValue = DenseElementsAttr::get(
tensorType, FloatAttr::get(elementType, 1.0 / replaceIndex));
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ td_library(
),
deps = [
"//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:ViewLikeInterfaceTdFiles",
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include "iree/compiler/Dialect/HAL/IR/HALBase.td"
include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilAttrs.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -1360,7 +1361,7 @@ def HAL_DeviceQueryOp :
HAL_Device:$device,
StrAttr:$category,
StrAttr:$key,
OptionalAttr<AnyAttr>:$default_value
OptionalAttr<TypedAttrInterface>:$default_value
);
let results = (outs
I1:$ok,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct GlobalOpExpansion
auto resourceOp = rewriter.replaceOpWithNewOp<IREE::Util::GlobalOp>(
globalOp, globalOp.getName(), globalOp.getIsMutable(), resourceType,
initialValue && !tensorInitializerRequired
? llvm::Optional<Attribute>{initialValue}
? llvm::Optional<TypedAttr>{initialValue}
: llvm::None);
resourceOp.setVisibility(globalOp.getVisibility());

Expand All @@ -108,7 +108,7 @@ struct GlobalOpExpansion
auto indexType = rewriter.getIndexType();
auto resourceSizeOp = rewriter.create<IREE::Util::GlobalOp>(
globalOp.getLoc(), (globalOp.getName() + "__size").str(),
globalOp.getIsMutable(), indexType, Optional<Attribute>{});
globalOp.getIsMutable(), indexType, Optional<TypedAttr>{});
resourceSizeOp.setVisibility(globalOp.getVisibility());

// Materialize the initializer if we need to setup a tensor-like constant.
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ td_library(
),
deps = [
"//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include "iree/compiler/Dialect/Util/IR/UtilBase.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilTypes.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/SubElementInterfaces.td"

Expand Down Expand Up @@ -320,7 +321,8 @@ def Stream_Timepoint : TypeDef<Stream_Dialect, "Timepoint", [
let parameters = (ins);
}

def Stream_TimepointAttr : AttrDef<Stream_Dialect, "Timepoint", []> {
def Stream_TimepointAttr : AttrDef<Stream_Dialect, "Timepoint",
[TypedAttrInterface]> {
let mnemonic = "timepoint";
let summary = [{an immediately-resolved timepoint}];
let description = [{}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ struct TensorConstantToSplat : public OpRewritePattern<TensorConstantOp> {
"only constant splat attrs can be converted to splat ops");
}

auto splatElementAttr = splatAttr.getSplatValue<Attribute>();
auto splatElementAttr = splatAttr.getSplatValue<TypedAttr>();
auto splatValue = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), splatElementAttr.getType(), splatElementAttr);
auto resultType = IREE::Stream::ResourceType::get(constantOp.getContext());
Expand Down Expand Up @@ -1202,7 +1202,7 @@ struct ConvertSplatConstantsIntoSplats
if (!value.isSplat()) return failure();

auto splatElementAttr =
value.dyn_cast<SplatElementsAttr>().getSplatValue<Attribute>();
value.dyn_cast<SplatElementsAttr>().getSplatValue<TypedAttr>();
auto splatValue = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), splatElementAttr.getType(), splatElementAttr);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncSplatOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ static bool isOutlinableValue(Attribute value) {
if (auto elementsAttr = value.dyn_cast<DenseElementsAttr>()) {
// Don't outline splats - we want those fused.
return !elementsAttr.isSplat();
} else if (auto opaqueAttr = value.dyn_cast<OpaqueElementsAttr>()) {
return !opaqueAttr.isSplat();
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ struct ConstantSlice {
uint64_t getRawLength() const {
if (auto denseAttr = value.dyn_cast<DenseElementsAttr>()) {
return denseAttr.getRawData().size();
} else if (auto opaqueAttr = value.dyn_cast<OpaqueElementsAttr>()) {
// Later on in the pipeline opaque attrs will cause the compiler to fail
// (as at some point we need to get the data) but this allows us to run
// the stream transforms on IR that has had its large constants elided.
return opaqueAttr.getNumElements() *
opaqueAttr.getElementType().getIntOrFloatBitWidth();
} else {
assert(false && "invalid constant attr type");
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct ConstantSet {
// Locations of all constants that went into the table.
SetVector<Location> locs;
// Operand index -> all values from dispatch sites.
SmallVector<std::pair<unsigned, SmallVector<Attribute>>> values;
SmallVector<std::pair<unsigned, SmallVector<TypedAttr>>> values;
};

struct ConstantTable {
Expand Down Expand Up @@ -92,7 +92,7 @@ static ConstantTable buildConstantTable(
set.type = operandType;
typeOrder.push_back(operandType);
}
SmallVector<Attribute> values;
SmallVector<TypedAttr> values;
for (auto dispatchOp : dispatchOps) {
auto operand = dispatchOp.getUniformOperands()[idx];
Attribute constantValue;
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ td_library(
include = ["*.td"],
),
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:CallInterfacesTdFiles",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:FunctionInterfacesTdFiles",
Expand Down
15 changes: 2 additions & 13 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,6 @@ CompositeAttr CompositeAttr::get(MLIRContext *context,
if (auto serializableAttr =
valueAttr.dyn_cast<SerializableAttrInterface>()) {
calculatedLength += serializableAttr.getStorageSize();
} else if (auto opaqueAttr = valueAttr.dyn_cast<OpaqueElementsAttr>()) {
// Allow opaque attrs to be placed into composites ease debugging of IR
// that has had large attrs elided; these will fail to actually serialize
// but being able to run most passes with these unserializable attrs is
// useful.
calculatedLength += opaqueAttr.getNumElements() *
opaqueAttr.getElementType().getIntOrFloatBitWidth();
} else {
return {};
}
Expand All @@ -369,12 +362,8 @@ LogicalResult CompositeAttr::verify(
if (auto serializableAttr =
valueAttr.dyn_cast<SerializableAttrInterface>()) {
calculatedLength += serializableAttr.getStorageSize();
} else if (auto opaqueAttr = valueAttr.dyn_cast<OpaqueElementsAttr>()) {
calculatedLength += opaqueAttr.getNumElements() *
opaqueAttr.getElementType().getIntOrFloatBitWidth();
} else {
return emitError() << "value is not serializable: "
<< valueAttr.getType();
return emitError() << "value is not serializable: " << valueAttr;
}
}
if (calculatedLength != totalLength) {
Expand Down Expand Up @@ -464,7 +453,7 @@ LogicalResult CompositeAttr::serializeToStream(llvm::support::endianness endian,
auto serializableAttr = valueAttr.dyn_cast<SerializableAttrInterface>();
if (!serializableAttr) {
llvm::errs() << "unable to serialize a non-serializable attribute: "
<< valueAttr.getType() << "\n";
<< valueAttr << "\n";
return failure();
}
if (failed(serializableAttr.serializeToStream(endian, os))) {
Expand Down
19 changes: 15 additions & 4 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
typeAttr = TypeAttr::get(attr.getType());

if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
typeAttr = TypeAttr::get(typedAttr.getType());
}
return success();
}

Expand All @@ -107,7 +110,8 @@ ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) {
bool needsSpace = false;
if (!attr || attr.getType() != type.getValue()) {
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
if (!typedAttr || typedAttr.getType() != type.getValue()) {
p << ": ";
p.printAttribute(type);
needsSpace = true; // subsequent attr value needs a space separator
Expand Down Expand Up @@ -679,7 +683,7 @@ ParseResult UnfoldableConstantOp::parse(OpAsmParser &parser,
// If the attribute is a symbol reference, then we expect a trailing type.
Type type;
if (!valueAttr.isa<SymbolRefAttr>())
type = valueAttr.getType();
type = valueAttr.cast<TypedAttr>().getType();
else if (parser.parseColonType(type))
return failure();

Expand Down Expand Up @@ -785,9 +789,16 @@ static bool isGlobalTypeCompatible(Type globalType, Type accessType) {
return globalType == accessType;
}

void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name,
bool isMutable, Type type, TypedAttr initialValue,
ArrayRef<NamedAttribute> attrs) {
build(builder, result, name, isMutable, type,
Optional<TypedAttr>(initialValue), attrs);
}

void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name,
bool isMutable, Type type,
Optional<Attribute> initialValue,
Optional<TypedAttr> initialValue,
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
Expand Down
14 changes: 11 additions & 3 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilTypes.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
Expand Down Expand Up @@ -337,7 +338,7 @@ def Util_UnfoldableConstantOp : Util_Op<"unfoldable_constant"> {
let results = (outs AnyType);

let builders = [
OpBuilder<(ins "Attribute":$value),
OpBuilder<(ins "TypedAttr":$value),
[{ build($_builder, $_state, value.getType(), value); }]>];

let hasCanonicalizer = 1;
Expand Down Expand Up @@ -453,7 +454,7 @@ def Util_GlobalOp : Util_Op<"global", [
SymbolNameAttr:$sym_name,
TypeAttr:$type,
UnitAttr:$is_mutable,
OptionalAttr<AnyAttr>:$initial_value
OptionalAttr<TypedAttrInterface>:$initial_value
);

let assemblyFormat = [{
Expand All @@ -470,7 +471,14 @@ def Util_GlobalOp : Util_Op<"global", [
"StringRef":$name,
"bool":$isMutable,
"Type":$type,
"Optional<Attribute>":$initialValue,
"TypedAttr":$initialValue,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
)>,
OpBuilder<(ins
"StringRef":$name,
"bool":$isMutable,
"Type":$type,
"Optional<TypedAttr>":$initialValue,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
)>,
OpBuilder<(ins
Expand Down
Loading

0 comments on commit 8540f68

Please sign in to comment.