Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions shardy/dialect/sdy/ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,33 @@ MeshAttr getMeshAttr(Operation* op, SymbolRefAttr meshSymName) {
return nullptr;
}

TensorShardingAttr inlineMesh(const SymbolTable& symbolTable,
TensorShardingAttr sharding) {
if (auto name = dyn_cast<FlatSymbolRefAttr>(sharding.getMeshOrRef())) {
MeshAttr mesh = getMeshAttr(symbolTable, name);
assert(mesh && "unknown mesh");
return TensorShardingAttr::get(
sharding.getContext(), mesh, sharding.getDimShardings(),
sharding.getReplicatedAxes(), sharding.getUnreducedAxes());
}
return sharding;
}

TensorShardingPerValueAttr inlineMesh(
const SymbolTable& symbolTable,
TensorShardingPerValueAttr shardingPerValue) {
if (!shardingPerValue) {
return shardingPerValue;
}
SmallVector<TensorShardingAttr> inlinedShardings;
inlinedShardings.reserve(shardingPerValue.getShardings().size());
for (TensorShardingAttr shardingAttr : shardingPerValue.getShardings()) {
inlinedShardings.push_back(inlineMesh(symbolTable, shardingAttr));
}
return TensorShardingPerValueAttr::get(shardingPerValue.getContext(),
inlinedShardings);
}

Attribute getCommonMeshOrRef(ArrayRef<TensorShardingAttr> operandShardings,
ArrayRef<TensorShardingAttr> resultsShardings,
const SymbolTable& symbolTable,
Expand Down
11 changes: 11 additions & 0 deletions shardy/dialect/sdy/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ MeshAttr getMeshAttr(Operation* op, StringRef meshName);
// table, or nullptr otherwise.
MeshAttr getMeshAttr(Operation* op, SymbolRefAttr meshSymName);

// If sharding refers to a mesh by name, returns a new TensorShardingAttr with
// the mesh inlined. Otherwise returns the same sharding.
TensorShardingAttr inlineMesh(const SymbolTable& symbolTable,
TensorShardingAttr sharding);

// For each sharding in shardingPerValue, if it refers to a mesh by name,
// returns a new sharding with the mesh inlined.
TensorShardingPerValueAttr inlineMesh(
const SymbolTable& symbolTable,
TensorShardingPerValueAttr shardingPerValue);

// Returns the common mesh (or a reference to it) bound by all the
// `TensorShardingAttr`s or nullptr if there is none.
//
Expand Down
10 changes: 1 addition & 9 deletions shardy/dialect/sdy/transforms/import/inline_meshes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,7 @@ struct InlineMeshesPass : public impl::InlineMeshesPassBase<InlineMeshesPass> {
SymbolTable symbolTable(moduleOp);

transformShardings(moduleOp, [&](TensorShardingAttr sharding) {
if (auto name = dyn_cast<FlatSymbolRefAttr>(sharding.getMeshOrRef())) {
MeshAttr mesh = getMeshAttr(symbolTable, name);
assert(mesh && "unknown mesh");
return TensorShardingAttr::get(sharding.getContext(), mesh,
sharding.getDimShardings(),
sharding.getReplicatedAxes(),
sharding.getUnreducedAxes());
}
return sharding;
return inlineMesh(symbolTable, sharding);
});

// Remove all MeshOps.
Expand Down
Loading