1- // RUN: mpmd_opt %s -mpmd-export-pipeline 2>&1 | FileCheck %s
1+ // RUN: mpmd_opt %s -mpmd-export-pipeline -split-input-file 2>&1 | FileCheck %s
22
33!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor <" m1" , tensor <4 x8 xf32 >>
44
@@ -17,3 +17,36 @@ func.func @main(%arg0: !mesh_1_tensor_4_8_f32 {tf.aliasing_output = 0: i32}, %ar
1717 } : (!mesh_1_tensor_4_8_f32 , !mesh_1_tensor_4_8_f32 ) -> (!mesh_1_tensor_4_8_f32 )
1818 func.return %0 : !mesh_1_tensor_4_8_f32
1919}
20+
21+ // -----
22+
23+ !mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor <" m1" , tensor <4 x8 xf32 >>
24+ !mesh_2_tensor_4_8_f32 = !mpmd.mesh_tensor <" m2" , tensor <4 x8 xf32 >>
25+
26+ // This test verifies that an explicit fragment and an inferred fragment
27+ // (created by the UniquifyFunctionInputsOutputsPass for the duplicated return
28+ // of the transfer result) are merged sideways. Without sideways merge, the
29+ // transfer result would produce a separate inferred fragment call on m1.
30+ // The function-level returns remain unique SSA values (%[[RES]]#0, #1, #2),
31+ // preserving the invariant established by the uniquify pass, even though the
32+ // fragment body internally returns the same value in multiple positions.
33+ // CHECK-LABEL: func.func @test_sideways_merge
34+ func.func @test_sideways_merge (%arg0: !mesh_1_tensor_4_8_f32 , %arg1: !mesh_2_tensor_4_8_f32 )
35+ -> (!mesh_1_tensor_4_8_f32 , !mesh_1_tensor_4_8_f32 , !mesh_1_tensor_4_8_f32 ) attributes {
36+ " topology" =#mpmd.topology <
37+ <" m1" : <[" x" =2 ]>>,
38+ <" m2" : <[" x" =2 ]>>
39+ >} {
40+ // CHECK: %[[RES:.*]]:3 = mpmd.fragment_call<mesh="m1", origin=["f1"]> @[[CALLEE_M1:.*]]
41+ // CHECK-NOT: mpmd.fragment_call<mesh="m1"
42+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2
43+
44+ %0 = mpmd.fragment <mesh =" m1" , origin =[" f1" ]> (%arg0 ) (%arg2: tensor <4 x8 xf32 >) {
45+ %4 = stablehlo.add %arg2 , %arg2 : tensor <4 x8 xf32 >
46+ mpmd.return %4 : tensor <4 x8 xf32 >
47+ } : (!mesh_1_tensor_4_8_f32 ) -> !mesh_1_tensor_4_8_f32
48+
49+ %1 = mpmd.transfer %arg1 : (!mesh_2_tensor_4_8_f32 ) -> !mesh_1_tensor_4_8_f32
50+
51+ func.return %0 , %1 , %1 : !mesh_1_tensor_4_8_f32 , !mesh_1_tensor_4_8_f32 , !mesh_1_tensor_4_8_f32
52+ }
0 commit comments