Skip to content

Commit b4149a0

Browse files
authored
[MLIR][Presburger] Fix Gaussian elimination (#164437)
In the Presburger library, there are two minor bugs of Gaussian elimination. In Barvinok.cpp, the `if (equations(i, i) != 0) continue;` is intended to skip only the row-swapping, but it in fact skipped the whole loop body altogether, including the elimination parts. In IntegerRelation.cpp, the Gaussian elimination forgets to advance `firstVar` (the number of finished columns) when it finishes a column. Moreover, when it checks the pivot row of each column, it didn't ignore the rows considered. As an example, suppose the constraints are ``` 1 0 0 1 2 = 0 0 1 0 0 3 = 0 0 0 0 1 4 = 0 ... ``` For the 4th column, it will think the pivot is the first row `1 0 0 1 2 = 0`, rather than the correct 3rd row `0 0 0 1 4 = 0`. (This bug is left undiscovered, because if we don't advance `firstVar` then this Gaussian elimination process will simply do nothing. Moreover, it is called only in `simplify()`, and the existing test cases doesn't care whether a set has been simplified.)
1 parent bfc4571 commit b4149a0

File tree

4 files changed

+51
-9
lines changed

4 files changed

+51
-9
lines changed

mlir/lib/Analysis/Presburger/Barvinok.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ mlir::presburger::detail::solveParametricEquations(FracMatrix equations) {
178178
for (unsigned i = 0; i < d; ++i) {
179179
// First ensure that the diagonal element is nonzero, by swapping
180180
// it with a row that is non-zero at column i.
181-
if (equations(i, i) != 0)
182-
continue;
183-
for (unsigned j = i + 1; j < d; ++j) {
184-
if (equations(j, i) == 0)
185-
continue;
186-
equations.swapRows(j, i);
187-
break;
181+
if (equations(i, i) == 0) {
182+
for (unsigned j = i + 1; j < d; ++j) {
183+
if (equations(j, i) == 0)
184+
continue;
185+
equations.swapRows(j, i);
186+
break;
187+
}
188188
}
189189

190190
Fraction diagElement = equations(i, i);

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,15 +1112,29 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
11121112
return posLimit - posStart;
11131113
}
11141114

1115+
static std::optional<unsigned>
1116+
findEqualityWithNonZeroAfterRow(IntegerRelation &rel, unsigned fromRow,
1117+
unsigned colIdx) {
1118+
assert(fromRow < rel.getNumEqualities() && colIdx < rel.getNumCols() &&
1119+
"position out of bounds");
1120+
for (unsigned rowIdx = fromRow, e = rel.getNumEqualities(); rowIdx < e;
1121+
++rowIdx) {
1122+
if (rel.atEq(rowIdx, colIdx) != 0)
1123+
return rowIdx;
1124+
}
1125+
return std::nullopt;
1126+
}
1127+
11151128
bool IntegerRelation::gaussianEliminate() {
11161129
gcdTightenInequalities();
11171130
unsigned firstVar = 0, vars = getNumVars();
11181131
unsigned nowDone, eqs;
11191132
std::optional<unsigned> pivotRow;
11201133
for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) {
1121-
// Finds the first non-empty column.
1134+
// Finds the first non-empty column that we haven't dealt with.
11221135
for (; firstVar < vars; ++firstVar) {
1123-
if ((pivotRow = findConstraintWithNonZeroAt(firstVar, /*isEq=*/true)))
1136+
if ((pivotRow =
1137+
findEqualityWithNonZeroAfterRow(*this, nowDone, firstVar)))
11241138
break;
11251139
}
11261140
// The matrix has been normalized to row echelon form.
@@ -1143,6 +1157,10 @@ bool IntegerRelation::gaussianEliminate() {
11431157
inequalities.normalizeRow(i);
11441158
}
11451159
gcdTightenInequalities();
1160+
1161+
// The column is finished. Tell the next iteration to start at the next
1162+
// column.
1163+
firstVar++;
11461164
}
11471165

11481166
// No redundant rows.

mlir/unittests/Analysis/Presburger/BarvinokTest.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,12 @@ TEST(BarvinokTest, computeNumTermsPolytope) {
301301
gf = count[0].second;
302302
EXPECT_EQ(gf.getNumerators().size(), 24u);
303303
}
304+
305+
TEST(BarvinokTest, solveParametricEquations) {
306+
FracMatrix equations = makeFracMatrix(2, 3, {{2, 3, -4}, {2, 6, -7}});
307+
auto maybeSolution = solveParametricEquations(equations);
308+
ASSERT_TRUE(maybeSolution.has_value());
309+
FracMatrix solution = *maybeSolution;
310+
EXPECT_EQ(solution.at(0, 0), Fraction(1, 2));
311+
EXPECT_EQ(solution.at(1, 0), 1);
312+
}

mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,3 +725,18 @@ TEST(IntegerRelationTest, addLocalModulo) {
725725
EXPECT_TRUE(rel.containsPointNoLocal({x, x % 32}));
726726
}
727727
}
728+
729+
TEST(IntegerRelationTest, simplify) {
730+
IntegerRelation rel =
731+
parseRelationFromSet("(x, y, z): (2*x + y - 4*z - 3 == 0, "
732+
"3*x - y - 3*z + 2 == 0, x + 3*y - 5*z - 8 == 0,"
733+
"x - y + z >= 0)",
734+
2);
735+
IntegerRelation copy = rel;
736+
rel.simplify();
737+
738+
EXPECT_TRUE(rel.isEqual(copy));
739+
// The third equality is redundant and should be removed.
740+
// It can be obtained from 2 times the first equality minus the second.
741+
EXPECT_TRUE(rel.getNumEqualities() == 2);
742+
}

0 commit comments

Comments
 (0)