Skip to content

Commit

Permalink
Find common axes deterministically by not depending on DenseMap itera…
Browse files Browse the repository at this point in the history
…tion order to find the best factor-axes pair.

For this, change FactorAxesPair from std::pair to struct, and implement a comparator method, so that factor-axes pairs imposes a total order.

It also helps that unit-testing is not flaky.

PiperOrigin-RevId: 696603604
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Nov 14, 2024
1 parent 3f43aab commit 2a11733
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
6 changes: 6 additions & 0 deletions shardy/dialect/sdy/ir/attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def Sdy_AxisRef : AttrDef<Sdy_Dialect, "AxisRef"> {
static std::function<bool(AxisRefAttr lhs, AxisRefAttr rhs)>
getMeshComparator(MeshAttr mesh);

// Returns a comparator that order axis names lexiographically. If both
// axis-refs have the same name, if one is a sub-axis and the other is the
// full axis, then the sub-axis comes first. If both are sub-axes then the
// smaller sub-axes comes first based on SubAxisInfoAttr comparator.
bool operator<(const AxisRefAttr &rhs) const;

std::string toString() const;

// Returns the size of this axis or sub-axis.
Expand Down
18 changes: 18 additions & 0 deletions shardy/dialect/sdy/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,24 @@ AxisRefAttr::getMeshComparator(MeshAttr mesh) {
};
}

bool AxisRefAttr::operator<(const AxisRefAttr& rhs) const {
StringRef name = getName();
StringRef rhsName = rhs.getName();
if (name != rhsName) {
return name < rhsName;
}
// Both axis-refs have the same name, if one is a sub-axis and the other
// is the full axis, then the sub-axis comes first.
if (!getSubAxisInfo()) {
return false;
}
if (!rhs.getSubAxisInfo()) {
return true;
}
// Both axis-refs are sub-axes.
return getSubAxisInfo() < rhs.getSubAxisInfo();
}

std::string AxisRefAttr::toString() const {
return strippedAttrString(*this, /*stripMnemonic=*/true);
}
Expand Down
49 changes: 38 additions & 11 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,30 @@ bool axisRefsOverlap(ArrayRef<AxisRefAttr> first,
return false;
}

using FactorAxesPair = std::pair<int64_t, ArrayRef<AxisRefAttr>>;
struct FactorAxesPair {
int64_t factorIndex;
ArrayRef<AxisRefAttr> axisRefs;

FactorAxesPair(int64_t factorIndex, ArrayRef<AxisRefAttr> axisRefs)
: factorIndex(factorIndex), axisRefs(axisRefs) {}

FactorAxesPair() = default;

bool operator<(const FactorAxesPair& rhs) const {
if (factorIndex != rhs.factorIndex) {
return factorIndex < rhs.factorIndex;
}
if (axisRefs.size() != rhs.axisRefs.size()) {
return axisRefs.size() < rhs.axisRefs.size();
}
for (auto [axisRef, rhsAxisRef] : llvm::zip_equal(axisRefs, rhs.axisRefs)) {
if (axisRef != rhsAxisRef) {
return axisRef < rhsAxisRef;
}
}
return false;
}
};

// Broadly the algorithm is, at each iteration, to pick a {factor,axis} pair
// with the largest count from a list that is initialized with all the
Expand All @@ -192,11 +215,13 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
if (factorSharding.axisRefs.empty()) {
continue;
}
ArrayRef<AxisRefAttr> axisRefs = factorSharding.axisRefs;
int64_t axesCount = ++factorAxesCounts[factorIndex][axisRefs];
if (axesCount > maxCount) {
FactorAxesPair factorAxes(factorIndex, factorSharding.axisRefs);
int64_t axesCount =
++factorAxesCounts[factorAxes.factorIndex][factorAxes.axisRefs];
if (axesCount > maxCount ||
(axesCount == maxCount && factorAxes < bestFactorAxes)) {
maxCount = axesCount;
bestFactorAxes = FactorAxesPair(factorIndex, axisRefs);
bestFactorAxes = factorAxes;
}
}
}
Expand All @@ -210,9 +235,9 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
BitVector unseenFactors(numFactors, true);
// TODO(enver): Optimize to mark unseen only the factors with an axis.
while (maxCount > 0) {
factorAxisRefs[bestFactorAxes.first] =
llvm::to_vector(bestFactorAxes.second);
unseenFactors.reset(bestFactorAxes.first);
factorAxisRefs[bestFactorAxes.factorIndex] =
llvm::to_vector(bestFactorAxes.axisRefs);
unseenFactors.reset(bestFactorAxes.factorIndex);
// TODO(enver): Tie-breaking currently depends on the order of iteration.
// Consider some heuristic for breaking ties.
// Invalidate axes that overlaps with the picked one across all unseen
Expand All @@ -225,7 +250,7 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
// TODO(enver): Relax the overlap check. We need to erase in case of an
// overlap only if the factor indices appear together in any of the
// operands or results.
if (axisRefsOverlap(bestFactorAxes.second, axisRefs)) {
if (axisRefsOverlap(bestFactorAxes.axisRefs, axisRefs)) {
// TODO(enver): Optimize to flip unseen if all the axes of the factor
// have zero count.
// Clear the count of overlapping axis, effectively erasing.
Expand All @@ -234,9 +259,11 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
axesCounts[axisRefs] = 0;
continue;
}
if (count > maxCount) {
FactorAxesPair factorAxes(factorIndex, axisRefs);
if (count > maxCount ||
(count == maxCount && factorAxes < nextBestFactorAxes)) {
maxCount = count;
nextBestFactorAxes = FactorAxesPair(factorIndex, axisRefs);
nextBestFactorAxes = factorAxes;
}
}
}
Expand Down

0 comments on commit 2a11733

Please sign in to comment.