|
| 1 | +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel |
| 2 | +--- stablehlo/BUILD.bazel |
| 3 | ++++ stablehlo/BUILD.bazel |
| 4 | +@@ -340,6 +340,21 @@ |
| 5 | + ], |
| 6 | + ) |
| 7 | + |
| 8 | ++gentbl_cc_library( |
| 9 | ++ name = "stablehlo_create_compatibility_expander_inc_gen", |
| 10 | ++ tbl_outs = [ |
| 11 | ++ ( |
| 12 | ++ ["--gen-rewriters"], |
| 13 | ++ "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc", |
| 14 | ++ ), |
| 15 | ++ ], |
| 16 | ++ tblgen = "@llvm-project//mlir:mlir-tblgen", |
| 17 | ++ td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td", |
| 18 | ++ deps = [ |
| 19 | ++ ":stablehlo_ops_td_files", |
| 20 | ++ ], |
| 21 | ++) |
| 22 | ++ |
| 23 | + cc_library( |
| 24 | + name = "interpreter_ops", |
| 25 | + srcs = [ |
| 26 | +@@ -1086,6 +1101,7 @@ |
| 27 | + "stablehlo/transforms/StablehloAggressiveSimplification.cpp", |
| 28 | + "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", |
| 29 | + "stablehlo/transforms/StablehloConvertToSignless.cpp", |
| 30 | ++ "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp", |
| 31 | + "stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp", |
| 32 | + "stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp", |
| 33 | + "stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp", |
| 34 | +@@ -1109,6 +1125,7 @@ |
| 35 | + ":chlo_ops", |
| 36 | + ":chlo_rewriters_inc_gen", |
| 37 | + ":linalg_passes", |
| 38 | ++ ":stablehlo_create_compatibility_expander_inc_gen", |
| 39 | + ":stablehlo_legalize_deprecated_ops_inc_gen", |
| 40 | + ":stablehlo_ops", |
| 41 | + ":stablehlo_ops_inc_gen", |
| 42 | +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir |
| 43 | +--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir |
| 44 | ++++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir |
| 45 | +@@ -0,0 +1,43 @@ |
| 46 | ++// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK |
| 47 | ++// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE |
| 48 | ++ |
| 49 | ++// ----- |
| 50 | ++ |
| 51 | ++// CHECK-LABEL @tan_op_non_complex |
| 52 | ++// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64> |
| 53 | ++// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64> |
| 54 | ++// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64> |
| 55 | ++// CHECK-NEXT: return %[[div2]] : tensor<4xf64> |
| 56 | ++func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { |
| 57 | ++ // CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64> |
| 58 | ++ %1 = stablehlo.tan %arg0 : tensor<4xf64> |
| 59 | ++ func.return %1 : tensor<4xf64> |
| 60 | ++} |
| 61 | ++ |
| 62 | ++// ----- |
| 63 | ++ |
| 64 | ++// CHECK-LABEL: @tan_op_complex |
| 65 | ++// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64> |
| 66 | ++// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>> |
| 67 | ++// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64> |
| 68 | ++// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64> |
| 69 | ++// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64> |
| 70 | ++// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64> |
| 71 | ++// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64> |
| 72 | ++// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64> |
| 73 | ++// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex<f64>> |
| 74 | ++// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64> |
| 75 | ++// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64> |
| 76 | ++// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex<f64>> |
| 77 | ++// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex<f64>> |
| 78 | ++// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64> |
| 79 | ++// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64> |
| 80 | ++// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64> |
| 81 | ++func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) { |
| 82 | ++ %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>> |
| 83 | ++ // CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex<f64>> |
| 84 | ++ %1 = stablehlo.tan %0 : tensor<4xcomplex<f64>> |
| 85 | ++ %2 = stablehlo.real %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64> |
| 86 | ++ %3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64> |
| 87 | ++ func.return %2, %3 : tensor<4xf64>, tensor<4xf64> |
| 88 | ++} |
| 89 | +diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt |
| 90 | +--- stablehlo/stablehlo/transforms/CMakeLists.txt |
| 91 | ++++ stablehlo/stablehlo/transforms/CMakeLists.txt |
| 92 | +@@ -20,6 +20,10 @@ |
| 93 | + mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters) |
| 94 | + add_public_tablegen_target(ChloDecompositionPatternsIncGen) |
| 95 | + |
| 96 | ++set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td) |
| 97 | ++mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters) |
| 98 | ++add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen) |
| 99 | ++ |
| 100 | + set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td) |
| 101 | + mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters) |
| 102 | + add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen) |
| 103 | +@@ -27,6 +31,7 @@ |
| 104 | + set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td) |
| 105 | + mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters) |
| 106 | + add_public_tablegen_target(VhloToVersionPatterns) |
| 107 | ++ |
| 108 | + |
| 109 | + add_mlir_dialect_library(StablehloPasses |
| 110 | + PARTIAL_SOURCES_INTENDED |
| 111 | +@@ -37,6 +42,7 @@ |
| 112 | + StablehloAggressiveSimplification.cpp |
| 113 | + StablehloCanonicalizeDynamism.cpp |
| 114 | + StablehloConvertToSignless.cpp |
| 115 | ++ StablehloCreateCompatibilityExpander.cpp |
| 116 | + StablehloLegalizeCompositeToCall.cpp |
| 117 | + StablehloLegalizeDeprecatedOps.cpp |
| 118 | + StablehloLegalizeQuantToMath.cpp |
| 119 | +@@ -53,6 +59,7 @@ |
| 120 | + StablehloLegalizeDeprecatedOpsPatternsIncGen |
| 121 | + PassesIncGen |
| 122 | + VhloToVersionPatterns |
| 123 | ++ StablehloCreateCompatibilityExpanderPatternsIncGen |
| 124 | + |
| 125 | + LINK_LIBS PUBLIC |
| 126 | + ChloOps |
| 127 | +diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h |
| 128 | +--- stablehlo/stablehlo/transforms/Passes.h |
| 129 | ++++ stablehlo/stablehlo/transforms/Passes.h |
| 130 | +@@ -25,6 +25,7 @@ |
| 131 | + #include "mlir/Pass/Pass.h" |
| 132 | + #include "mlir/Support/LogicalResult.h" |
| 133 | + #include "mlir/Transforms/DialectConversion.h" |
| 134 | ++#include "stablehlo/dialect/Version.h" |
| 135 | + |
| 136 | + namespace mlir { |
| 137 | + namespace stablehlo { |
| 138 | +@@ -96,6 +97,12 @@ |
| 139 | + void populateShapeToStablehloPatterns(MLIRContext *context, |
| 140 | + RewritePatternSet *patterns); |
| 141 | + |
| 142 | ++/// Collection of patterns to create compatibility expander for StableHLO |
| 143 | ++/// operations. |
| 144 | ++void populateStablehloCreateCompatibilityExpanderPatterns( |
| 145 | ++ RewritePatternSet *patterns, MLIRContext *context, |
| 146 | ++ vhlo::Version targetVersion); |
| 147 | ++ |
| 148 | + //// Additional pass constructors //// |
| 149 | + |
| 150 | + std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass( |
| 151 | +diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td |
| 152 | +--- stablehlo/stablehlo/transforms/Passes.td |
| 153 | ++++ stablehlo/stablehlo/transforms/Passes.td |
| 154 | +@@ -292,3 +292,51 @@ |
| 155 | + "mlir::stablehlo::StablehloDialect", |
| 156 | + ]; |
| 157 | + } |
| 158 | ++ |
| 159 | ++def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> { |
| 160 | ++ let summary = "Create compatibility expander for StableHLO operations."; |
| 161 | ++ |
| 162 | ++ let description = [{ |
| 163 | ++ StableHLO ops gets updates or new op is introduced in the latest versions. |
| 164 | ++ This opt-in pass expands backward compatibility with older StableHLO |
| 165 | ++ versions by decomposing newer StableHLO operations into equivalent |
| 166 | ++ operations supported by those older versions. |
| 167 | ++ |
| 168 | ++ Why is this an opt-in pass? |
| 169 | ++ |
| 170 | ++ Occasionally, StableHLO op enhancements are used to greatly simplify the |
| 171 | ++ handling of certain common patterns in the OpenXLA ecosystem. This |
| 172 | ++ includes things like TanOp, which has high framework and compiler support, |
| 173 | ++ as well as gather/scatter batching dimensions, which can be represented |
| 174 | ++ using slices, but makes sharding much more difficult. For this category of |
| 175 | ++ new features, we do not offer automatic downgrade, since it may throw away |
| 176 | ++ important information used in subsequent optimizations. This pass can be |
| 177 | ++ used to expand these ops based on a target version to maximize compatibility |
| 178 | ++ at the expense of potentially less optimal compilation. |
| 179 | ++ |
| 180 | ++ ```mlir |
| 181 | ++ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { |
| 182 | ++ %1 = stablehlo.tan %arg0 : tensor<4xf64> |
| 183 | ++ func.return %1 : tensor<4xf64> |
| 184 | ++ } |
| 185 | ++ ``` |
| 186 | ++ |
| 187 | ++ will become: |
| 188 | ++ |
| 189 | ++ ```mlir |
| 190 | ++ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { |
| 191 | ++ %0 = stablehlo.sine %arg0 : tensor<4xf64> |
| 192 | ++ %1 = stablehlo.cosine %arg0 : tensor<4xf64> |
| 193 | ++ %2 = stablehlo.divide %0, %1 : tensor<4xf64> |
| 194 | ++ return %2 : tensor<4xf64> |
| 195 | ++ } |
| 196 | ++ ``` |
| 197 | ++ }]; |
| 198 | ++ let options = [ |
| 199 | ++ Option<"targetVersionOption", "target", "std::string", "", |
| 200 | ++ "The target version. Must be a version of the form #.#.#.">, |
| 201 | ++ ]; |
| 202 | ++ let dependentDialects = [ |
| 203 | ++ "mlir::stablehlo::StablehloDialect", |
| 204 | ++ ]; |
| 205 | ++} |
| 206 | +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp |
| 207 | +--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp |
| 208 | ++++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp |
| 209 | +@@ -0,0 +1,128 @@ |
| 210 | ++/* Copyright 2024 The StableHLO Authors. All Rights Reserved. |
| 211 | ++Licensed under the Apache License, Version 2.0 (the "License"); |
| 212 | ++you may not use this file except in compliance with the License. |
| 213 | ++You may obtain a copy of the License at |
| 214 | ++ http://www.apache.org/licenses/LICENSE-2.0 |
| 215 | ++Unless required by applicable law or agreed to in writing, software |
| 216 | ++distributed under the License is distributed on an "AS IS" BASIS, |
| 217 | ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 218 | ++See the License for the specific language governing permissions and |
| 219 | ++limitations under the License. |
| 220 | ++==============================================================================*/ |
| 221 | ++ |
| 222 | ++#include <fcntl.h> |
| 223 | ++ |
| 224 | ++#include <cassert> |
| 225 | ++ |
| 226 | ++#include "llvm/ADT/APFloat.h" |
| 227 | ++#include "llvm/Support/ErrorHandling.h" |
| 228 | ++#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 229 | ++#include "mlir/IR/BuiltinAttributes.h" |
| 230 | ++#include "mlir/IR/PatternMatch.h" |
| 231 | ++#include "mlir/Support/LLVM.h" |
| 232 | ++#include "mlir/Transforms/DialectConversion.h" |
| 233 | ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 234 | ++#include "stablehlo/dialect/StablehloOps.h" |
| 235 | ++#include "stablehlo/dialect/Version.h" |
| 236 | ++#include "stablehlo/transforms/Passes.h" |
| 237 | ++ |
| 238 | ++namespace mlir { |
| 239 | ++namespace stablehlo { |
| 240 | ++#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS |
| 241 | ++#include "stablehlo/transforms/Passes.h.inc" |
| 242 | ++ |
| 243 | ++namespace { |
| 244 | ++ |
| 245 | ++//===----------------------------------------------------------------------===// |
| 246 | ++// Helpers. |
| 247 | ++//===----------------------------------------------------------------------===// |
| 248 | ++ |
| 249 | ++// Creates a constant with all ones. |
| 250 | ++static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) { |
| 251 | ++ auto shapedTy = dyn_cast<mlir::ShapedType>(val.getType()); |
| 252 | ++ if (!shapedTy) llvm_unreachable("Unsupported shaped type."); |
| 253 | ++ |
| 254 | ++ mlir::DenseElementsAttr elementsAttr = |
| 255 | ++ mlir::DenseElementsAttr::get(shapedTy, 1.0); |
| 256 | ++ |
| 257 | ++ return b.create<mlir::stablehlo::ConstantOp>(loc, val.getType(), |
| 258 | ++ elementsAttr); |
| 259 | ++} |
| 260 | ++ |
| 261 | ++// Check user-specified target version. |
| 262 | ++vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { |
| 263 | ++ auto failOrVersion = vhlo::Version::fromString(versionRef); |
| 264 | ++ if (failed(failOrVersion)) { |
| 265 | ++ assert(!versionRef.empty() && |
| 266 | ++ "No target version specified. Target version must be of the form " |
| 267 | ++ "`#.#.#`."); |
| 268 | ++ assert(versionRef.empty() && |
| 269 | ++ "Invalid target version argument. Target version must be of the " |
| 270 | ++ "form `#.#.#`."); |
| 271 | ++ } |
| 272 | ++ vhlo::Version targetVersion = *failOrVersion; |
| 273 | ++ assert((vhlo::Version::getMinimumVersion() <= targetVersion) && |
| 274 | ++ "target version is less than minimum supported."); |
| 275 | ++ assert((targetVersion <= vhlo::Version::getCurrentVersion()) && |
| 276 | ++ "target version is greater than current version."); |
| 277 | ++ return targetVersion; |
| 278 | ++} |
| 279 | ++ |
| 280 | ++//===----------------------------------------------------------------------===// |
| 281 | ++// Pass |
| 282 | ++//===----------------------------------------------------------------------===// |
| 283 | ++ |
| 284 | ++struct StablehloCreateCompatibilityExpanderPass |
| 285 | ++ : public impl::StablehloCreateCompatibilityExpanderPassBase< |
| 286 | ++ StablehloCreateCompatibilityExpanderPass> { |
| 287 | ++ StablehloCreateCompatibilityExpanderPass() |
| 288 | ++ : StablehloCreateCompatibilityExpanderPassBase< |
| 289 | ++ StablehloCreateCompatibilityExpanderPass>() {} |
| 290 | ++ StablehloCreateCompatibilityExpanderPass( |
| 291 | ++ const StablehloCreateCompatibilityExpanderPassOptions &opts) |
| 292 | ++ : StablehloCreateCompatibilityExpanderPassBase< |
| 293 | ++ StablehloCreateCompatibilityExpanderPass>(opts) {} |
| 294 | ++ |
| 295 | ++ public: |
| 296 | ++ LogicalResult initialize(MLIRContext *context) override { |
| 297 | ++ auto targetVersion = validateTargetVersion(targetVersionOption); |
| 298 | ++ |
| 299 | ++ config.useTopDownTraversal = true; |
| 300 | ++ RewritePatternSet patterns_(context); |
| 301 | ++ populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context, |
| 302 | ++ targetVersion); |
| 303 | ++ patterns = std::move(patterns_); |
| 304 | ++ return success(); |
| 305 | ++ } |
| 306 | ++ |
| 307 | ++ void runOnOperation() override { |
| 308 | ++ auto func = getOperation(); |
| 309 | ++ if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { |
| 310 | ++ func.emitError( |
| 311 | ++ "Failed to converge StableHLOCreateCompatibilityExpanderPass in ") |
| 312 | ++ << config.maxIterations << " iterations"; |
| 313 | ++ signalPassFailure(); |
| 314 | ++ } |
| 315 | ++ } |
| 316 | ++ |
| 317 | ++ private: |
| 318 | ++ FrozenRewritePatternSet patterns; |
| 319 | ++ GreedyRewriteConfig config; |
| 320 | ++}; |
| 321 | ++ |
| 322 | ++#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc" |
| 323 | ++ |
| 324 | ++} // namespace |
| 325 | ++ |
| 326 | ++void populateStablehloCreateCompatibilityExpanderPatterns( |
| 327 | ++ RewritePatternSet *patterns, MLIRContext *context, |
| 328 | ++ vhlo::Version targetVersion) { |
| 329 | ++ // StableHLO TanOp is introduced in v1.4.0. |
| 330 | ++ if (targetVersion < vhlo::Version(1, 4, 0)) { |
| 331 | ++ patterns->add<TanOp_ComplexElementType_CompatiblityExpander>(context); |
| 332 | ++ patterns->add<TanOp_CompatiblityExpander>(context); |
| 333 | ++ } |
| 334 | ++} |
| 335 | ++ |
| 336 | ++} // namespace stablehlo |
| 337 | ++} // namespace mlir |
| 338 | +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td |
| 339 | +--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td |
| 340 | ++++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td |
| 341 | +@@ -0,0 +1,47 @@ |
| 342 | ++/* Copyright 2022 The StableHLO Authors. |
| 343 | ++ |
| 344 | ++Licensed under the Apache License, Version 2.0 (the "License"); |
| 345 | ++you may not use this file except in compliance with the License. |
| 346 | ++You may obtain a copy of the License at |
| 347 | ++ |
| 348 | ++ http://www.apache.org/licenses/LICENSE-2.0 |
| 349 | ++ |
| 350 | ++Unless required by applicable law or agreed to in writing, software |
| 351 | ++distributed under the License is distributed on an "AS IS" BASIS, |
| 352 | ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 353 | ++See the License for the specific language governing permissions and |
| 354 | ++limitations under the License. |
| 355 | ++==============================================================================*/ |
| 356 | ++ |
| 357 | ++include "mlir/IR/OpBase.td" |
| 358 | ++include "stablehlo/dialect/StablehloOps.td" |
| 359 | ++ |
| 360 | ++def ComplexElementType : Type< |
| 361 | ++ CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">, |
| 362 | ++ "Complex element type">; |
| 363 | ++ |
| 364 | ++def NonComplexElementType : Type< |
| 365 | ++ CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">, |
| 366 | ++ "Non-complex element type">; |
| 367 | ++ |
| 368 | ++def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">; |
| 369 | ++ |
| 370 | ++// Express `tan` as |
| 371 | ++// sine(x) / cosine(x) |
| 372 | ++def TanOp_CompatiblityExpander : Pat<(StableHLO_TanOp NonComplexElementType:$input), |
| 373 | ++ (StableHLO_DivOp |
| 374 | ++ (StableHLO_SineOp $input), |
| 375 | ++ (StableHLO_CosineOp $input) |
| 376 | ++ )>; |
| 377 | ++ |
| 378 | ++// Express `tan(a + bi)` as |
| 379 | ++// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b)) |
| 380 | ++def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp ComplexElementType:$input), |
| 381 | ++ (StableHLO_DivOp |
| 382 | ++ (StableHLO_ComplexOp |
| 383 | ++ (StableHLO_TanOp:$tan (StableHLO_RealOp $input)), |
| 384 | ++ (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), |
| 385 | ++ (StableHLO_ComplexOp |
| 386 | ++ (createConstantWithAllOnes $tan), |
| 387 | ++ (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) |
| 388 | ++ )>; |
1 | 389 |
|
0 commit comments