Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions shardy/dialect/mpmd/transforms/common/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def UniquifyFunctionInputsOutputsPass :
Similarly, if a function returns a block argument, this pass creates an
identity fragment for that block argument, guaranteeing that values are
passed by value to the function, not by reference.

Additionally, when not using transfers, the pass will attempt to merge
each newly created inferred fragment into an existing same-mesh fragment
to reduce the total number of fragments.
}];

let options = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ func.func @no_work_needed(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mes
func.func @single_mesh_one_return_operand(%arg0: !mesh_1_tensor) -> (!mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
} {
// Uniquify merges into f2 (the last same-mesh fragment), keeping tensors
// alive for less time.
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) {
// CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32>
// CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1
// CHECK: %[[F2:.*]]:3 = mpmd.fragment<mesh="m1", origin=["f2"]> (%[[F1]])
// CHECK: return %[[F2]]#0, %[[F2]]#1, %[[F2]]#2
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
mpmd.return %1 : tensor<4xf32>
Expand All @@ -48,11 +48,8 @@ func.func @needs_fragment_for_m1_with_many_values(%arg0: !mesh_1_tensor, %arg1:
} {
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
// CHECK: %[[F3:.*]] = mpmd.fragment<mesh="m1", origin=["f3"]>
// CHECK: %[[UF:.*]]:5 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]], %[[F3]]) {mpmd.inferred_by = ["uniquify"]} (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>)
// CHECK-NEXT: mpmd.return %[[A1]], %[[A1]], %[[A2]], %[[A2]], %[[A2]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[F2]], %[[UF]]#0, %[[UF]]#2, %[[UF]]#1, %[[UF]]#3, %[[UF]]#4
// CHECK: %[[F3:.*]]:5 = mpmd.fragment<mesh="m1", origin=["f3"]> (%[[F1]], %arg0)
// CHECK: return %[[F2]], %[[F3]]#0, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3, %[[F3]]#4
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
mpmd.return %arg2 : tensor<4xf32>
} : (!mesh_1_tensor) -> !mesh_1_tensor
Expand All @@ -70,9 +67,10 @@ func.func @needs_fragment_for_m1_and_m2(%arg0: !mesh_1_tensor, %arg1: !mesh_2_te
) -> (!mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
} {
// CHECK: %[[UF1:.*]]:4 = mpmd.fragment<mesh="m1", origin=[]> ({{.*}}) {mpmd.inferred_by = ["uniquify"]}
// CHECK: %[[UF2:.*]]:2 = mpmd.fragment<mesh="m2", origin=[]> ({{.*}}) {mpmd.inferred_by = ["uniquify"]}
// CHECK: return %[[UF1]]#0, %[[UF2]]#0, %[[UF2]]#1, %[[UF1]]#2, %[[UF1]]#1, %[[UF1]]#3
// CHECK: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[F2:.*]]:2 = mpmd.fragment<mesh="m2", origin=["f2"]>
// CHECK: %[[F3:.*]]:4 = mpmd.fragment<mesh="m1", origin=["f3"]> (%[[F1]], %arg0)
// CHECK: return %[[F3]]#0, %[[F2]]#0, %[[F2]]#1, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
mpmd.return %arg2 : tensor<4xf32>
} : (!mesh_1_tensor) -> !mesh_1_tensor
Expand All @@ -95,11 +93,11 @@ module {
func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_tensor) -> (!dist_mesh_tensor, !dist_mesh_tensor, !dist_mesh_tensor) attributes {
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
} {
// Now the uniquify fragment merges into f2 (the last same-mesh fragment),
// keeping tensors alive for less time. f2 gets 3 results.
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) {
// CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32>
// CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1
// CHECK: %[[F2:.*]]:3 = mpmd.fragment<mesh="m1", origin=["f2"]> (%[[F1]])
// CHECK: return %[[F2]]#0, %[[F2]]#1, %[[F2]]#2
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
mpmd.return %1 : tensor<4xf32>
Expand All @@ -119,13 +117,13 @@ func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_ten
func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor)
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
{
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m", origin=["f"]> (%arg0) (%arg1: tensor<4xui32>) {
// CHECK-NEXT: return %arg1
// CHECK-NEXT: }
// CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment<mesh="m", origin=[]> (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) {
// CHECK-NEXT: return %arg1, %arg1
// The uniquify fragment for the duplicated %arg0 returns now merges into "f"
// (the last same-mesh fragment). Block args dominate everything, so no
// positioning constraint. Result: single fragment with 3 results.
// CHECK-NEXT: %[[F1:.*]]:3 = mpmd.fragment<mesh="m", origin=["f"]> (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) {
// CHECK-NEXT: return %arg1, %arg1, %arg1
// CHECK-NEXT: }
// CHECK-NEXT: return %[[F2]]#0, %[[F1]], %[[F2]]#1
// CHECK-NEXT: return %[[F1]]#1, %[[F1]]#0, %[[F1]]#2
%0 = mpmd.fragment<mesh="m", origin=["f"]>(%arg0) (%arg1: tensor<4xui32>) {
mpmd.return %arg1 : tensor<4xui32>
} : (!mesh_tensor) -> !mesh_tensor
Expand All @@ -146,3 +144,58 @@ func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor
// CHECK-NEXT: return %[[F]]
func.return %arg0 : !mesh_tensor
}

// -----

// Block-argument-only inferred fragments now merge into the last same-mesh
// fragment. Block args dominate everything, so there's no positioning
// constraint. This keeps tensors alive for less time.

!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xf32>>

// CHECK-LABEL: func @block_arg_only_merges
func.func @block_arg_only_merges(%arg0: !mesh_tensor, %arg1: !mesh_tensor)
-> (!mesh_tensor, !mesh_tensor, !mesh_tensor) attributes {
"topology"=#mpmd.topology<<"m": <["x"=2]>>>
} {
// The existing fragment f1 uses %arg0. The return uses %0, %arg1, %arg1.
// Uniquify creates an inferred fragment for the duplicated %arg1 returns.
// Since f1 is the last same-mesh fragment, the inferred fragment merges into
// f1, which now takes both %arg0 and %arg1.
// CHECK: %[[F1:.*]]:3 = mpmd.fragment<mesh="m", origin=["f1"]> (%arg0, %arg1) {mpmd.inferred_by = ["uniquify"]}
%0 = mpmd.fragment<mesh="m", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
mpmd.return %1 : tensor<4xf32>
} : (!mesh_tensor) -> !mesh_tensor
func.return %0, %arg1, %arg1 : !mesh_tensor, !mesh_tensor, !mesh_tensor
}

// -----

// When there are multiple same-mesh fragments, the inferred fragment merges
// into the last one (f2), keeping tensors alive for less time.

!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xf32>>

// CHECK-LABEL: func @merge_into_last_fragment
func.func @merge_into_last_fragment(%arg0: !mesh_tensor)
-> (!mesh_tensor, !mesh_tensor, !mesh_tensor) attributes {
"topology"=#mpmd.topology<<"m": <["x"=2]>>>
} {
// f1 produces %0, f2 consumes %0 and produces %1.
// The return duplicates %1, creating an inferred fragment for uniquify.
// f2 is the last same-mesh fragment, so the inferred fragment merges into f2.
// CHECK: %[[F1:.*]] = mpmd.fragment<mesh="m", origin=["f1"]>
// CHECK: %[[F2:.*]]:2 = mpmd.fragment<mesh="m", origin=["f2"]>
// CHECK-NOT: mpmd.fragment<mesh="m", origin=[]>
// CHECK: return %[[F1]], %[[F2]]#0, %[[F2]]#1
%0 = mpmd.fragment<mesh="m", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
mpmd.return %1 : tensor<4xf32>
} : (!mesh_tensor) -> !mesh_tensor
%1 = mpmd.fragment<mesh="m", origin=["f2"]> (%0) (%arg2: tensor<4xf32>) {
%2 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
mpmd.return %2 : tensor<4xf32>
} : (!mesh_tensor) -> !mesh_tensor
func.return %0, %1, %1 : !mesh_tensor, !mesh_tensor, !mesh_tensor
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ limitations under the License.
==============================================================================*/

#include <cstdint>
#include <utility>

#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "shardy/common/logging.h"
#include "shardy/dialect/mpmd/ir/dialect.h"
#include "shardy/dialect/mpmd/ir/utils.h"
#include "shardy/dialect/mpmd/transforms/common/passes.h" // IWYU pragma: keep
#include "shardy/dialect/mpmd/transforms/common/utils.h"
#include "shardy/dialect/sdy/ir/dialect.h"

namespace mlir::mpmd {
Expand All @@ -41,17 +41,53 @@ namespace {

using ValueToReturnIndices = llvm::MapVector<Value, SmallVector<int64_t>>;

// Tries to merge the newly created inferred fragment into an existing
// same-mesh fragment in the block.
void MergeInferredFragmentWithExisting(FragmentOp fragment_op,
StringRef mesh_name,
Operation* return_op,
OpBuilder& builder) {
// Find the last same-mesh fragment to merge into. Merging as late as
// possible keeps tensors alive for less time.
FragmentOp merge_target = FindLastFragmentOnMesh(
return_op->getBlock(), mesh_name, {return_op, fragment_op});
if (!merge_target) return;

// Ensure the merge target is after the latest operand producer so all
// operands dominate the merged fragment.
Operation* latest_producer = FindLatestOperandProducer(fragment_op);
if (!EnsureAfter(merge_target, latest_producer)) return;

fragment_op->moveAfter(merge_target);
IRRewriter rewriter(builder.getContext());

// Save the merge target's attributes before MergeRegionOps erases it.
SavedFragmentAttrs saved = SaveFragmentAttrs(merge_target);

FragmentOp merged_fragment = MergeRegionOps(
merge_target, fragment_op, rewriter,
/*num_static_args=*/0, /*replace_producer_use_in_consumer_block=*/
[](OpOperand&, Value) {
SDY_CHECK(false) << "Fragment ops shouldn't have free variables";
},
GetFragmentOriginUnion(merge_target, fragment_op, rewriter),
merge_target.getMeshNameAttr(),
// Uniquify-created fragments have no stage_id, so we preserve the
// merge target's stage_id directly.
/*stage_id=*/merge_target.getStageIdAttr());

RestoreFragmentAttrs(merged_fragment, saved, "uniquify", builder);
}

void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
ValueToReturnIndices& value_to_return_indices,
OpBuilder& builder) {
// We remove any entries that require no work, in order to avoid too many
// checks.
// Remove entries that require no work: single-use non-block-arg values
// already have a unique return slot and don't need a fragment.
value_to_return_indices.remove_if([](const auto& it) {
if (it.second.size() == 1) {
Value v = it.first;
return !isa<BlockArgument>(v);
}
return it.second.empty();
// Every entry has at least one index (populated via push_back).
SDY_CHECK(!it.second.empty());
return it.second.size() == 1 && !isa<BlockArgument>(it.first);
});

if (value_to_return_indices.empty()) {
Expand Down Expand Up @@ -98,6 +134,8 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
}
auto block_builder = OpBuilder::atBlockEnd(&fragment_block);
ReturnOp::create(block_builder, loc, returned_values);

MergeInferredFragmentWithExisting(fragment_op, mesh_name, return_op, builder);
}

// Replaces the return values of the function with transfer ops.
Expand Down
87 changes: 87 additions & 0 deletions shardy/dialect/mpmd/transforms/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ limitations under the License.
#include "mlir/Transforms/RegionUtils.h"
#include "shardy/common/logging.h"
#include "shardy/dialect/mpmd/ir/dialect.h"
#include "shardy/dialect/mpmd/ir/utils.h"

namespace mlir::mpmd {

Expand Down Expand Up @@ -226,6 +227,92 @@ bool IsSplitKeepTransferred(FragmentOp fragment) {
return fragment->hasAttr(kSplitKeepTransferredAttrName);
}

// ---------------------------------------------------------------------------
// Fragment query and positioning utilities.
// ---------------------------------------------------------------------------

FragmentOp FindLastFragmentOnMesh(Block* block, StringRef mesh_name,
ArrayRef<Operation*> exclude) {
FragmentOp result = nullptr;
for (Operation& op : *block) {
if (llvm::is_contained(exclude, &op)) continue;
auto frag = dyn_cast<FragmentOp>(&op);
if (frag && frag.getMeshName() == mesh_name) {
result = frag;
}
}
return result;
}

FragmentOp FindFirstFragmentOnMesh(Block* block, StringRef mesh_name,
ArrayRef<Operation*> exclude) {
for (Operation& op : *block) {
if (llvm::is_contained(exclude, &op)) continue;
auto frag = dyn_cast<FragmentOp>(&op);
if (frag && frag.getMeshName() == mesh_name) {
return frag;
}
}
return nullptr;
}

Operation* FindLatestOperandProducer(Operation* op) {
Operation* latest = nullptr;
for (Value v : op->getOperands()) {
Operation* def = v.getDefiningOp();
if (!def) continue;
if (!latest || latest->isBeforeInBlock(def)) {
latest = def;
}
}
return latest;
}

bool CanMoveAfter(Operation* op_to_move, Operation* target_op) {
if (op_to_move->getBlock() != target_op->getBlock()) return false;
if (!op_to_move->isBeforeInBlock(target_op)) return false;

for (Value result : op_to_move->getResults()) {
for (Operation* user : result.getUsers()) {
if (user->getBlock() == op_to_move->getBlock()) {
if (user == target_op || user->isBeforeInBlock(target_op)) {
return false;
}
}
}
}
return true;
}

bool EnsureAfter(Operation* op, Operation* target) {
if (!target) return true;
if (op == target) return true;
if (!op->isBeforeInBlock(target)) return true; // already after
if (!CanMoveAfter(op, target)) return false;
op->moveAfter(target);
return true;
}

SavedFragmentAttrs SaveFragmentAttrs(FragmentOp fragment) {
return {
fragment->getAttrOfType<ArrayAttr>(kInferredByAttr),
fragment->getAttrOfType<IntegerAttr>(kCallCounterAttrName),
};
}

void RestoreFragmentAttrs(FragmentOp fragment, const SavedFragmentAttrs& saved,
StringRef pass_name, OpBuilder& builder) {
if (saved.call_counter) {
fragment->setAttr(kCallCounterAttrName, saved.call_counter);
}
SmallVector<Attribute> inferred_by;
if (saved.inferred_by) {
inferred_by.append(saved.inferred_by.begin(), saved.inferred_by.end());
}
inferred_by.push_back(builder.getStringAttr(pass_name));
fragment->setAttr(kInferredByAttr, builder.getArrayAttr(inferred_by));
}

namespace detail {
namespace {

Expand Down
42 changes: 42 additions & 0 deletions shardy/dialect/mpmd/transforms/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,48 @@ SmallVector<T> FilterRange(RangeT range, const BitVector& erase) {
return result;
}

// ---------------------------------------------------------------------------
// Fragment query and positioning utilities.
// ---------------------------------------------------------------------------

// Returns the last FragmentOp on `mesh_name` in `block`, excluding ops in
// `exclude`. Returns nullptr if none found.
FragmentOp FindLastFragmentOnMesh(Block* block, StringRef mesh_name,
ArrayRef<Operation*> exclude = {});

// Returns the first FragmentOp on `mesh_name` in `block`, excluding ops in
// `exclude`. Returns nullptr if none found.
FragmentOp FindFirstFragmentOnMesh(Block* block, StringRef mesh_name,
ArrayRef<Operation*> exclude = {});

// Returns the latest (block-order) op that defines any operand of `op`.
// Returns nullptr if all operands are block arguments.
Operation* FindLatestOperandProducer(Operation* op);

// Returns true if `op_to_move` can be repositioned after `target_op`
// without breaking any use-def chains. Both must be in the same block
// and `op_to_move` must precede `target_op`.
bool CanMoveAfter(Operation* op_to_move, Operation* target_op);

// Moves `op` after `target` if needed and possible. Returns false if the
// move would break use-def chains. No-op (returns true) if `op` is already
// at or after `target`.
bool EnsureAfter(Operation* op, Operation* target);

// Attributes saved from a fragment before MergeRegionOps erases it.
struct SavedFragmentAttrs {
ArrayAttr inferred_by;
IntegerAttr call_counter;
};

// Saves the inferred_by and call_counter attributes from `fragment`.
SavedFragmentAttrs SaveFragmentAttrs(FragmentOp fragment);

// Restores saved attributes onto `fragment`, appending `pass_name` to the
// inferred_by list.
void RestoreFragmentAttrs(FragmentOp fragment, const SavedFragmentAttrs& saved,
StringRef pass_name, OpBuilder& builder);

namespace detail {

// A non-templated version of MergeRegionOps<OpTy> that takes a callback for
Expand Down
Loading
Loading