diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b13789..8bac9da 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,389 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -340,6 +340,21 @@ + ], + ) + ++gentbl_cc_library( ++ name = "stablehlo_create_compatibility_expander_inc_gen", ++ tbl_outs = [ ++ ( ++ ["--gen-rewriters"], ++ "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td", ++ deps = [ ++ ":stablehlo_ops_td_files", ++ ], ++) ++ + cc_library( + name = "interpreter_ops", + srcs = [ +@@ -1086,6 +1101,7 @@ + "stablehlo/transforms/StablehloAggressiveSimplification.cpp", + "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", + "stablehlo/transforms/StablehloConvertToSignless.cpp", ++ "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp", + "stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp", + "stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp", + "stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp", +@@ -1109,6 +1125,7 @@ + ":chlo_ops", + ":chlo_rewriters_inc_gen", + ":linalg_passes", ++ ":stablehlo_create_compatibility_expander_inc_gen", + ":stablehlo_legalize_deprecated_ops_inc_gen", + ":stablehlo_ops", + ":stablehlo_ops_inc_gen", +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ++++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +@@ -0,0 +1,43 @@ ++// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK ++// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE ++ ++// ----- ++ ++// CHECK-LABEL @tan_op_non_complex ++// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64> ++// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64> ++// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64> ++// CHECK-NEXT: return %[[div2]] : tensor<4xf64> ++func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { ++ // CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64> ++ %1 = stablehlo.tan %arg0 : tensor<4xf64> ++ func.return %1 : tensor<4xf64> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @tan_op_complex ++// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64> ++// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex> ++// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex>) -> tensor<4xf64> ++// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64> ++// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64> ++// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64> ++// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex>) -> tensor<4xf64> ++// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64> ++// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex> ++// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64> ++// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64> ++// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex> ++// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex> ++// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex>) -> tensor<4xf64> ++// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex>) -> tensor<4xf64> ++// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64> ++func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) { ++ %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex> ++ // CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex> ++ %1 = stablehlo.tan %0 : tensor<4xcomplex> ++ %2 = stablehlo.real %1 : (tensor<4xcomplex>) -> tensor<4xf64> ++ %3 = stablehlo.imag %1 : (tensor<4xcomplex>) -> tensor<4xf64> ++ func.return %2, %3 : tensor<4xf64>, tensor<4xf64> ++} +diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt +--- stablehlo/stablehlo/transforms/CMakeLists.txt ++++ stablehlo/stablehlo/transforms/CMakeLists.txt +@@ -20,6 +20,10 @@ + mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters) + add_public_tablegen_target(ChloDecompositionPatternsIncGen) + ++set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td) ++mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters) ++add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen) ++ + set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td) + mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters) + add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen) +@@ -27,6 +31,7 @@ + set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td) + mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters) + add_public_tablegen_target(VhloToVersionPatterns) ++ + + add_mlir_dialect_library(StablehloPasses + PARTIAL_SOURCES_INTENDED +@@ -37,6 +42,7 @@ + StablehloAggressiveSimplification.cpp + StablehloCanonicalizeDynamism.cpp + StablehloConvertToSignless.cpp ++ StablehloCreateCompatibilityExpander.cpp + StablehloLegalizeCompositeToCall.cpp + StablehloLegalizeDeprecatedOps.cpp + StablehloLegalizeQuantToMath.cpp +@@ -53,6 +59,7 @@ + StablehloLegalizeDeprecatedOpsPatternsIncGen + PassesIncGen + VhloToVersionPatterns ++ StablehloCreateCompatibilityExpanderPatternsIncGen + + LINK_LIBS PUBLIC + ChloOps +diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h +--- stablehlo/stablehlo/transforms/Passes.h ++++ stablehlo/stablehlo/transforms/Passes.h +@@ -25,6 +25,7 @@ + #include "mlir/Pass/Pass.h" + #include "mlir/Support/LogicalResult.h" + #include "mlir/Transforms/DialectConversion.h" ++#include "stablehlo/dialect/Version.h" + + namespace mlir { + namespace stablehlo { +@@ -96,6 +97,12 @@ + void populateShapeToStablehloPatterns(MLIRContext *context, + RewritePatternSet *patterns); + ++/// Collection of patterns to create compatibility expander for StableHLO ++/// operations. ++void populateStablehloCreateCompatibilityExpanderPatterns( ++ RewritePatternSet *patterns, MLIRContext *context, ++ vhlo::Version targetVersion); ++ + //// Additional pass constructors //// + + std::unique_ptr> createStablehloRefineArgumentsPass( +diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td +--- stablehlo/stablehlo/transforms/Passes.td ++++ stablehlo/stablehlo/transforms/Passes.td +@@ -292,3 +292,51 @@ + "mlir::stablehlo::StablehloDialect", + ]; + } ++ ++def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> { ++ let summary = "Create compatibility expander for StableHLO operations."; ++ ++ let description = [{ ++ StableHLO ops gets updates or new op is introduced in the latest versions. ++ This opt-in pass expands backward compatibility with older StableHLO ++ versions by decomposing newer StableHLO operations into equivalent ++ operations supported by those older versions. ++ ++ Why is this an opt-in pass? ++ ++ Occasionally, StableHLO op enhancements are used to greatly simplify the ++ handling of certain common patterns in the OpenXLA ecosystem. This ++ includes things like TanOp, which has high framework and compiler support, ++ as well as gather/scatter batching dimensions, which can be represented ++ using slices, but makes sharding much more difficult. For this category of ++ new features, we do not offer automatic downgrade, since it may throw away ++ important information used in subsequent optimizations. This pass can be ++ used to expand these ops based on a target version to maximize compatibility ++ at the expense of potentially less optimal compilation. ++ ++ ```mlir ++ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { ++ %1 = stablehlo.tan %arg0 : tensor<4xf64> ++ func.return %1 : tensor<4xf64> ++ } ++ ``` ++ ++ will become: ++ ++ ```mlir ++ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { ++ %0 = stablehlo.sine %arg0 : tensor<4xf64> ++ %1 = stablehlo.cosine %arg0 : tensor<4xf64> ++ %2 = stablehlo.divide %0, %1 : tensor<4xf64> ++ return %2 : tensor<4xf64> ++ } ++ ``` ++ }]; ++ let options = [ ++ Option<"targetVersionOption", "target", "std::string", "", ++ "The target version. Must be a version of the form #.#.#.">, ++ ]; ++ let dependentDialects = [ ++ "mlir::stablehlo::StablehloDialect", ++ ]; ++} +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ++++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +@@ -0,0 +1,128 @@ ++/* Copyright 2024 The StableHLO Authors. All Rights Reserved. ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ http://www.apache.org/licenses/LICENSE-2.0 ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include ++ ++#include ++ ++#include "llvm/ADT/APFloat.h" ++#include "llvm/Support/ErrorHandling.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/Support/LLVM.h" ++#include "mlir/Transforms/DialectConversion.h" ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++#include "stablehlo/dialect/StablehloOps.h" ++#include "stablehlo/dialect/Version.h" ++#include "stablehlo/transforms/Passes.h" ++ ++namespace mlir { ++namespace stablehlo { ++#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS ++#include "stablehlo/transforms/Passes.h.inc" ++ ++namespace { ++ ++//===----------------------------------------------------------------------===// ++// Helpers. ++//===----------------------------------------------------------------------===// ++ ++// Creates a constant with all ones. ++static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) { ++ auto shapedTy = dyn_cast(val.getType()); ++ if (!shapedTy) llvm_unreachable("Unsupported shaped type."); ++ ++ mlir::DenseElementsAttr elementsAttr = ++ mlir::DenseElementsAttr::get(shapedTy, 1.0); ++ ++ return b.create(loc, val.getType(), ++ elementsAttr); ++} ++ ++// Check user-specified target version. ++vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { ++ auto failOrVersion = vhlo::Version::fromString(versionRef); ++ if (failed(failOrVersion)) { ++ assert(!versionRef.empty() && ++ "No target version specified. Target version must be of the form " ++ "`#.#.#`."); ++ assert(versionRef.empty() && ++ "Invalid target version argument. Target version must be of the " ++ "form `#.#.#`."); ++ } ++ vhlo::Version targetVersion = *failOrVersion; ++ assert((vhlo::Version::getMinimumVersion() <= targetVersion) && ++ "target version is less than minimum supported."); ++ assert((targetVersion <= vhlo::Version::getCurrentVersion()) && ++ "target version is greater than current version."); ++ return targetVersion; ++} ++ ++//===----------------------------------------------------------------------===// ++// Pass ++//===----------------------------------------------------------------------===// ++ ++struct StablehloCreateCompatibilityExpanderPass ++ : public impl::StablehloCreateCompatibilityExpanderPassBase< ++ StablehloCreateCompatibilityExpanderPass> { ++ StablehloCreateCompatibilityExpanderPass() ++ : StablehloCreateCompatibilityExpanderPassBase< ++ StablehloCreateCompatibilityExpanderPass>() {} ++ StablehloCreateCompatibilityExpanderPass( ++ const StablehloCreateCompatibilityExpanderPassOptions &opts) ++ : StablehloCreateCompatibilityExpanderPassBase< ++ StablehloCreateCompatibilityExpanderPass>(opts) {} ++ ++ public: ++ LogicalResult initialize(MLIRContext *context) override { ++ auto targetVersion = validateTargetVersion(targetVersionOption); ++ ++ config.useTopDownTraversal = true; ++ RewritePatternSet patterns_(context); ++ populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context, ++ targetVersion); ++ patterns = std::move(patterns_); ++ return success(); ++ } ++ ++ void runOnOperation() override { ++ auto func = getOperation(); ++ if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { ++ func.emitError( ++ "Failed to converge StableHLOCreateCompatibilityExpanderPass in ") ++ << config.maxIterations << " iterations"; ++ signalPassFailure(); ++ } ++ } ++ ++ private: ++ FrozenRewritePatternSet patterns; ++ GreedyRewriteConfig config; ++}; ++ ++#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc" ++ ++} // namespace ++ ++void populateStablehloCreateCompatibilityExpanderPatterns( ++ RewritePatternSet *patterns, MLIRContext *context, ++ vhlo::Version targetVersion) { ++ // StableHLO TanOp is introduced in v1.4.0. ++ if (targetVersion < vhlo::Version(1, 4, 0)) { ++ patterns->add(context); ++ patterns->add(context); ++ } ++} ++ ++} // namespace stablehlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td +--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td ++++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td +@@ -0,0 +1,47 @@ ++/* Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++include "mlir/IR/OpBase.td" ++include "stablehlo/dialect/StablehloOps.td" ++ ++def ComplexElementType : Type< ++ CPred<"isa(cast($_self).getElementType())">, ++ "Complex element type">; ++ ++def NonComplexElementType : Type< ++ CPred<"!isa(cast($_self).getElementType())">, ++ "Non-complex element type">; ++ ++def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">; ++ ++// Express `tan` as ++// sine(x) / cosine(x) ++def TanOp_CompatiblityExpander : Pat<(StableHLO_TanOp NonComplexElementType:$input), ++ (StableHLO_DivOp ++ (StableHLO_SineOp $input), ++ (StableHLO_CosineOp $input) ++ )>; ++ ++// Express `tan(a + bi)` as ++// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b)) ++def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp ComplexElementType:$input), ++ (StableHLO_DivOp ++ (StableHLO_ComplexOp ++ (StableHLO_TanOp:$tan (StableHLO_RealOp $input)), ++ (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), ++ (StableHLO_ComplexOp ++ (createConstantWithAllOnes $tan), ++ (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) ++ )>;