diff --git a/shardy/dialect/mpmd/transforms/common/passes.td b/shardy/dialect/mpmd/transforms/common/passes.td index 0b7dabc4..78795adc 100644 --- a/shardy/dialect/mpmd/transforms/common/passes.td +++ b/shardy/dialect/mpmd/transforms/common/passes.td @@ -450,6 +450,10 @@ def UniquifyFunctionInputsOutputsPass : Similarly, if a function returns a block argument, this pass creates an identity fragment for that block argument, guaranteeing that values are passed by value to the function, not by reference. + + Additionally, when not using transfers, the pass will attempt to merge + each newly created inferred fragment into an existing same-mesh fragment + to reduce the total number of fragments. }]; let options = [ diff --git a/shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir b/shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir index 172c9631..ffba5276 100644 --- a/shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir +++ b/shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir @@ -26,11 +26,11 @@ func.func @no_work_needed(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mes func.func @single_mesh_one_return_operand(%arg0: !mesh_1_tensor) -> (!mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes { "topology"=#mpmd.topology<<"m1": <["x"=2]>>> } { + // Uniquify merges into f2 (the last same-mesh fragment), keeping tensors + // alive for less time. // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment - // CHECK: %[[F2:.*]] = mpmd.fragment - // CHECK: %[[UF:.*]]:2 = mpmd.fragment (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) { - // CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32> - // CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1 + // CHECK: %[[F2:.*]]:3 = mpmd.fragment (%[[F1]]) + // CHECK: return %[[F2]]#0, %[[F2]]#1, %[[F2]]#2 %0 = mpmd.fragment (%arg0) (%arg1: tensor<4xf32>) { %1 = stablehlo.add %arg1, %arg1 : tensor<4xf32> mpmd.return %1 : tensor<4xf32> @@ -48,11 +48,8 @@ func.func @needs_fragment_for_m1_with_many_values(%arg0: !mesh_1_tensor, %arg1: } { // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment // CHECK: %[[F2:.*]] = mpmd.fragment - // CHECK: %[[F3:.*]] = mpmd.fragment - // CHECK: %[[UF:.*]]:5 = mpmd.fragment (%[[F1]], %[[F3]]) {mpmd.inferred_by = ["uniquify"]} (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>) - // CHECK-NEXT: mpmd.return %[[A1]], %[[A1]], %[[A2]], %[[A2]], %[[A2]] - // CHECK-NEXT: } - // CHECK-NEXT: return %[[F2]], %[[UF]]#0, %[[UF]]#2, %[[UF]]#1, %[[UF]]#3, %[[UF]]#4 + // CHECK: %[[F3:.*]]:5 = mpmd.fragment (%[[F1]], %arg0) + // CHECK: return %[[F2]], %[[F3]]#0, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3, %[[F3]]#4 %0 = mpmd.fragment (%arg0) (%arg2: tensor<4xf32>) { mpmd.return %arg2 : tensor<4xf32> } : (!mesh_1_tensor) -> !mesh_1_tensor @@ -70,9 +67,10 @@ func.func @needs_fragment_for_m1_and_m2(%arg0: !mesh_1_tensor, %arg1: !mesh_2_te ) -> (!mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes { "topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>> } { - // CHECK: %[[UF1:.*]]:4 = mpmd.fragment ({{.*}}) {mpmd.inferred_by = ["uniquify"]} - // CHECK: %[[UF2:.*]]:2 = mpmd.fragment ({{.*}}) {mpmd.inferred_by = ["uniquify"]} - // CHECK: return %[[UF1]]#0, %[[UF2]]#0, %[[UF2]]#1, %[[UF1]]#2, %[[UF1]]#1, %[[UF1]]#3 + // CHECK: %[[F1:.*]] = mpmd.fragment + // CHECK: %[[F2:.*]]:2 = mpmd.fragment + // CHECK: %[[F3:.*]]:4 = mpmd.fragment (%[[F1]], %arg0) + // CHECK: return %[[F3]]#0, %[[F2]]#0, %[[F2]]#1, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3 %0 = mpmd.fragment (%arg0) (%arg2: tensor<4xf32>) { mpmd.return %arg2 : tensor<4xf32> } : (!mesh_1_tensor) -> !mesh_1_tensor @@ -95,11 +93,11 @@ module { func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_tensor) -> (!dist_mesh_tensor, !dist_mesh_tensor, !dist_mesh_tensor) attributes { "topology"=#mpmd.topology<<"m1": <["x"=2]>>> } { + // Now the uniquify fragment merges into f2 (the last same-mesh fragment), + // keeping tensors alive for less time. f2 gets 3 results. // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment - // CHECK: %[[F2:.*]] = mpmd.fragment - // CHECK: %[[UF:.*]]:2 = mpmd.fragment (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) { - // CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32> - // CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1 + // CHECK: %[[F2:.*]]:3 = mpmd.fragment (%[[F1]]) + // CHECK: return %[[F2]]#0, %[[F2]]#1, %[[F2]]#2 %0 = mpmd.fragment (%arg0) (%arg1: tensor<4xf32>) { %1 = stablehlo.add %arg1, %arg1 : tensor<4xf32> mpmd.return %1 : tensor<4xf32> @@ -119,13 +117,13 @@ func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_ten func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor) attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>} { - // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment (%arg0) (%arg1: tensor<4xui32>) { - // CHECK-NEXT: return %arg1 - // CHECK-NEXT: } - // CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) { - // CHECK-NEXT: return %arg1, %arg1 + // The uniquify fragment for the duplicated %arg0 returns now merges into "f" + // (the last same-mesh fragment). Block args dominate everything, so no + // positioning constraint. Result: single fragment with 3 results. + // CHECK-NEXT: %[[F1:.*]]:3 = mpmd.fragment (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) { + // CHECK-NEXT: return %arg1, %arg1, %arg1 // CHECK-NEXT: } - // CHECK-NEXT: return %[[F2]]#0, %[[F1]], %[[F2]]#1 + // CHECK-NEXT: return %[[F1]]#1, %[[F1]]#0, %[[F1]]#2 %0 = mpmd.fragment(%arg0) (%arg1: tensor<4xui32>) { mpmd.return %arg1 : tensor<4xui32> } : (!mesh_tensor) -> !mesh_tensor @@ -146,3 +144,58 @@ func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor // CHECK-NEXT: return %[[F]] func.return %arg0 : !mesh_tensor } + +// ----- + +// Block-argument-only inferred fragments now merge into the last same-mesh +// fragment. Block args dominate everything, so there's no positioning +// constraint. This keeps tensors alive for less time. + +!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xf32>> + +// CHECK-LABEL: func @block_arg_only_merges +func.func @block_arg_only_merges(%arg0: !mesh_tensor, %arg1: !mesh_tensor) + -> (!mesh_tensor, !mesh_tensor, !mesh_tensor) attributes { + "topology"=#mpmd.topology<<"m": <["x"=2]>>> +} { + // The existing fragment f1 uses %arg0. The return uses %0, %arg1, %arg1. + // Uniquify creates an inferred fragment for the duplicated %arg1 returns. + // Since f1 is the last same-mesh fragment, the inferred fragment merges into + // f1, which now takes both %arg0 and %arg1. + // CHECK: %[[F1:.*]]:3 = mpmd.fragment (%arg0, %arg1) {mpmd.inferred_by = ["uniquify"]} + %0 = mpmd.fragment (%arg0) (%arg2: tensor<4xf32>) { + %1 = stablehlo.add %arg2, %arg2 : tensor<4xf32> + mpmd.return %1 : tensor<4xf32> + } : (!mesh_tensor) -> !mesh_tensor + func.return %0, %arg1, %arg1 : !mesh_tensor, !mesh_tensor, !mesh_tensor +} + +// ----- + +// When there are multiple same-mesh fragments, the inferred fragment merges +// into the last one (f2), keeping tensors alive for less time. + +!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xf32>> + +// CHECK-LABEL: func @merge_into_last_fragment +func.func @merge_into_last_fragment(%arg0: !mesh_tensor) + -> (!mesh_tensor, !mesh_tensor, !mesh_tensor) attributes { + "topology"=#mpmd.topology<<"m": <["x"=2]>>> +} { + // f1 produces %0, f2 consumes %0 and produces %1. + // The return duplicates %1, creating an inferred fragment for uniquify. + // f2 is the last same-mesh fragment, so the inferred fragment merges into f2. + // CHECK: %[[F1:.*]] = mpmd.fragment + // CHECK: %[[F2:.*]]:2 = mpmd.fragment + // CHECK-NOT: mpmd.fragment + // CHECK: return %[[F1]], %[[F2]]#0, %[[F2]]#1 + %0 = mpmd.fragment (%arg0) (%arg2: tensor<4xf32>) { + %1 = stablehlo.add %arg2, %arg2 : tensor<4xf32> + mpmd.return %1 : tensor<4xf32> + } : (!mesh_tensor) -> !mesh_tensor + %1 = mpmd.fragment (%0) (%arg2: tensor<4xf32>) { + %2 = stablehlo.add %arg2, %arg2 : tensor<4xf32> + mpmd.return %2 : tensor<4xf32> + } : (!mesh_tensor) -> !mesh_tensor + func.return %0, %1, %1 : !mesh_tensor, !mesh_tensor, !mesh_tensor +} diff --git a/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc b/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc index 70040bfd..8bdb7f37 100644 --- a/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc +++ b/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc @@ -14,15 +14,14 @@ limitations under the License. ==============================================================================*/ #include -#include -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" @@ -30,6 +29,7 @@ limitations under the License. #include "shardy/dialect/mpmd/ir/dialect.h" #include "shardy/dialect/mpmd/ir/utils.h" #include "shardy/dialect/mpmd/transforms/common/passes.h" // IWYU pragma: keep +#include "shardy/dialect/mpmd/transforms/common/utils.h" #include "shardy/dialect/sdy/ir/dialect.h" namespace mlir::mpmd { @@ -41,17 +41,53 @@ namespace { using ValueToReturnIndices = llvm::MapVector>; +// Tries to merge the newly created inferred fragment into an existing +// same-mesh fragment in the block. +void MergeInferredFragmentWithExisting(FragmentOp fragment_op, + StringRef mesh_name, + Operation* return_op, + OpBuilder& builder) { + // Find the last same-mesh fragment to merge into. Merging as late as + // possible keeps tensors alive for less time. + FragmentOp merge_target = FindLastFragmentOnMesh( + return_op->getBlock(), mesh_name, {return_op, fragment_op}); + if (!merge_target) return; + + // Ensure the merge target is after the latest operand producer so all + // operands dominate the merged fragment. + Operation* latest_producer = FindLatestOperandProducer(fragment_op); + if (!EnsureAfter(merge_target, latest_producer)) return; + + fragment_op->moveAfter(merge_target); + IRRewriter rewriter(builder.getContext()); + + // Save the merge target's attributes before MergeRegionOps erases it. + SavedFragmentAttrs saved = SaveFragmentAttrs(merge_target); + + FragmentOp merged_fragment = MergeRegionOps( + merge_target, fragment_op, rewriter, + /*num_static_args=*/0, /*replace_producer_use_in_consumer_block=*/ + [](OpOperand&, Value) { + SDY_CHECK(false) << "Fragment ops shouldn't have free variables"; + }, + GetFragmentOriginUnion(merge_target, fragment_op, rewriter), + merge_target.getMeshNameAttr(), + // Uniquify-created fragments have no stage_id, so we preserve the + // merge target's stage_id directly. + /*stage_id=*/merge_target.getStageIdAttr()); + + RestoreFragmentAttrs(merged_fragment, saved, "uniquify", builder); +} + void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, ValueToReturnIndices& value_to_return_indices, OpBuilder& builder) { - // We remove any entries that require no work, in order to avoid too many - // checks. + // Remove entries that require no work: single-use non-block-arg values + // already have a unique return slot and don't need a fragment. value_to_return_indices.remove_if([](const auto& it) { - if (it.second.size() == 1) { - Value v = it.first; - return !isa(v); - } - return it.second.empty(); + // Every entry has at least one index (populated via push_back). + SDY_CHECK(!it.second.empty()); + return it.second.size() == 1 && !isa(it.first); }); if (value_to_return_indices.empty()) { @@ -98,6 +134,8 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, } auto block_builder = OpBuilder::atBlockEnd(&fragment_block); ReturnOp::create(block_builder, loc, returned_values); + + MergeInferredFragmentWithExisting(fragment_op, mesh_name, return_op, builder); } // Replaces the return values of the function with transfer ops. diff --git a/shardy/dialect/mpmd/transforms/common/utils.cc b/shardy/dialect/mpmd/transforms/common/utils.cc index 6bca7025..56639bf9 100644 --- a/shardy/dialect/mpmd/transforms/common/utils.cc +++ b/shardy/dialect/mpmd/transforms/common/utils.cc @@ -47,6 +47,7 @@ limitations under the License. #include "mlir/Transforms/RegionUtils.h" #include "shardy/common/logging.h" #include "shardy/dialect/mpmd/ir/dialect.h" +#include "shardy/dialect/mpmd/ir/utils.h" namespace mlir::mpmd { @@ -226,6 +227,92 @@ bool IsSplitKeepTransferred(FragmentOp fragment) { return fragment->hasAttr(kSplitKeepTransferredAttrName); } +// --------------------------------------------------------------------------- +// Fragment query and positioning utilities. +// --------------------------------------------------------------------------- + +FragmentOp FindLastFragmentOnMesh(Block* block, StringRef mesh_name, + ArrayRef exclude) { + FragmentOp result = nullptr; + for (Operation& op : *block) { + if (llvm::is_contained(exclude, &op)) continue; + auto frag = dyn_cast(&op); + if (frag && frag.getMeshName() == mesh_name) { + result = frag; + } + } + return result; +} + +FragmentOp FindFirstFragmentOnMesh(Block* block, StringRef mesh_name, + ArrayRef exclude) { + for (Operation& op : *block) { + if (llvm::is_contained(exclude, &op)) continue; + auto frag = dyn_cast(&op); + if (frag && frag.getMeshName() == mesh_name) { + return frag; + } + } + return nullptr; +} + +Operation* FindLatestOperandProducer(Operation* op) { + Operation* latest = nullptr; + for (Value v : op->getOperands()) { + Operation* def = v.getDefiningOp(); + if (!def) continue; + if (!latest || latest->isBeforeInBlock(def)) { + latest = def; + } + } + return latest; +} + +bool CanMoveAfter(Operation* op_to_move, Operation* target_op) { + if (op_to_move->getBlock() != target_op->getBlock()) return false; + if (!op_to_move->isBeforeInBlock(target_op)) return false; + + for (Value result : op_to_move->getResults()) { + for (Operation* user : result.getUsers()) { + if (user->getBlock() == op_to_move->getBlock()) { + if (user == target_op || user->isBeforeInBlock(target_op)) { + return false; + } + } + } + } + return true; +} + +bool EnsureAfter(Operation* op, Operation* target) { + if (!target) return true; + if (op == target) return true; + if (!op->isBeforeInBlock(target)) return true; // already after + if (!CanMoveAfter(op, target)) return false; + op->moveAfter(target); + return true; +} + +SavedFragmentAttrs SaveFragmentAttrs(FragmentOp fragment) { + return { + fragment->getAttrOfType(kInferredByAttr), + fragment->getAttrOfType(kCallCounterAttrName), + }; +} + +void RestoreFragmentAttrs(FragmentOp fragment, const SavedFragmentAttrs& saved, + StringRef pass_name, OpBuilder& builder) { + if (saved.call_counter) { + fragment->setAttr(kCallCounterAttrName, saved.call_counter); + } + SmallVector inferred_by; + if (saved.inferred_by) { + inferred_by.append(saved.inferred_by.begin(), saved.inferred_by.end()); + } + inferred_by.push_back(builder.getStringAttr(pass_name)); + fragment->setAttr(kInferredByAttr, builder.getArrayAttr(inferred_by)); +} + namespace detail { namespace { diff --git a/shardy/dialect/mpmd/transforms/common/utils.h b/shardy/dialect/mpmd/transforms/common/utils.h index e924c24d..86df753a 100644 --- a/shardy/dialect/mpmd/transforms/common/utils.h +++ b/shardy/dialect/mpmd/transforms/common/utils.h @@ -127,6 +127,48 @@ SmallVector FilterRange(RangeT range, const BitVector& erase) { return result; } +// --------------------------------------------------------------------------- +// Fragment query and positioning utilities. +// --------------------------------------------------------------------------- + +// Returns the last FragmentOp on `mesh_name` in `block`, excluding ops in +// `exclude`. Returns nullptr if none found. +FragmentOp FindLastFragmentOnMesh(Block* block, StringRef mesh_name, + ArrayRef exclude = {}); + +// Returns the first FragmentOp on `mesh_name` in `block`, excluding ops in +// `exclude`. Returns nullptr if none found. +FragmentOp FindFirstFragmentOnMesh(Block* block, StringRef mesh_name, + ArrayRef exclude = {}); + +// Returns the latest (block-order) op that defines any operand of `op`. +// Returns nullptr if all operands are block arguments. +Operation* FindLatestOperandProducer(Operation* op); + +// Returns true if `op_to_move` can be repositioned after `target_op` +// without breaking any use-def chains. Both must be in the same block +// and `op_to_move` must precede `target_op`. +bool CanMoveAfter(Operation* op_to_move, Operation* target_op); + +// Moves `op` after `target` if needed and possible. Returns false if the +// move would break use-def chains. No-op (returns true) if `op` is already +// at or after `target`. +bool EnsureAfter(Operation* op, Operation* target); + +// Attributes saved from a fragment before MergeRegionOps erases it. +struct SavedFragmentAttrs { + ArrayAttr inferred_by; + IntegerAttr call_counter; +}; + +// Saves the inferred_by and call_counter attributes from `fragment`. +SavedFragmentAttrs SaveFragmentAttrs(FragmentOp fragment); + +// Restores saved attributes onto `fragment`, appending `pass_name` to the +// inferred_by list. +void RestoreFragmentAttrs(FragmentOp fragment, const SavedFragmentAttrs& saved, + StringRef pass_name, OpBuilder& builder); + namespace detail { // A non-templated version of MergeRegionOps that takes a callback for diff --git a/shardy/dialect/mpmd/transforms/common/utils_test.cc b/shardy/dialect/mpmd/transforms/common/utils_test.cc index 459dee0e..e641a3a5 100644 --- a/shardy/dialect/mpmd/transforms/common/utils_test.cc +++ b/shardy/dialect/mpmd/transforms/common/utils_test.cc @@ -107,5 +107,507 @@ TEST(MergeRegionOps, PreservesControlOperands) { f2->getResult(0)); } +// --------------------------------------------------------------------------- +// A reusable two-mesh program for fragment query tests. +// Four fragments total: f0(m1), f1(m2), f2(m1), f3(m1). +// --------------------------------------------------------------------------- +constexpr char kTwoMeshProgram[] = R"mlir( + !m1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + !m2_tensor = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> + func.func @main(%arg0: !m1_tensor, %arg1: !m2_tensor) + -> (!m1_tensor, !m2_tensor) + attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=4]>>>} { + %0 = mpmd.fragment (%arg0) + (%a0: tensor<4x8xf32>) { + mpmd.return %a0 : tensor<4x8xf32> + } : (!m1_tensor) -> !m1_tensor + + %1 = mpmd.fragment (%arg1) + (%a1: tensor<4x8xf32>) { + mpmd.return %a1 : tensor<4x8xf32> + } : (!m2_tensor) -> !m2_tensor + + %2 = mpmd.fragment (%0) + (%a2: tensor<4x8xf32>) { + mpmd.return %a2 : tensor<4x8xf32> + } : (!m1_tensor) -> !m1_tensor + + %3 = mpmd.fragment (%2) + (%a3: tensor<4x8xf32>) { + mpmd.return %a3 : tensor<4x8xf32> + } : (!m1_tensor) -> !m1_tensor + + return %3, %1 : !m1_tensor, !m2_tensor + } +)mlir"; + +// Helper: parse the two-mesh program and return (func, {f0, f1, f2, f3}). +struct ParsedTwoMeshProgram { + OwningOpRef module; + FuncOp func_op; + FragmentOp f0; // m1 + FragmentOp f1; // m2 + FragmentOp f2; // m1 + FragmentOp f3; // m1 +}; + +ParsedTwoMeshProgram ParseTwoMeshProgram(MLIRContext& context) { + ParsedTwoMeshProgram p; + p.module = parseSourceString(kTwoMeshProgram, &context); + p.func_op = GetMainFunction(*p.module); + auto it = p.func_op.getOps().begin(); + p.f0 = mlir::cast(*it++); + p.f1 = mlir::cast(*it++); + p.f2 = mlir::cast(*it++); + p.f3 = mlir::cast(*it); + return p; +} + +// ===== FindLastFragmentOnMesh tests ===== + +TEST(FindLastFragmentOnMesh, ReturnsLastOnMatchingMesh) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + // m1 fragments are f0, f2, f3. The last one should be f3. + FragmentOp last = FindLastFragmentOnMesh(block, "m1"); + ASSERT_NE(last, nullptr); + EXPECT_EQ(last, p.f3); +} + +TEST(FindLastFragmentOnMesh, ReturnsNullptrForMissingMesh) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + FragmentOp result = FindLastFragmentOnMesh(block, "nonexistent"); + EXPECT_EQ(result, nullptr); +} + +TEST(FindLastFragmentOnMesh, RespectsExcludeList) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + // Exclude f3 — the last m1 fragment should now be f2. + FragmentOp last = + FindLastFragmentOnMesh(block, "m1", {p.f3.getOperation()}); + ASSERT_NE(last, nullptr); + EXPECT_EQ(last, p.f2); +} + +TEST(FindLastFragmentOnMesh, ExcludeAllReturnsNullptr) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + // Exclude all m1 fragments. + FragmentOp last = FindLastFragmentOnMesh( + block, "m1", + {p.f0.getOperation(), p.f2.getOperation(), p.f3.getOperation()}); + EXPECT_EQ(last, nullptr); +} + +TEST(FindLastFragmentOnMesh, SingleMeshFragment) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + // m2 has only f1. + FragmentOp last = FindLastFragmentOnMesh(block, "m2"); + ASSERT_NE(last, nullptr); + EXPECT_EQ(last, p.f1); +} + +// ===== FindFirstFragmentOnMesh tests ===== + +TEST(FindFirstFragmentOnMesh, ReturnsFirstOnMatchingMesh) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + // m1 fragments are f0, f2, f3. The first one should be f0. + FragmentOp first = FindFirstFragmentOnMesh(block, "m1"); + ASSERT_NE(first, nullptr); + EXPECT_EQ(first, p.f0); +} + +TEST(FindFirstFragmentOnMesh, ReturnsNullptrForMissingMesh) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + FragmentOp result = FindFirstFragmentOnMesh(block, "nonexistent"); + EXPECT_EQ(result, nullptr); +} + +TEST(FindFirstFragmentOnMesh, RespectsExcludeList) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + // Exclude f0 — the first m1 fragment should now be f2. + FragmentOp first = + FindFirstFragmentOnMesh(block, "m1", {p.f0.getOperation()}); + ASSERT_NE(first, nullptr); + EXPECT_EQ(first, p.f2); +} + +TEST(FindFirstFragmentOnMesh, ExcludeAllReturnsNullptr) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + Block* block = &p.func_op.getBody().front(); + FragmentOp first = FindFirstFragmentOnMesh( + block, "m1", + {p.f0.getOperation(), p.f2.getOperation(), p.f3.getOperation()}); + EXPECT_EQ(first, nullptr); +} + +// ===== FindLatestOperandProducer tests ===== + +TEST(FindLatestOperandProducer, ReturnsLatestProducer) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f3 takes %2 (result of f2) as operand. f2 is the latest producer. + Operation* latest = FindLatestOperandProducer(p.f3); + ASSERT_NE(latest, nullptr); + EXPECT_EQ(latest, p.f2.getOperation()); +} + +TEST(FindLatestOperandProducer, ReturnsNullptrForBlockArgOnly) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f0 takes %arg0 (a block argument) as operand — no defining op. + Operation* latest = FindLatestOperandProducer(p.f0); + EXPECT_EQ(latest, nullptr); +} + +TEST(FindLatestOperandProducer, ReturnsNullptrForBlockArgOnly_M2) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f1 takes %arg1 (a block argument) as operand — no defining op. + Operation* latest = FindLatestOperandProducer(p.f1); + EXPECT_EQ(latest, nullptr); +} + +TEST(FindLatestOperandProducer, WithMultipleOperands) { + // Build a program where a fragment has operands from two different producers + // at different positions in the block. + const char kProgram[] = R"mlir( + !mt = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + func.func @main(%arg0: !mt) -> !mt + attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>>} { + %0 = mpmd.fragment (%arg0) + (%a0: tensor<4x8xf32>) { + mpmd.return %a0 : tensor<4x8xf32> + } : (!mt) -> !mt + + %1 = mpmd.fragment (%arg0) + (%a1: tensor<4x8xf32>) { + mpmd.return %a1 : tensor<4x8xf32> + } : (!mt) -> !mt + + %2 = mpmd.fragment (%0, %1) + (%a2: tensor<4x8xf32>, %a3: tensor<4x8xf32>) { + mpmd.return %a2 : tensor<4x8xf32> + } : (!mt, !mt) -> !mt + + return %2 : !mt + } + )mlir"; + MLIRContext context; + loadAllRequiredDialects(&context); + OwningOpRef module = + parseSourceString(kProgram, &context); + FuncOp func_op = GetMainFunction(*module); + + auto it = func_op.getOps().begin(); + ++it; // skip "early" + Operation* late = &*it++; + Operation* consumer = &*it; + + // The consumer uses results from both early and late. + // late is the latest operand producer. + Operation* latest = FindLatestOperandProducer(consumer); + ASSERT_NE(latest, nullptr); + EXPECT_EQ(latest, late); +} + +// ===== CanMoveAfter tests ===== + +TEST(CanMoveAfter, CanMoveWhenNoIntermediateUsers) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f0's result is used by f2 (not between f0 and f1). + // f1 is on m2 and doesn't use f0's result. + // So f0 can be moved after f1 only if no user of f0 is at or before f1. + // Actually: f0 -> f1 -> f2 -> f3. f0's user is f2. + // f2 is after f1, so f0 CAN be moved after f1. + EXPECT_TRUE(CanMoveAfter(p.f0, p.f1)); +} + +TEST(CanMoveAfter, CannotMoveWhenUserIsTarget) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f0's result is used by f2. Trying to move f0 after f2 fails because + // f2 itself uses f0's result (user == target). + EXPECT_FALSE(CanMoveAfter(p.f0, p.f2)); +} + +TEST(CanMoveAfter, CannotMoveWhenUserIsBeforeTarget) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f0's result is used by f2. Trying to move f0 after f3: f2 is before f3, + // so the user f2 would be before the target f3. + EXPECT_FALSE(CanMoveAfter(p.f0, p.f3)); +} + +TEST(CanMoveAfter, ReturnsFalseWhenNotBefore) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f3 is after f0 in block order — CanMoveAfter requires op_to_move to + // precede target_op. + EXPECT_FALSE(CanMoveAfter(p.f3, p.f0)); +} + +TEST(CanMoveAfter, ReturnsFalseForDifferentBlocks) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // Operations inside different blocks (the body of f0 vs. the func block). + // Get an op from inside f0's region. + Operation* inner_op = &p.f0.getRegion().front().front(); + EXPECT_FALSE(CanMoveAfter(inner_op, p.f1)); +} + +// ===== EnsureAfter tests ===== + +TEST(EnsureAfter, NoOpWhenAlreadyAfter) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f2 is already after f0 — EnsureAfter should be a no-op returning true. + EXPECT_TRUE(EnsureAfter(p.f2, p.f0)); + // f2 should still be after f0. + EXPECT_TRUE(p.f0->isBeforeInBlock(p.f2)); +} + +TEST(EnsureAfter, NoOpWhenSameOp) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // op == target — should be a no-op returning true. + EXPECT_TRUE(EnsureAfter(p.f0, p.f0)); +} + +TEST(EnsureAfter, NoOpWhenTargetIsNull) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // nullptr target — should return true. + EXPECT_TRUE(EnsureAfter(p.f0, nullptr)); +} + +TEST(EnsureAfter, MovesWhenSafe) { + // f1(m2) has no users in the block (only used by return). + // We can move f1 after f2 safely since f1's result is only used by return + // which comes after everything. + const char kProgram[] = R"mlir( + !m1 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + !m2 = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> + func.func @main(%arg0: !m1, %arg1: !m2) -> (!m1, !m2) + attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=4]>>>} { + %0 = mpmd.fragment (%arg0) + (%a: tensor<4x8xf32>) { + mpmd.return %a : tensor<4x8xf32> + } : (!m1) -> !m1 + + %1 = mpmd.fragment (%arg1) + (%b: tensor<4x8xf32>) { + mpmd.return %b : tensor<4x8xf32> + } : (!m2) -> !m2 + + %2 = mpmd.fragment (%0) + (%c: tensor<4x8xf32>) { + mpmd.return %c : tensor<4x8xf32> + } : (!m1) -> !m1 + + return %2, %1 : !m1, !m2 + } + )mlir"; + MLIRContext context; + loadAllRequiredDialects(&context); + OwningOpRef module = + parseSourceString(kProgram, &context); + FuncOp func_op = GetMainFunction(*module); + + auto it = func_op.getOps().begin(); + ++it; // skip "a" + FragmentOp b = mlir::cast(*it++); + FragmentOp c = mlir::cast(*it); + + // b is before c; move b after c. + ASSERT_TRUE(b->isBeforeInBlock(c)); + EXPECT_TRUE(EnsureAfter(b, c)); + // Now b should be after c. + EXPECT_TRUE(c->isBeforeInBlock(b)); +} + +TEST(EnsureAfter, ReturnsFalseWhenUnsafe) { + MLIRContext context; + loadAllRequiredDialects(&context); + auto p = ParseTwoMeshProgram(context); + + // f2's result is used by f3. Moving f2 after f3 is unsafe. + EXPECT_FALSE(EnsureAfter(p.f2, p.f3)); + // f2 should NOT have been moved — still before f3. + EXPECT_TRUE(p.f2->isBeforeInBlock(p.f3)); +} + +// ===== SaveFragmentAttrs / RestoreFragmentAttrs tests ===== + +// A reusable single-mesh program with two fragments. +constexpr char kTwoFragmentProgram[] = R"mlir( + !mt = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + func.func @main(%arg0: !mt) -> !mt + attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>>} { + %0 = mpmd.fragment (%arg0) + (%a0: tensor<4x8xf32>) { + mpmd.return %a0 : tensor<4x8xf32> + } : (!mt) -> !mt + + %1 = mpmd.fragment (%0) + (%a1: tensor<4x8xf32>) { + mpmd.return %a1 : tensor<4x8xf32> + } : (!mt) -> !mt + + return %1 : !mt + } +)mlir"; + +// A reusable single-mesh program with one fragment. +constexpr char kSingleFragmentProgram[] = R"mlir( + !mt = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + func.func @main(%arg0: !mt) -> !mt + attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>>} { + %0 = mpmd.fragment (%arg0) + (%a0: tensor<4x8xf32>) { + mpmd.return %a0 : tensor<4x8xf32> + } : (!mt) -> !mt + return %0 : !mt + } +)mlir"; + +TEST(SaveRestoreFragmentAttrs, SavesAndRestoresInferredByAndCallCounter) { + MLIRContext context; + loadAllRequiredDialects(&context); + OwningOpRef module = + parseSourceString(kTwoFragmentProgram, &context); + FuncOp func_op = GetMainFunction(*module); + + auto it = func_op.getOps().begin(); + FragmentOp f0 = mlir::cast(*it++); + FragmentOp f1 = mlir::cast(*it); + + OpBuilder builder(&context); + + // Set up f0 with inferred_by and call_counter attributes. + SetInferredByAttr(f0, "pass_a", builder); + f0->setAttr(kCallCounterAttrName, builder.getUI32IntegerAttr(42)); + + // Save the attributes from f0. + SavedFragmentAttrs saved = SaveFragmentAttrs(f0); + ASSERT_NE(saved.inferred_by, nullptr); + ASSERT_NE(saved.call_counter, nullptr); + EXPECT_EQ(saved.call_counter.getValue().getZExtValue(), 42u); + + // Restore onto f1, adding "pass_b" to the inferred_by list. + RestoreFragmentAttrs(f1, saved, "pass_b", builder); + + // Verify call_counter was restored. + auto restored_counter = + f1->getAttrOfType(kCallCounterAttrName); + ASSERT_NE(restored_counter, nullptr); + EXPECT_EQ(restored_counter.getValue().getZExtValue(), 42u); + + // Verify inferred_by was restored with pass_b appended. + auto restored_inferred_by = + f1->getAttrOfType(kInferredByAttr); + ASSERT_NE(restored_inferred_by, nullptr); + ASSERT_EQ(restored_inferred_by.size(), 2); + EXPECT_EQ(mlir::cast(restored_inferred_by[0]).getValue(), + "pass_a"); + EXPECT_EQ(mlir::cast(restored_inferred_by[1]).getValue(), + "pass_b"); +} + +TEST(SaveRestoreFragmentAttrs, SavesNullWhenNoAttrs) { + MLIRContext context; + loadAllRequiredDialects(&context); + OwningOpRef module = + parseSourceString(kSingleFragmentProgram, &context); + FuncOp func_op = GetMainFunction(*module); + FragmentOp f0 = mlir::cast(*func_op.getOps().begin()); + + // f0 has no inferred_by or call_counter attrs. + SavedFragmentAttrs saved = SaveFragmentAttrs(f0); + EXPECT_EQ(saved.inferred_by, nullptr); + EXPECT_EQ(saved.call_counter, nullptr); +} + +TEST(SaveRestoreFragmentAttrs, RestoreWithNoSavedInferredByCreatesNew) { + MLIRContext context; + loadAllRequiredDialects(&context); + OwningOpRef module = + parseSourceString(kSingleFragmentProgram, &context); + FuncOp func_op = GetMainFunction(*module); + FragmentOp f0 = mlir::cast(*func_op.getOps().begin()); + + OpBuilder builder(&context); + + // Saved attrs have no inferred_by. + SavedFragmentAttrs saved = {/*inferred_by=*/nullptr, + /*call_counter=*/nullptr}; + RestoreFragmentAttrs(f0, saved, "my_pass", builder); + + // Should create an inferred_by list with just "my_pass". + auto inferred_by = f0->getAttrOfType(kInferredByAttr); + ASSERT_NE(inferred_by, nullptr); + ASSERT_EQ(inferred_by.size(), 1); + EXPECT_EQ(mlir::cast(inferred_by[0]).getValue(), "my_pass"); + + // No call_counter should be set. + EXPECT_FALSE(f0->hasAttr(kCallCounterAttrName)); +} + } // namespace } // namespace mlir::mpmd diff --git a/shardy/dialect/mpmd/transforms/export/export_pipeline.cc b/shardy/dialect/mpmd/transforms/export/export_pipeline.cc index 3b4218db..619aa385 100644 --- a/shardy/dialect/mpmd/transforms/export/export_pipeline.cc +++ b/shardy/dialect/mpmd/transforms/export/export_pipeline.cc @@ -86,11 +86,6 @@ void addExportPipeline(OpPassManager& pm, const ExportOptions& options) { // identity fragments, which would be canonicalized away. pm.addNestedPass(createUniquifyFunctionInputsOutputsPass()); - // The fragments created by the pass above maybe slowdown compilation (more - // fragments to compile) and may cause performance regressions. Thus, we merge - // them with other fragments. - pm.addNestedPass(createMergeInferredFragmentsPass()); - // Mark each fragment with the inputs and outputs which are offloaded to host // memory. pm.addNestedPass(createMarkOffloadedInputOutputPass()); diff --git a/shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir b/shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir index adb7f6d8..6b249e5e 100644 --- a/shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir +++ b/shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: mpmd_opt %s -mpmd-export-pipeline 2>&1 | FileCheck %s +// RUN: mpmd_opt %s -mpmd-export-pipeline -split-input-file 2>&1 | FileCheck %s !mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> @@ -17,3 +17,36 @@ func.func @main(%arg0: !mesh_1_tensor_4_8_f32 {tf.aliasing_output = 0: i32}, %ar } : (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) -> (!mesh_1_tensor_4_8_f32) func.return %0 : !mesh_1_tensor_4_8_f32 } + +// ----- + +!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> +!mesh_2_tensor_4_8_f32 = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> + +// This test verifies that an explicit fragment and an inferred fragment +// (created by the UniquifyFunctionInputsOutputsPass for the duplicated return +// of the transfer result) are merged sideways. Without sideways merge, the +// transfer result would produce a separate inferred fragment call on m1. +// The function-level returns remain unique SSA values (%[[RES]]#0, #1, #2), +// preserving the invariant established by the uniquify pass, even though the +// fragment body internally returns the same value in multiple positions. +// CHECK-LABEL: func.func @test_sideways_merge +func.func @test_sideways_merge(%arg0: !mesh_1_tensor_4_8_f32, %arg1: !mesh_2_tensor_4_8_f32) + -> (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) attributes { + "topology"=#mpmd.topology< + <"m1": <["x"=2]>>, + <"m2": <["x"=2]>> + >} { + // CHECK: %[[RES:.*]]:3 = mpmd.fragment_call @[[CALLEE_M1:.*]] + // CHECK-NOT: mpmd.fragment_call (%arg0) (%arg2: tensor<4x8xf32>) { + %4 = stablehlo.add %arg2, %arg2 : tensor<4x8xf32> + mpmd.return %4 : tensor<4x8xf32> + } : (!mesh_1_tensor_4_8_f32) -> !mesh_1_tensor_4_8_f32 + + %1 = mpmd.transfer %arg1 : (!mesh_2_tensor_4_8_f32) -> !mesh_1_tensor_4_8_f32 + + func.return %0, %1, %1 : !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32 +}