Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions shardy/dialect/sdy/transforms/import/lift_inlined_meshes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ struct LiftInlinedMeshesPass
: public impl::LiftInlinedMeshesPassBase<LiftInlinedMeshesPass> {
using LiftInlinedMeshesPassBase::LiftInlinedMeshesPassBase;

protected:
void runOnOperation() final {
ModuleOp moduleOp = getOperation();
SymbolTable symbolTable(moduleOp);
Expand Down Expand Up @@ -196,6 +197,36 @@ struct LiftInlinedMeshesPass
moduleOp.walk([&](stablehlo::CollectiveBroadcastOp op) {
processMeshInReplicaGroups(op);
});

// Attach discardable `stablehlo.mesh` attributes to all named meshes.
// Downgrading to older StableHLO versions before
// `MeshAxesReplicaGroups` was added requires the
// `StablehloCompatibilityExpander` pass to resolve symbol references to
// named meshes and extract a `stablehlo::MeshAttr` from them. Because
// Shardy's `sdy::MeshOp` stores its configuration as an `sdy::MeshAttr`
// and core StableHLO cannot depend on Shardy, attaching this discardable
// attribute ensures compatibility without violating dialect layering.
for (auto meshOp : llvm::make_early_inc_range(moduleOp.getOps<MeshOp>())) {
if (meshOp->hasAttr("stablehlo.mesh")) {
continue;
}
MeshAttr sdyMeshAttr = meshOp.getMesh();
SmallVector<mlir::stablehlo::MeshAxisAttr> shloAxes;
for (auto axisAttr : sdyMeshAttr.getAxes()) {
shloAxes.push_back(mlir::stablehlo::MeshAxisAttr::get(
meshOp.getContext(), axisAttr.getName(), axisAttr.getSize()));
}
DenseIntElementsAttr deviceIds;
if (!sdyMeshAttr.getDeviceIds().empty()) {
auto type = RankedTensorType::get(
{static_cast<int64_t>(sdyMeshAttr.getDeviceIds().size())},
builder.getI64Type());
deviceIds = DenseIntElementsAttr::get(type, sdyMeshAttr.getDeviceIds());
}
auto shloMeshAttr = mlir::stablehlo::MeshAttr::get(meshOp.getContext(),
shloAxes, deviceIds);
meshOp->setAttr("stablehlo.mesh", shloMeshAttr);
}
}
};

Expand Down
15 changes: 15 additions & 0 deletions shardy/dialect/sdy/transforms/import/test/lift_inlined_meshes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,18 @@ func.func private @foo(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mes
%0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<8x8xf32>
return %0 : tensor<8x8xf32>
}

// -----

// CHECK: sdy.mesh @mesh = <["a"=4]> {stablehlo.mesh = #stablehlo.mesh<axes=[<name = "a", size = 4>]>}
sdy.mesh @mesh = <["a"=4]>

// CHECK: sdy.mesh @mesh_0 = <["b"=2]> {stablehlo.mesh = #stablehlo.mesh<axes=[<name = "b", size = 2>]>}

// CHECK-LABEL: func @tagged_stablehlo_mesh_attribute
func.func @tagged_stablehlo_mesh_attribute(%arg0: tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
// CHECK-NEXT: stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"b"}, {}]>]>}
%0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<mesh<["b"=2]>, [{"b"}, {}]>]>} : tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: sdy_opt %s -split-input-file -sdy-propagation-pipeline 2>&1 | FileCheck %s
// RUN: sdy_opt %s -split-input-file -sdy-propagation-pipeline | FileCheck %s

sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]>

Expand Down
Loading