Skip to content

Commit

Permalink
Use AxisRefsWithTail instead of ArrayRef of AxisRefs, so the tailing …
Browse files Browse the repository at this point in the history
…AxisRefAttr can be modified on factor-axis pairs.

It is to prepare the cases that we consider modified axes to assign to a factor, for example, a trimmed version of the existing axes. For example, if factor i is assigned to {"y":(4)2 }, then some other factor with a sharding {"x", "y", "z"} can now be assigned to {"x", "y":(1)4} and since the tailing axis "y":(1)4 is not in the original sharding sharding {"x", "y", "z"}, it can not be referred using ArrayRef, hence the tailing axis is AxisRefAttr.

This change is no-op, as the tailing axis is always empty.

PiperOrigin-RevId: 696503004
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Nov 14, 2024
1 parent 3f43aab commit e35d7ec
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 64 deletions.
105 changes: 67 additions & 38 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,54 @@ void insertExplicitReshards(Operation* op, const ShardingProjection& projection,
}
}

// Checks if any two axes, one from the first array, and the other from the
// second array, overlap.
// AxisRefsWithTail holds a pair of an axes-array and a 'tail' axis which
// together define axes as the concatanation of the two. The axes-array of the
// pair can not be empty, while the 'tail' axis may or may not be empty.
using AxisRefsWithTail = std::pair<ArrayRef<AxisRefAttr>, AxisRefAttr>;
using FactorAxesPair = std::pair<int64_t, AxisRefsWithTail>;

// Checks if `axisRef` overlaps with any of the axes of
// `againstAxisRefsWithTail`.
// TODO(enver): Optimize by using a set of AxisRefAttr.
bool axisRefsOverlap(ArrayRef<AxisRefAttr> first,
ArrayRef<AxisRefAttr> second) {
for (const auto& firstAxisRef : first) {
for (const auto& secondAxisRef : second) {
if (firstAxisRef.overlaps(secondAxisRef)) {
return true;
}
bool axisRefsOverlap(AxisRefAttr axisRef,
AxisRefsWithTail againstAxisRefsWithTail) {
auto& [againstAxisRefs, againstTailAxisRef] = againstAxisRefsWithTail;
for (const auto& againstAxisRef : againstAxisRefs) {
if (axisRef.overlaps(againstAxisRef)) {
return true;
}
}
if (againstTailAxisRef && axisRef.overlaps(againstTailAxisRef)) {
return true;
}
return false;
}

using FactorAxesPair = std::pair<int64_t, ArrayRef<AxisRefAttr>>;
// Checks if any two axes, one from `axisRefsWithTail`, and the other from the
// `againstAxisRefsWithTail`, overlap.
// TODO(enver): Optimize by using a set of AxisRefAttr.
bool axisRefsOverlap(AxisRefsWithTail axisRefsWithTail,
AxisRefsWithTail againstAxisRefsWithTail) {
auto& [axisRefs, tailAxisRef] = axisRefsWithTail;
for (const auto& axisRef : axisRefs) {
if (axisRefsOverlap(axisRef, againstAxisRefsWithTail)) {
return true;
}
}
if (tailAxisRef && axisRefsOverlap(tailAxisRef, againstAxisRefsWithTail)) {
return true;
}
return false;
}

SmallVector<AxisRefAttr> toVector(AxisRefsWithTail axisRefsWithTail) {
auto& [axisRefs, tailAxisRef] = axisRefsWithTail;
SmallVector<AxisRefAttr> resultAxisRefs = llvm::to_vector(axisRefs);
if (tailAxisRef) {
resultAxisRefs.push_back(tailAxisRef);
}
return resultAxisRefs;
}

// Broadly the algorithm is, at each iteration, to pick a {factor,axis} pair
// with the largest count from a list that is initialized with all the
Expand All @@ -180,8 +212,7 @@ using FactorAxesPair = std::pair<int64_t, ArrayRef<AxisRefAttr>>;
AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
const ShardingProjection& projection, int64_t numFactors) {
AxesPerFactor factorAxisRefs(numFactors);
SmallVector<DenseMap<ArrayRef<AxisRefAttr>, int64_t>> factorAxesCounts(
numFactors);
DenseMap<FactorAxesPair, int64_t> factorAxesCounts;
int64_t maxCount = 0;
FactorAxesPair bestFactorAxes;
for (const TensorFactorShardings& tensorFactorSharding :
Expand All @@ -193,10 +224,11 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
continue;
}
ArrayRef<AxisRefAttr> axisRefs = factorSharding.axisRefs;
int64_t axesCount = ++factorAxesCounts[factorIndex][axisRefs];
FactorAxesPair factorAxes(factorIndex, {axisRefs, AxisRefAttr()});
int64_t axesCount = ++factorAxesCounts[factorAxes];
if (axesCount > maxCount) {
maxCount = axesCount;
bestFactorAxes = FactorAxesPair(factorIndex, axisRefs);
bestFactorAxes = factorAxes;
}
}
}
Expand All @@ -207,38 +239,35 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
// the count of [x] prefix will be two for this factor.
// TODO(enver): Assign an axis to a factor immediately if the count is more
// than floor(n/2) where n is the number of tensors.
BitVector unseenFactors(numFactors, true);
// TODO(enver): Optimize to mark unseen only the factors with an axis.
while (maxCount > 0) {
factorAxisRefs[bestFactorAxes.first] =
llvm::to_vector(bestFactorAxes.second);
unseenFactors.reset(bestFactorAxes.first);
factorAxisRefs[bestFactorAxes.first] = toVector(bestFactorAxes.second);
// TODO(enver): Tie-breaking currently depends on the order of iteration.
// Consider some heuristic for breaking ties.
// Invalidate axes that overlaps with the picked one across all unseen
// factors. During the iteration, also find the new best.
maxCount = 0;
FactorAxesPair nextBestFactorAxes;
for (int factorIndex : unseenFactors.set_bits()) {
auto& axesCounts = factorAxesCounts[factorIndex];
for (const auto& [axisRefs, count] : axesCounts) {
// TODO(enver): Relax the overlap check. We need to erase in case of an
// overlap only if the factor indices appear together in any of the
// operands or results.
if (axisRefsOverlap(bestFactorAxes.second, axisRefs)) {
// TODO(enver): Optimize to flip unseen if all the axes of the factor
// have zero count.
// Clear the count of overlapping axis, effectively erasing.
// TODO(enver): Instead of removing from the list, trim the axisRefs,
// to use the largest prefix that does not overlap with bestAxisRefs.
axesCounts[axisRefs] = 0;
continue;
}
if (count > maxCount) {
maxCount = count;
nextBestFactorAxes = FactorAxesPair(factorIndex, axisRefs);
}
for (auto factorAxesCountIt = factorAxesCounts.begin();
factorAxesCountIt != factorAxesCounts.end();) {
const auto& [factorAxes, count] = *factorAxesCountIt;
// TODO(enver): Relax the overlap check. We need to erase in case of an
// overlap only if the factor indices appear together in any of the
// operands or results.
if (factorAxes.first == bestFactorAxes.first ||
axisRefsOverlap(factorAxes.second, bestFactorAxes.second)) {
// TODO(enver): Optimize to flip unseen if all the axes of the factor
// have zero count.
// Clear the count of overlapping axis, effectively erasing.
// TODO(enver): Instead of removing from the list, trim the axisRefs,
// to use the largest prefix that does not overlap with bestAxisRefs.
factorAxesCounts.erase(factorAxesCountIt++);
continue;
}
if (count > maxCount) {
maxCount = count;
nextBestFactorAxes = factorAxes;
}
++factorAxesCountIt;
}
bestFactorAxes = nextBestFactorAxes;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ func.func @dot_compatible_contracting_dim_empty(%arg0: tensor<8x32xf32> {sdy.sha
}

// CHECK-LABEL: func @dot_incompatible_same_non_contracting_dims_out_empty
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %[[RESHARD1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[DOT]] <@mesh, [{}, {}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[RESHARD2]] : tensor<8x16xf32>
func.func @dot_incompatible_same_non_contracting_dims_out_empty(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<8x16xf32> {
Expand Down Expand Up @@ -195,20 +195,19 @@ func.func @dot_incompatible_i_mismatch(%arg0: tensor<8x32xf32> {sdy.sharding = #
}

// CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[DOT]] <@mesh, [{}, {"x"}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[RESHARD]] : tensor<8x16xf32>
// TODO(enver): A better solution could be to reshard operands, depending on the factor sizes.
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %[[DOT]] : tensor<8x16xf32>
func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
%0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k],[k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}

// CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_rhs_non_contracting_dim_is_sharded
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {"y"}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %[[DOT]] : tensor<8x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[DOT]] <@mesh, [{"x"}, {}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[RESHARD]] : tensor<8x16xf32>
func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_rhs_non_contracting_dim_is_sharded(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
%0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k],[k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
Expand All @@ -225,20 +224,20 @@ func.func @dot_incompatible_in_out_mismatch_i_j_swapped(%arg0: tensor<8x32xf32>
}

// CHECK-LABEL: func @dot_incompatible_sub_axis_overlaps
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x":(2)2}, {}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[DOT]] <@mesh, [{}, {"x"}]> : tensor<8x16xf32>
// return %[[RESHARD]] : tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %[[DOT]] : tensor<8x16xf32>
func.func @dot_incompatible_sub_axis_overlaps(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(2)2}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
%0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k],[k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}

// CHECK-LABEL: func @dot_incompatible_all_factors_mismatch
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{}, {"y"}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD3:.*]] = sdy.reshard %[[DOT]] <@mesh, [{"y"}, {"x"}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[RESHARD3]] : tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %[[RESHARD1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[DOT]] <@mesh, [{"y"}, {"x"}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[RESHARD2]] : tensor<8x16xf32>
func.func @dot_incompatible_all_factors_mismatch(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) {
%0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k],[k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
Expand All @@ -247,10 +246,9 @@ func.func @dot_incompatible_all_factors_mismatch(%arg0: tensor<8x32xf32> {sdy.sh
// CHECK-LABEL: func @dot_reshard_is_local
func.func @dot_reshard_is_local(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
%0 = stablehlo.negate %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>]>} : tensor<32x16xf32>
// CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {"y"}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %0 <@mesh, [{"y"}, {}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[DOT]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<8x16xf32>
// CHECK: %[[DOT:.*]] = stablehlo.dot %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k], [k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[DOT]] <@mesh, [{"x"}, {}]> : tensor<8x16xf32>
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<8x16xf32>
// CHECK-NEXT: return %[[NEGATE]] : tensor<8x16xf32>
%1 = stablehlo.dot %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>, sdy.sharding_rule = #sdy.op_sharding_rule<([i, k],[k, j])->([i, j]) {i=8, j=16, k=32}>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
%2 = stablehlo.negate %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : tensor<8x16xf32>
Expand All @@ -267,10 +265,9 @@ func.func @dot_reshard_does_not_change_input_sharding(%arg0: tensor<8x32xf32> {s
}

// CHECK-LABEL: func @dot_without_sharding_rule
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {"y"}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %[[DOT]] : tensor<8x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[DOT]] <@mesh, [{"x"}, {}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[RESHARD]] : tensor<8x16xf32>
func.func @dot_without_sharding_rule(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
%0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
Expand Down

0 comments on commit e35d7ec

Please sign in to comment.