diff --git a/shardy/dialect/sdy/transforms/import/lift_inlined_meshes.cc b/shardy/dialect/sdy/transforms/import/lift_inlined_meshes.cc index 2722b45d..6dfed9c7 100644 --- a/shardy/dialect/sdy/transforms/import/lift_inlined_meshes.cc +++ b/shardy/dialect/sdy/transforms/import/lift_inlined_meshes.cc @@ -114,6 +114,7 @@ struct LiftInlinedMeshesPass : public impl::LiftInlinedMeshesPassBase { using LiftInlinedMeshesPassBase::LiftInlinedMeshesPassBase; + protected: void runOnOperation() final { ModuleOp moduleOp = getOperation(); SymbolTable symbolTable(moduleOp); @@ -196,6 +197,36 @@ struct LiftInlinedMeshesPass moduleOp.walk([&](stablehlo::CollectiveBroadcastOp op) { processMeshInReplicaGroups(op); }); + + // Attach discardable `stablehlo.mesh` attributes to all named meshes. + // Downgrading to older StableHLO versions before + // `MeshAxesReplicaGroups` was added requires the + // `StablehloCompatibilityExpander` pass to resolve symbol references to + // named meshes and extract a `stablehlo::MeshAttr` from them. Because + // Shardy's `sdy::MeshOp` stores its configuration as an `sdy::MeshAttr` + // and core StableHLO cannot depend on Shardy, attaching this discardable + // attribute ensures compatibility without violating dialect layering. + for (auto meshOp : llvm::make_early_inc_range(moduleOp.getOps())) { + if (meshOp->hasAttr("stablehlo.mesh")) { + continue; + } + MeshAttr sdyMeshAttr = meshOp.getMesh(); + SmallVector shloAxes; + for (auto axisAttr : sdyMeshAttr.getAxes()) { + shloAxes.push_back(mlir::stablehlo::MeshAxisAttr::get( + meshOp.getContext(), axisAttr.getName(), axisAttr.getSize())); + } + DenseIntElementsAttr deviceIds; + if (!sdyMeshAttr.getDeviceIds().empty()) { + auto type = RankedTensorType::get( + {static_cast(sdyMeshAttr.getDeviceIds().size())}, + builder.getI64Type()); + deviceIds = DenseIntElementsAttr::get(type, sdyMeshAttr.getDeviceIds()); + } + auto shloMeshAttr = mlir::stablehlo::MeshAttr::get(meshOp.getContext(), + shloAxes, deviceIds); + meshOp->setAttr("stablehlo.mesh", shloMeshAttr); + } } }; diff --git a/shardy/dialect/sdy/transforms/import/test/lift_inlined_meshes.mlir b/shardy/dialect/sdy/transforms/import/test/lift_inlined_meshes.mlir index de18acb6..77134cf4 100644 --- a/shardy/dialect/sdy/transforms/import/test/lift_inlined_meshes.mlir +++ b/shardy/dialect/sdy/transforms/import/test/lift_inlined_meshes.mlir @@ -256,3 +256,18 @@ func.func private @foo(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mes %0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<8x8xf32> return %0 : tensor<8x8xf32> } + +// ----- + +// CHECK: sdy.mesh @mesh = <["a"=4]> {stablehlo.mesh = #stablehlo.mesh]>} +sdy.mesh @mesh = <["a"=4]> + +// CHECK: sdy.mesh @mesh_0 = <["b"=2]> {stablehlo.mesh = #stablehlo.mesh]>} + +// CHECK-LABEL: func @tagged_stablehlo_mesh_attribute +func.func @tagged_stablehlo_mesh_attribute(%arg0: tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"b"}, {}]>]>} + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[, [{"b"}, {}]>]>} : tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + diff --git a/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir b/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir index ba78e175..519143a7 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: sdy_opt %s -split-input-file -sdy-propagation-pipeline 2>&1 | FileCheck %s +// RUN: sdy_opt %s -split-input-file -sdy-propagation-pipeline | FileCheck %s sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]>