Skip to content

Commit 09debc3

Browse files
abhigunjcopybara-github
authored andcommitted
Add StableHLOCreateCompatibilityExpanderPass to create compatibility expander for StableHLO operations.
PiperOrigin-RevId: 669131010
1 parent 51f3b68 commit 09debc3

File tree

1 file changed

+388
-0
lines changed

1 file changed

+388
-0
lines changed

third_party/stablehlo/temporary.patch

Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,389 @@
1+
diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel
2+
--- stablehlo/BUILD.bazel
3+
+++ stablehlo/BUILD.bazel
4+
@@ -340,6 +340,21 @@
5+
],
6+
)
7+
8+
+gentbl_cc_library(
9+
+ name = "stablehlo_create_compatibility_expander_inc_gen",
10+
+ tbl_outs = [
11+
+ (
12+
+ ["--gen-rewriters"],
13+
+ "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc",
14+
+ ),
15+
+ ],
16+
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
17+
+ td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td",
18+
+ deps = [
19+
+ ":stablehlo_ops_td_files",
20+
+ ],
21+
+)
22+
+
23+
cc_library(
24+
name = "interpreter_ops",
25+
srcs = [
26+
@@ -1086,6 +1101,7 @@
27+
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
28+
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
29+
"stablehlo/transforms/StablehloConvertToSignless.cpp",
30+
+ "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp",
31+
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
32+
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
33+
"stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp",
34+
@@ -1109,6 +1125,7 @@
35+
":chlo_ops",
36+
":chlo_rewriters_inc_gen",
37+
":linalg_passes",
38+
+ ":stablehlo_create_compatibility_expander_inc_gen",
39+
":stablehlo_legalize_deprecated_ops_inc_gen",
40+
":stablehlo_ops",
41+
":stablehlo_ops_inc_gen",
42+
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
43+
--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
44+
+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
45+
@@ -0,0 +1,43 @@
46+
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK
47+
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE
48+
+
49+
+// -----
50+
+
51+
+// CHECK-LABEL @tan_op_non_complex
52+
+// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64>
53+
+// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64>
54+
+// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64>
55+
+// CHECK-NEXT: return %[[div2]] : tensor<4xf64>
56+
+func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
57+
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64>
58+
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
59+
+ func.return %1 : tensor<4xf64>
60+
+}
61+
+
62+
+// -----
63+
+
64+
+// CHECK-LABEL: @tan_op_complex
65+
+// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64>
66+
+// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
67+
+// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
68+
+// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64>
69+
+// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64>
70+
+// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64>
71+
+// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
72+
+// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64>
73+
+// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex<f64>>
74+
+// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64>
75+
+// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64>
76+
+// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex<f64>>
77+
+// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex<f64>>
78+
+// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
79+
+// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
80+
+// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64>
81+
+func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) {
82+
+ %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
83+
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex<f64>>
84+
+ %1 = stablehlo.tan %0 : tensor<4xcomplex<f64>>
85+
+ %2 = stablehlo.real %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
86+
+ %3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
87+
+ func.return %2, %3 : tensor<4xf64>, tensor<4xf64>
88+
+}
89+
diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt
90+
--- stablehlo/stablehlo/transforms/CMakeLists.txt
91+
+++ stablehlo/stablehlo/transforms/CMakeLists.txt
92+
@@ -20,6 +20,10 @@
93+
mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters)
94+
add_public_tablegen_target(ChloDecompositionPatternsIncGen)
95+
96+
+set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td)
97+
+mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters)
98+
+add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen)
99+
+
100+
set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td)
101+
mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters)
102+
add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen)
103+
@@ -27,6 +31,7 @@
104+
set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td)
105+
mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters)
106+
add_public_tablegen_target(VhloToVersionPatterns)
107+
+
108+
109+
add_mlir_dialect_library(StablehloPasses
110+
PARTIAL_SOURCES_INTENDED
111+
@@ -37,6 +42,7 @@
112+
StablehloAggressiveSimplification.cpp
113+
StablehloCanonicalizeDynamism.cpp
114+
StablehloConvertToSignless.cpp
115+
+ StablehloCreateCompatibilityExpander.cpp
116+
StablehloLegalizeCompositeToCall.cpp
117+
StablehloLegalizeDeprecatedOps.cpp
118+
StablehloLegalizeQuantToMath.cpp
119+
@@ -53,6 +59,7 @@
120+
StablehloLegalizeDeprecatedOpsPatternsIncGen
121+
PassesIncGen
122+
VhloToVersionPatterns
123+
+ StablehloCreateCompatibilityExpanderPatternsIncGen
124+
125+
LINK_LIBS PUBLIC
126+
ChloOps
127+
diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
128+
--- stablehlo/stablehlo/transforms/Passes.h
129+
+++ stablehlo/stablehlo/transforms/Passes.h
130+
@@ -25,6 +25,7 @@
131+
#include "mlir/Pass/Pass.h"
132+
#include "mlir/Support/LogicalResult.h"
133+
#include "mlir/Transforms/DialectConversion.h"
134+
+#include "stablehlo/dialect/Version.h"
135+
136+
namespace mlir {
137+
namespace stablehlo {
138+
@@ -96,6 +97,12 @@
139+
void populateShapeToStablehloPatterns(MLIRContext *context,
140+
RewritePatternSet *patterns);
141+
142+
+/// Collection of patterns to create compatibility expander for StableHLO
143+
+/// operations.
144+
+void populateStablehloCreateCompatibilityExpanderPatterns(
145+
+ RewritePatternSet *patterns, MLIRContext *context,
146+
+ vhlo::Version targetVersion);
147+
+
148+
//// Additional pass constructors ////
149+
150+
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
151+
diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td
152+
--- stablehlo/stablehlo/transforms/Passes.td
153+
+++ stablehlo/stablehlo/transforms/Passes.td
154+
@@ -292,3 +292,51 @@
155+
"mlir::stablehlo::StablehloDialect",
156+
];
157+
}
158+
+
159+
+def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> {
160+
+ let summary = "Create compatibility expander for StableHLO operations.";
161+
+
162+
+ let description = [{
163+
+ StableHLO ops gets updates or new op is introduced in the latest versions.
164+
+ This opt-in pass expands backward compatibility with older StableHLO
165+
+ versions by decomposing newer StableHLO operations into equivalent
166+
+ operations supported by those older versions.
167+
+
168+
+ Why is this an opt-in pass?
169+
+
170+
+ Occasionally, StableHLO op enhancements are used to greatly simplify the
171+
+ handling of certain common patterns in the OpenXLA ecosystem. This
172+
+ includes things like TanOp, which has high framework and compiler support,
173+
+ as well as gather/scatter batching dimensions, which can be represented
174+
+ using slices, but makes sharding much more difficult. For this category of
175+
+ new features, we do not offer automatic downgrade, since it may throw away
176+
+ important information used in subsequent optimizations. This pass can be
177+
+ used to expand these ops based on a target version to maximize compatibility
178+
+ at the expense of potentially less optimal compilation.
179+
+
180+
+ ```mlir
181+
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
182+
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
183+
+ func.return %1 : tensor<4xf64>
184+
+ }
185+
+ ```
186+
+
187+
+ will become:
188+
+
189+
+ ```mlir
190+
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
191+
+ %0 = stablehlo.sine %arg0 : tensor<4xf64>
192+
+ %1 = stablehlo.cosine %arg0 : tensor<4xf64>
193+
+ %2 = stablehlo.divide %0, %1 : tensor<4xf64>
194+
+ return %2 : tensor<4xf64>
195+
+ }
196+
+ ```
197+
+ }];
198+
+ let options = [
199+
+ Option<"targetVersionOption", "target", "std::string", "",
200+
+ "The target version. Must be a version of the form #.#.#.">,
201+
+ ];
202+
+ let dependentDialects = [
203+
+ "mlir::stablehlo::StablehloDialect",
204+
+ ];
205+
+}
206+
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
207+
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
208+
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
209+
@@ -0,0 +1,128 @@
210+
+/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
211+
+Licensed under the Apache License, Version 2.0 (the "License");
212+
+you may not use this file except in compliance with the License.
213+
+You may obtain a copy of the License at
214+
+ http://www.apache.org/licenses/LICENSE-2.0
215+
+Unless required by applicable law or agreed to in writing, software
216+
+distributed under the License is distributed on an "AS IS" BASIS,
217+
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
218+
+See the License for the specific language governing permissions and
219+
+limitations under the License.
220+
+==============================================================================*/
221+
+
222+
+#include <fcntl.h>
223+
+
224+
+#include <cassert>
225+
+
226+
+#include "llvm/ADT/APFloat.h"
227+
+#include "llvm/Support/ErrorHandling.h"
228+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
229+
+#include "mlir/IR/BuiltinAttributes.h"
230+
+#include "mlir/IR/PatternMatch.h"
231+
+#include "mlir/Support/LLVM.h"
232+
+#include "mlir/Transforms/DialectConversion.h"
233+
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
234+
+#include "stablehlo/dialect/StablehloOps.h"
235+
+#include "stablehlo/dialect/Version.h"
236+
+#include "stablehlo/transforms/Passes.h"
237+
+
238+
+namespace mlir {
239+
+namespace stablehlo {
240+
+#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS
241+
+#include "stablehlo/transforms/Passes.h.inc"
242+
+
243+
+namespace {
244+
+
245+
+//===----------------------------------------------------------------------===//
246+
+// Helpers.
247+
+//===----------------------------------------------------------------------===//
248+
+
249+
+// Creates a constant with all ones.
250+
+static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) {
251+
+ auto shapedTy = dyn_cast<mlir::ShapedType>(val.getType());
252+
+ if (!shapedTy) llvm_unreachable("Unsupported shaped type.");
253+
+
254+
+ mlir::DenseElementsAttr elementsAttr =
255+
+ mlir::DenseElementsAttr::get(shapedTy, 1.0);
256+
+
257+
+ return b.create<mlir::stablehlo::ConstantOp>(loc, val.getType(),
258+
+ elementsAttr);
259+
+}
260+
+
261+
+// Check user-specified target version.
262+
+vhlo::Version validateTargetVersion(llvm::StringRef versionRef) {
263+
+ auto failOrVersion = vhlo::Version::fromString(versionRef);
264+
+ if (failed(failOrVersion)) {
265+
+ assert(!versionRef.empty() &&
266+
+ "No target version specified. Target version must be of the form "
267+
+ "`#.#.#`.");
268+
+ assert(versionRef.empty() &&
269+
+ "Invalid target version argument. Target version must be of the "
270+
+ "form `#.#.#`.");
271+
+ }
272+
+ vhlo::Version targetVersion = *failOrVersion;
273+
+ assert((vhlo::Version::getMinimumVersion() <= targetVersion) &&
274+
+ "target version is less than minimum supported.");
275+
+ assert((targetVersion <= vhlo::Version::getCurrentVersion()) &&
276+
+ "target version is greater than current version.");
277+
+ return targetVersion;
278+
+}
279+
+
280+
+//===----------------------------------------------------------------------===//
281+
+// Pass
282+
+//===----------------------------------------------------------------------===//
283+
+
284+
+struct StablehloCreateCompatibilityExpanderPass
285+
+ : public impl::StablehloCreateCompatibilityExpanderPassBase<
286+
+ StablehloCreateCompatibilityExpanderPass> {
287+
+ StablehloCreateCompatibilityExpanderPass()
288+
+ : StablehloCreateCompatibilityExpanderPassBase<
289+
+ StablehloCreateCompatibilityExpanderPass>() {}
290+
+ StablehloCreateCompatibilityExpanderPass(
291+
+ const StablehloCreateCompatibilityExpanderPassOptions &opts)
292+
+ : StablehloCreateCompatibilityExpanderPassBase<
293+
+ StablehloCreateCompatibilityExpanderPass>(opts) {}
294+
+
295+
+ public:
296+
+ LogicalResult initialize(MLIRContext *context) override {
297+
+ auto targetVersion = validateTargetVersion(targetVersionOption);
298+
+
299+
+ config.useTopDownTraversal = true;
300+
+ RewritePatternSet patterns_(context);
301+
+ populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context,
302+
+ targetVersion);
303+
+ patterns = std::move(patterns_);
304+
+ return success();
305+
+ }
306+
+
307+
+ void runOnOperation() override {
308+
+ auto func = getOperation();
309+
+ if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
310+
+ func.emitError(
311+
+ "Failed to converge StableHLOCreateCompatibilityExpanderPass in ")
312+
+ << config.maxIterations << " iterations";
313+
+ signalPassFailure();
314+
+ }
315+
+ }
316+
+
317+
+ private:
318+
+ FrozenRewritePatternSet patterns;
319+
+ GreedyRewriteConfig config;
320+
+};
321+
+
322+
+#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc"
323+
+
324+
+} // namespace
325+
+
326+
+void populateStablehloCreateCompatibilityExpanderPatterns(
327+
+ RewritePatternSet *patterns, MLIRContext *context,
328+
+ vhlo::Version targetVersion) {
329+
+ // StableHLO TanOp is introduced in v1.4.0.
330+
+ if (targetVersion < vhlo::Version(1, 4, 0)) {
331+
+ patterns->add<TanOp_ComplexElementType_CompatiblityExpander>(context);
332+
+ patterns->add<TanOp_CompatiblityExpander>(context);
333+
+ }
334+
+}
335+
+
336+
+} // namespace stablehlo
337+
+} // namespace mlir
338+
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
339+
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
340+
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
341+
@@ -0,0 +1,47 @@
342+
+/* Copyright 2022 The StableHLO Authors.
343+
+
344+
+Licensed under the Apache License, Version 2.0 (the "License");
345+
+you may not use this file except in compliance with the License.
346+
+You may obtain a copy of the License at
347+
+
348+
+ http://www.apache.org/licenses/LICENSE-2.0
349+
+
350+
+Unless required by applicable law or agreed to in writing, software
351+
+distributed under the License is distributed on an "AS IS" BASIS,
352+
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
353+
+See the License for the specific language governing permissions and
354+
+limitations under the License.
355+
+==============================================================================*/
356+
+
357+
+include "mlir/IR/OpBase.td"
358+
+include "stablehlo/dialect/StablehloOps.td"
359+
+
360+
+def ComplexElementType : Type<
361+
+ CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
362+
+ "Complex element type">;
363+
+
364+
+def NonComplexElementType : Type<
365+
+ CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
366+
+ "Non-complex element type">;
367+
+
368+
+def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">;
369+
+
370+
+// Express `tan` as
371+
+// sine(x) / cosine(x)
372+
+def TanOp_CompatiblityExpander : Pat<(StableHLO_TanOp NonComplexElementType:$input),
373+
+ (StableHLO_DivOp
374+
+ (StableHLO_SineOp $input),
375+
+ (StableHLO_CosineOp $input)
376+
+ )>;
377+
+
378+
+// Express `tan(a + bi)` as
379+
+// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b))
380+
+def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp ComplexElementType:$input),
381+
+ (StableHLO_DivOp
382+
+ (StableHLO_ComplexOp
383+
+ (StableHLO_TanOp:$tan (StableHLO_RealOp $input)),
384+
+ (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))),
385+
+ (StableHLO_ComplexOp
386+
+ (createConstantWithAllOnes $tan),
387+
+ (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
388+
+ )>;
1389

0 commit comments

Comments
 (0)