Skip to content

Commit 25a9fd6

Browse files
petebucopybara-github
authored andcommitted
[mpmd] Merge sideways in export pipeline.
Adds sideways merging of inferred fragments to the export pipeline to allow merging fragments on the same mesh that are separated by operations on other meshes. This ensures that inferred fragments created in the uniquify pass are merged correctly. PiperOrigin-RevId: 899501008
1 parent 7fe99e0 commit 25a9fd6

2 files changed

Lines changed: 40 additions & 1 deletion

File tree

shardy/dialect/mpmd/transforms/export/export_pipeline.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ void addExportPipeline(OpPassManager& pm, const ExportOptions& options) {
9090
// fragments to compile) and may cause performance regressions. Thus, we merge
9191
// them with other fragments.
9292
pm.addNestedPass<FuncOp>(createMergeInferredFragmentsPass());
93+
{
94+
MergeInferredFragmentsPassOptions mergeInferredOptions;
95+
mergeInferredOptions.mergeSideways = true;
96+
pm.addNestedPass<FuncOp>(
97+
createMergeInferredFragmentsPass(std::move(mergeInferredOptions)));
98+
}
9399

94100
// Mark each fragment with the inputs and outputs which are offloaded to host
95101
// memory.

shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mpmd_opt %s -mpmd-export-pipeline 2>&1 | FileCheck %s
1+
// RUN: mpmd_opt %s -mpmd-export-pipeline -split-input-file 2>&1 | FileCheck %s
22

33
!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
44

@@ -17,3 +17,36 @@ func.func @main(%arg0: !mesh_1_tensor_4_8_f32 {tf.aliasing_output = 0: i32}, %ar
1717
} : (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) -> (!mesh_1_tensor_4_8_f32)
1818
func.return %0 : !mesh_1_tensor_4_8_f32
1919
}
20+
21+
// -----
22+
23+
!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
24+
!mesh_2_tensor_4_8_f32 = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>>
25+
26+
// This test verifies that an explicit fragment and an inferred fragment
27+
// (created by the UniquifyFunctionInputsOutputsPass for the duplicated return
28+
// of the transfer result) are merged sideways. Without sideways merge, the
29+
// transfer result would produce a separate inferred fragment call on m1.
30+
// The function-level returns remain unique SSA values (%[[RES]]#0, #1, #2),
31+
// preserving the invariant established by the uniquify pass, even though the
32+
// fragment body internally returns the same value in multiple positions.
33+
// CHECK-LABEL: func.func @test_sideways_merge
34+
func.func @test_sideways_merge(%arg0: !mesh_1_tensor_4_8_f32, %arg1: !mesh_2_tensor_4_8_f32)
35+
-> (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) attributes {
36+
"topology"=#mpmd.topology<
37+
<"m1": <["x"=2]>>,
38+
<"m2": <["x"=2]>>
39+
>} {
40+
// CHECK: %[[RES:.*]]:3 = mpmd.fragment_call<mesh="m1", origin=["f1"]> @[[CALLEE_M1:.*]]
41+
// CHECK-NOT: mpmd.fragment_call<mesh="m1"
42+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2
43+
44+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4x8xf32>) {
45+
%4 = stablehlo.add %arg2, %arg2 : tensor<4x8xf32>
46+
mpmd.return %4 : tensor<4x8xf32>
47+
} : (!mesh_1_tensor_4_8_f32) -> !mesh_1_tensor_4_8_f32
48+
49+
%1 = mpmd.transfer %arg1 : (!mesh_2_tensor_4_8_f32) -> !mesh_1_tensor_4_8_f32
50+
51+
func.return %0, %1, %1 : !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32
52+
}

0 commit comments

Comments
 (0)