Skip to content

Commit 66eae85

Browse files
Varchocopybara-github
authored andcommitted
[ReplicaGroupV3][JAX+stablehlo][6/n] Emit RGV3 from JAX in shard_map
PiperOrigin-RevId: 890525943
1 parent 697603d commit 66eae85

3 files changed

Lines changed: 47 additions & 1 deletion

File tree

shardy/dialect/sdy/transforms/import/lift_inlined_meshes.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ struct LiftInlinedMeshesPass
114114
: public impl::LiftInlinedMeshesPassBase<LiftInlinedMeshesPass> {
115115
using LiftInlinedMeshesPassBase::LiftInlinedMeshesPassBase;
116116

117+
protected:
117118
void runOnOperation() final {
118119
ModuleOp moduleOp = getOperation();
119120
SymbolTable symbolTable(moduleOp);
@@ -196,6 +197,36 @@ struct LiftInlinedMeshesPass
196197
moduleOp.walk([&](stablehlo::CollectiveBroadcastOp op) {
197198
processMeshInReplicaGroups(op);
198199
});
200+
201+
// Attach discardable `stablehlo.mesh` attributes to all named meshes.
202+
// Downgrading to older StableHLO versions before
203+
// `MeshAxesReplicaGroups` was added requires the
204+
// `StablehloCompatibilityExpander` pass to resolve symbol references to
205+
// named meshes and extract a `stablehlo::MeshAttr` from them. Because
206+
// Shardy's `sdy::MeshOp` stores its configuration as an `sdy::MeshAttr`
207+
// and core StableHLO cannot depend on Shardy, attaching this discardable
208+
// attribute ensures compatibility without violating dialect layering.
209+
for (auto meshOp : llvm::make_early_inc_range(moduleOp.getOps<MeshOp>())) {
210+
if (meshOp->hasAttr("stablehlo.mesh")) {
211+
continue;
212+
}
213+
MeshAttr sdyMeshAttr = meshOp.getMesh();
214+
SmallVector<mlir::stablehlo::MeshAxisAttr> shloAxes;
215+
for (auto axisAttr : sdyMeshAttr.getAxes()) {
216+
shloAxes.push_back(mlir::stablehlo::MeshAxisAttr::get(
217+
meshOp.getContext(), axisAttr.getName(), axisAttr.getSize()));
218+
}
219+
DenseIntElementsAttr deviceIds;
220+
if (!sdyMeshAttr.getDeviceIds().empty()) {
221+
auto type = RankedTensorType::get(
222+
{static_cast<int64_t>(sdyMeshAttr.getDeviceIds().size())},
223+
builder.getI64Type());
224+
deviceIds = DenseIntElementsAttr::get(type, sdyMeshAttr.getDeviceIds());
225+
}
226+
auto shloMeshAttr = mlir::stablehlo::MeshAttr::get(meshOp.getContext(),
227+
shloAxes, deviceIds);
228+
meshOp->setAttr("stablehlo.mesh", shloMeshAttr);
229+
}
199230
}
200231
};
201232

shardy/dialect/sdy/transforms/import/test/lift_inlined_meshes.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,18 @@ func.func private @foo(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mes
256256
%0 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<8x8xf32>
257257
return %0 : tensor<8x8xf32>
258258
}
259+
260+
// -----
261+
262+
// CHECK: sdy.mesh @mesh = <["a"=4]> {stablehlo.mesh = #stablehlo.mesh<axes=[<name = "a", size = 4>]>}
263+
sdy.mesh @mesh = <["a"=4]>
264+
265+
// CHECK: sdy.mesh @mesh_0 = <["b"=2]> {stablehlo.mesh = #stablehlo.mesh<axes=[<name = "b", size = 2>]>}
266+
267+
// CHECK-LABEL: func @tagged_stablehlo_mesh_attribute
268+
func.func @tagged_stablehlo_mesh_attribute(%arg0: tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
269+
// CHECK-NEXT: stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"b"}, {}]>]>}
270+
%0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<mesh<["b"=2]>, [{"b"}, {}]>]>} : tensor<4x4xf32>
271+
return %0 : tensor<4x4xf32>
272+
}
273+

shardy/dialect/sdy/transforms/propagation/test/propagation_pipeline.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: sdy_opt %s -split-input-file -sdy-propagation-pipeline 2>&1 | FileCheck %s
1+
// RUN: sdy_opt %s -split-input-file -sdy-propagation-pipeline | FileCheck %s
22

33
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]>
44

0 commit comments

Comments
 (0)