Skip to content

Commit

Permalink
Add StableHLOCreateCompatibilityExpanderPass to create compatibility …
Browse files Browse the repository at this point in the history
…expander for StableHLO operations.

PiperOrigin-RevId: 669131010
  • Loading branch information
abhigunj authored and copybara-github committed Aug 30, 2024
1 parent 51f3b68 commit 09debc3
Showing 1 changed file with 388 additions and 0 deletions.
388 changes: 388 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
@@ -1 +1,389 @@
diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel
--- stablehlo/BUILD.bazel
+++ stablehlo/BUILD.bazel
@@ -340,6 +340,21 @@
],
)

+gentbl_cc_library(
+ name = "stablehlo_create_compatibility_expander_inc_gen",
+ tbl_outs = [
+ (
+ ["--gen-rewriters"],
+ "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td",
+ deps = [
+ ":stablehlo_ops_td_files",
+ ],
+)
+
cc_library(
name = "interpreter_ops",
srcs = [
@@ -1086,6 +1101,7 @@
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
"stablehlo/transforms/StablehloConvertToSignless.cpp",
+ "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp",
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
"stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp",
@@ -1109,6 +1125,7 @@
":chlo_ops",
":chlo_rewriters_inc_gen",
":linalg_passes",
+ ":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
":stablehlo_ops",
":stablehlo_ops_inc_gen",
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
@@ -0,0 +1,43 @@
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE
+
+// -----
+
+// CHECK-LABEL @tan_op_non_complex
+// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64>
+// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64>
+// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64>
+// CHECK-NEXT: return %[[div2]] : tensor<4xf64>
+func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64>
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
+ func.return %1 : tensor<4xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @tan_op_complex
+// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64>
+// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
+// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64>
+// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64>
+// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64>
+// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64>
+// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex<f64>>
+// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64>
+// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64>
+// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex<f64>>
+// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex<f64>>
+// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64>
+func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) {
+ %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex<f64>>
+ %1 = stablehlo.tan %0 : tensor<4xcomplex<f64>>
+ %2 = stablehlo.real %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+ %3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+ func.return %2, %3 : tensor<4xf64>, tensor<4xf64>
+}
diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt
--- stablehlo/stablehlo/transforms/CMakeLists.txt
+++ stablehlo/stablehlo/transforms/CMakeLists.txt
@@ -20,6 +20,10 @@
mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(ChloDecompositionPatternsIncGen)

+set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td)
+mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters)
+add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen)
+
set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td)
mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen)
@@ -27,6 +31,7 @@
set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td)
mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(VhloToVersionPatterns)
+

add_mlir_dialect_library(StablehloPasses
PARTIAL_SOURCES_INTENDED
@@ -37,6 +42,7 @@
StablehloAggressiveSimplification.cpp
StablehloCanonicalizeDynamism.cpp
StablehloConvertToSignless.cpp
+ StablehloCreateCompatibilityExpander.cpp
StablehloLegalizeCompositeToCall.cpp
StablehloLegalizeDeprecatedOps.cpp
StablehloLegalizeQuantToMath.cpp
@@ -53,6 +59,7 @@
StablehloLegalizeDeprecatedOpsPatternsIncGen
PassesIncGen
VhloToVersionPatterns
+ StablehloCreateCompatibilityExpanderPatternsIncGen

LINK_LIBS PUBLIC
ChloOps
diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
--- stablehlo/stablehlo/transforms/Passes.h
+++ stablehlo/stablehlo/transforms/Passes.h
@@ -25,6 +25,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "stablehlo/dialect/Version.h"

namespace mlir {
namespace stablehlo {
@@ -96,6 +97,12 @@
void populateShapeToStablehloPatterns(MLIRContext *context,
RewritePatternSet *patterns);

+/// Collection of patterns to create compatibility expander for StableHLO
+/// operations.
+void populateStablehloCreateCompatibilityExpanderPatterns(
+ RewritePatternSet *patterns, MLIRContext *context,
+ vhlo::Version targetVersion);
+
//// Additional pass constructors ////

std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td
--- stablehlo/stablehlo/transforms/Passes.td
+++ stablehlo/stablehlo/transforms/Passes.td
@@ -292,3 +292,51 @@
"mlir::stablehlo::StablehloDialect",
];
}
+
+def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> {
+ let summary = "Create compatibility expander for StableHLO operations.";
+
+ let description = [{
+ StableHLO ops gets updates or new op is introduced in the latest versions.
+ This opt-in pass expands backward compatibility with older StableHLO
+ versions by decomposing newer StableHLO operations into equivalent
+ operations supported by those older versions.
+
+ Why is this an opt-in pass?
+
+ Occasionally, StableHLO op enhancements are used to greatly simplify the
+ handling of certain common patterns in the OpenXLA ecosystem. This
+ includes things like TanOp, which has high framework and compiler support,
+ as well as gather/scatter batching dimensions, which can be represented
+ using slices, but makes sharding much more difficult. For this category of
+ new features, we do not offer automatic downgrade, since it may throw away
+ important information used in subsequent optimizations. This pass can be
+ used to expand these ops based on a target version to maximize compatibility
+ at the expense of potentially less optimal compilation.
+
+ ```mlir
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
+ func.return %1 : tensor<4xf64>
+ }
+ ```
+
+ will become:
+
+ ```mlir
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
+ %0 = stablehlo.sine %arg0 : tensor<4xf64>
+ %1 = stablehlo.cosine %arg0 : tensor<4xf64>
+ %2 = stablehlo.divide %0, %1 : tensor<4xf64>
+ return %2 : tensor<4xf64>
+ }
+ ```
+ }];
+ let options = [
+ Option<"targetVersionOption", "target", "std::string", "",
+ "The target version. Must be a version of the form #.#.#.">,
+ ];
+ let dependentDialects = [
+ "mlir::stablehlo::StablehloDialect",
+ ];
+}
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
@@ -0,0 +1,128 @@
+/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fcntl.h>
+
+#include <cassert>
+
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/Version.h"
+#include "stablehlo/transforms/Passes.h"
+
+namespace mlir {
+namespace stablehlo {
+#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS
+#include "stablehlo/transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helpers.
+//===----------------------------------------------------------------------===//
+
+// Creates a constant with all ones.
+static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) {
+ auto shapedTy = dyn_cast<mlir::ShapedType>(val.getType());
+ if (!shapedTy) llvm_unreachable("Unsupported shaped type.");
+
+ mlir::DenseElementsAttr elementsAttr =
+ mlir::DenseElementsAttr::get(shapedTy, 1.0);
+
+ return b.create<mlir::stablehlo::ConstantOp>(loc, val.getType(),
+ elementsAttr);
+}
+
+// Check user-specified target version.
+vhlo::Version validateTargetVersion(llvm::StringRef versionRef) {
+ auto failOrVersion = vhlo::Version::fromString(versionRef);
+ if (failed(failOrVersion)) {
+ assert(!versionRef.empty() &&
+ "No target version specified. Target version must be of the form "
+ "`#.#.#`.");
+ assert(versionRef.empty() &&
+ "Invalid target version argument. Target version must be of the "
+ "form `#.#.#`.");
+ }
+ vhlo::Version targetVersion = *failOrVersion;
+ assert((vhlo::Version::getMinimumVersion() <= targetVersion) &&
+ "target version is less than minimum supported.");
+ assert((targetVersion <= vhlo::Version::getCurrentVersion()) &&
+ "target version is greater than current version.");
+ return targetVersion;
+}
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+struct StablehloCreateCompatibilityExpanderPass
+ : public impl::StablehloCreateCompatibilityExpanderPassBase<
+ StablehloCreateCompatibilityExpanderPass> {
+ StablehloCreateCompatibilityExpanderPass()
+ : StablehloCreateCompatibilityExpanderPassBase<
+ StablehloCreateCompatibilityExpanderPass>() {}
+ StablehloCreateCompatibilityExpanderPass(
+ const StablehloCreateCompatibilityExpanderPassOptions &opts)
+ : StablehloCreateCompatibilityExpanderPassBase<
+ StablehloCreateCompatibilityExpanderPass>(opts) {}
+
+ public:
+ LogicalResult initialize(MLIRContext *context) override {
+ auto targetVersion = validateTargetVersion(targetVersionOption);
+
+ config.useTopDownTraversal = true;
+ RewritePatternSet patterns_(context);
+ populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context,
+ targetVersion);
+ patterns = std::move(patterns_);
+ return success();
+ }
+
+ void runOnOperation() override {
+ auto func = getOperation();
+ if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
+ func.emitError(
+ "Failed to converge StableHLOCreateCompatibilityExpanderPass in ")
+ << config.maxIterations << " iterations";
+ signalPassFailure();
+ }
+ }
+
+ private:
+ FrozenRewritePatternSet patterns;
+ GreedyRewriteConfig config;
+};
+
+#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc"
+
+} // namespace
+
+void populateStablehloCreateCompatibilityExpanderPatterns(
+ RewritePatternSet *patterns, MLIRContext *context,
+ vhlo::Version targetVersion) {
+ // StableHLO TanOp is introduced in v1.4.0.
+ if (targetVersion < vhlo::Version(1, 4, 0)) {
+ patterns->add<TanOp_ComplexElementType_CompatiblityExpander>(context);
+ patterns->add<TanOp_CompatiblityExpander>(context);
+ }
+}
+
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
@@ -0,0 +1,47 @@
+/* Copyright 2022 The StableHLO Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+include "mlir/IR/OpBase.td"
+include "stablehlo/dialect/StablehloOps.td"
+
+def ComplexElementType : Type<
+ CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
+ "Complex element type">;
+
+def NonComplexElementType : Type<
+ CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
+ "Non-complex element type">;
+
+def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">;
+
+// Express `tan` as
+// sine(x) / cosine(x)
+def TanOp_CompatiblityExpander : Pat<(StableHLO_TanOp NonComplexElementType:$input),
+ (StableHLO_DivOp
+ (StableHLO_SineOp $input),
+ (StableHLO_CosineOp $input)
+ )>;
+
+// Express `tan(a + bi)` as
+// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b))
+def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp ComplexElementType:$input),
+ (StableHLO_DivOp
+ (StableHLO_ComplexOp
+ (StableHLO_TanOp:$tan (StableHLO_RealOp $input)),
+ (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))),
+ (StableHLO_ComplexOp
+ (createConstantWithAllOnes $tan),
+ (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
+ )>;

0 comments on commit 09debc3

Please sign in to comment.