Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounds and alignment analysis through bitwise ops #8574

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion doc/RunGen.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Generator, and inits every element to zero:

```
# Input is a 3-dimensional image with extent 123, 456, and 3
# (bluring an image of all zeroes isn't very interesting, of course)
# (bluring an image of all zeros isn't very interesting, of course)
$ ./bin/local_laplacian.rungen --output_extents=[100,200,3] input=zero:[123,456,3] levels=8 alpha=1 beta=1 output=/tmp/out.png
```

Expand Down
4 changes: 4 additions & 0 deletions src/ConstantBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ ConstantInterval bounds_helper(const Expr &e,
ConstantInterval cq = recurse(op->args[2]);
ConstantInterval rounding_term = 1 << (cq - 1);
return (ca * cb + rounding_term) >> cq;
} else if (op->is_intrinsic(Call::bitwise_not)) {
// We can't do much with the other bitwise ops, but we can treat
// bitwise_not as an all-ones bit pattern minus the argument.
return recurse(make_const(e.type(), -1) - op->args[0]);
}
// If you add a new intrinsic here, also add it to the expression
// generator in test/correctness/lossless_cast.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/HexagonOffload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ Buffer<uint8_t> compile_module_to_hexagon_shared_object(const Module &device_cod
// This will cause a difference in MemSize and FileSize like so:
// FileSize = (MemSize - size_of_bss)
// When the Hexagon loader is used on 8998 and later targets,
// the difference is filled with zeroes thereby initializing the .bss
// the difference is filled with zeros thereby initializing the .bss
// section.
bss->set_type(Elf::Section::SHT_PROGBITS);
std::fill(bss->contents_begin(), bss->contents_end(), 0);
Expand Down
157 changes: 155 additions & 2 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) {
Simplify::VarInfo info;
info.old_uses = info.new_uses = 0;
if (const Variable *v = fact.as<Variable>()) {
info.replacement = const_false(fact.type().lanes());
info.replacement = Halide::Internal::const_false(fact.type().lanes());
simplify->var_info.push(v->name, info);
pop_list.push_back(v);
} else if (const NE *ne = fact.as<NE>()) {
Expand Down Expand Up @@ -178,7 +178,7 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) {
Simplify::VarInfo info;
info.old_uses = info.new_uses = 0;
if (const Variable *v = fact.as<Variable>()) {
info.replacement = const_true(fact.type().lanes());
info.replacement = Halide::Internal::const_true(fact.type().lanes());
simplify->var_info.push(v->name, info);
pop_list.push_back(v);
} else if (const EQ *eq = fact.as<EQ>()) {
Expand Down Expand Up @@ -496,5 +496,158 @@ bool can_prove(Expr e, const Scope<Interval> &bounds) {
return is_const_one(e);
}

Simplify::ExprInfo::BitsKnown Simplify::ExprInfo::to_bits_known(const Type &type) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this bitwise math is a little tricky to follow. Have you thrown this in an SMT solver? I think this should be verified

BitsKnown result = {0, 0};

if (!(type.is_int() || type.is_uint())) {
// Let's not claim we know anything about the bit patterns of
// non-integer types for now.
return result;
}

// Identify the largest power of two in the modulus to get some low bits
if (alignment.modulus) {
result.mask = largest_power_of_two_factor(alignment.modulus) - 1;
result.value = result.mask & alignment.remainder;
} else {
// This value is just a constant
result.mask = (uint64_t)(-1);
result.value = alignment.remainder;
return result;
}

// Compute a mask which is 1 for all the leading zeros of a uint64
auto leading_zeros_mask = [](uint64_t x) {
if (x == 0) {
// They're all leading zeros, but clz64 is UB on zero. Really we
// should have returned early above, but it's hard to guarantee that
// the alignment analysis catches constants at the same time as
// bounds analysis does.
return (uint64_t)-1;
}
return (uint64_t)(-1) << (64 - clz64(x));
};

auto leading_ones_mask = [=](uint64_t x) {
return leading_zeros_mask(~x);
};

// The bounds and the type tell us a bunch of high bits are zero or one
if (type.is_uint()) {
// Narrow uints are always zero-extended.
if (type.bits() < 64) {
result.mask |= (uint64_t)(-1) << type.bits();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this true regardless of the bounds?

}
// A lower bound might tell us that there are some leading ones, and an
// upper bound might tell us that there are some leading
// zeros. Unfortunately we'll never learn about leading ones, because
// uint64_ts that start with leading ones can't have a min represented
// as an int64_t, which is what ConstantInverval uses, so
// bounds.min_defined will never be true.
if (bounds.max_defined) {
result.mask |= leading_zeros_mask(bounds.max);
}
} else {
internal_assert(type.is_int());
// A mask which is 1 for the sign bit and above.
uint64_t sign_bit_and_above = (uint64_t)(-1) << (type.bits() - 1);
if (bounds >= 0) {
// We know this int is positive, so the sign bit and above are zero.
result.mask |= sign_bit_and_above;
if (bounds.max_defined) {
// We also have an upper bound, so there may be more zero bits,
// depending on how many leading zeros there are in the upper
// bound.
result.mask |= leading_zeros_mask(bounds.max);
}
} else if (bounds < 0) {
// This int is negative, so the sign bit and above are one.
result.mask |= sign_bit_and_above;
result.value |= sign_bit_and_above;
if (bounds.min_defined) {
// We have a lower bound, so there may be more leading one bits,
// depending on how many leading ones there are in the lower
// bound.
uint64_t leading_ones = leading_ones_mask(bounds.min);
result.mask |= leading_ones;
result.value |= leading_ones;
}
}
}

return result;
}

void Simplify::ExprInfo::from_bits_known(Simplify::ExprInfo::BitsKnown known, const Type &type) {
// Normalize everything to 64-bits by sign- or zero-extending known bits for
// the type.

// A mask which is one for all the new bits resulting from sign or zero
// extension.
uint64_t missing_bits = 0;
if (type.bits() < 64) {
missing_bits = (uint64_t)(-1) << type.bits();
}

if (missing_bits) {
if (type.is_uint()) {
// For a uint the high bits are known to be zero
known.mask |= missing_bits;
known.value &= ~missing_bits;
} else if (type.is_int()) {
// For an int we need to know the sign to know the high bits
bool sign_bit_known = (known.mask >> (type.bits() - 1)) & 1;
bool negative = (known.value >> (type.bits() - 1)) & 1;
if (!sign_bit_known) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These cases need comments explaining each of them

// We don't know the sign bit, so we don't know any of the
// extended bits. Mark them as unknown in the mask and zero them
// out in the value too just for ease of debugging.
known.mask &= ~missing_bits;
known.value &= ~missing_bits;
} else if (negative) {
// We know the sign bit is 1, so all of the extended bits are 1
// too.
known.mask |= missing_bits;
known.value |= missing_bits;
} else if (!negative) {
// We know the sign bit is zero, so all of the extended bits are
// zero too.
known.mask |= missing_bits;
known.value &= ~missing_bits;
}
}
}

// We can get the trailing one bits by adding one and taking the largest
// power of two factor. Note that this works out correctly when we know all
// the bits - the modulus comes out as zero, and the remainder is the entire
// number, which is how we represent constants in ModulusRemainder.
alignment.modulus = largest_power_of_two_factor(known.mask + 1);
alignment.remainder = known.value & (alignment.modulus - 1);

if ((int64_t)known.mask < 0) {
// We know some leading bits

// Set all unknown bits to zero
uint64_t min_val = known.value & known.mask;
// Set all unknown bits to one
uint64_t max_val = known.value | ~known.mask;

if (type.is_uint() && (int64_t)known.value < 0) {
// We know it's out of range at the top end for our ConstantInterval
// class. At the time of writing, to_bits_known can't produce this
// directly, and bits_known is never propagated through other
// operations, so this code is unreachable. Nonetheless we'll do the
// best job we can at representing this case in case this code
// becomes reachable in future.
bounds = ConstantInterval::bounded_below((1ULL << 63) - 1);
} else {
// In all other cases, the bounds are representable as an int64
// and don't span zero (because we know the high bit).
bounds = ConstantInterval{(int64_t)min_val, (int64_t)max_val};
}
}
}

} // namespace Internal
} // namespace Halide
64 changes: 38 additions & 26 deletions src/Simplify_And.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Internal {

Expr Simplify::visit(const And *op, ExprInfo *info) {
if (falsehoods.count(op)) {
return const_false(op->type.lanes());
return const_false(op->type.lanes(), info);
}

Expr a = mutate(op->a, nullptr);
Expand All @@ -16,33 +16,17 @@ Expr Simplify::visit(const And *op, ExprInfo *info) {
std::swap(a, b);
}

if (info) {
info->cast_to(op->type);
}

auto rewrite = IRMatcher::rewriter(IRMatcher::and_op(a, b), op->type);

// clang-format off
if (EVAL_IN_LAMBDA
(rewrite(x && true, a) ||
rewrite(x && false, b) ||
rewrite(x && x, a) ||

rewrite((x && y) && x, a) ||
rewrite(x && (x && y), b) ||
rewrite((x && y) && y, a) ||
rewrite(y && (x && y), b) ||

rewrite(((x && y) && z) && x, a) ||
rewrite(x && ((x && y) && z), b) ||
rewrite((z && (x && y)) && x, a) ||
rewrite(x && (z && (x && y)), b) ||
rewrite(((x && y) && z) && y, a) ||
rewrite(y && ((x && y) && z), b) ||
rewrite((z && (x && y)) && y, a) ||
rewrite(y && (z && (x && y)), b) ||

rewrite((x || y) && x, b) ||
rewrite(x && (x || y), a) ||
rewrite((x || y) && y, b) ||
rewrite(y && (x || y), a) ||

// Cases that fold to a constant
if (EVAL_IN_LAMBDA
(rewrite(x && false, false) ||
rewrite(x != y && x == y, false) ||
rewrite(x != y && y == x, false) ||
rewrite((z && x != y) && x == y, false) ||
Expand All @@ -57,7 +41,6 @@ Expr Simplify::visit(const And *op, ExprInfo *info) {
rewrite(!x && x, false) ||
rewrite(y <= x && x < y, false) ||
rewrite(y < x && x < y, false) ||
rewrite(x != c0 && x == c1, b, c0 != c1) ||
rewrite(x == c0 && x == c1, false, c0 != c1) ||
// Note: In the predicate below, if undefined overflow
// occurs, the predicate counts as false. If well-defined
Expand All @@ -69,7 +52,36 @@ Expr Simplify::visit(const And *op, ExprInfo *info) {
rewrite(x <= c1 && c0 < x, false, c1 <= c0) ||
rewrite(c0 <= x && x < c1, false, c1 <= c0) ||
rewrite(c0 <= x && x <= c1, false, c1 < c0) ||
rewrite(x <= c1 && c0 <= x, false, c1 < c0) ||
rewrite(x <= c1 && c0 <= x, false, c1 < c0))) {
set_expr_info_to_constant(info, false);
return rewrite.result;
}

// Cases that fold to one of the args
if (EVAL_IN_LAMBDA
(rewrite(x && true, a) ||
rewrite(x && x, a) ||

rewrite((x && y) && x, a) ||
rewrite(x && (x && y), b) ||
rewrite((x && y) && y, a) ||
rewrite(y && (x && y), b) ||

rewrite(((x && y) && z) && x, a) ||
rewrite(x && ((x && y) && z), b) ||
rewrite((z && (x && y)) && x, a) ||
rewrite(x && (z && (x && y)), b) ||
rewrite(((x && y) && z) && y, a) ||
rewrite(y && ((x && y) && z), b) ||
rewrite((z && (x && y)) && y, a) ||
rewrite(y && (z && (x && y)), b) ||

rewrite((x || y) && x, b) ||
rewrite(x && (x || y), a) ||
rewrite((x || y) && y, b) ||
rewrite(y && (x || y), a) ||

rewrite(x != c0 && x == c1, b, c0 != c1) ||
rewrite(c0 < x && c1 < x, fold(max(c0, c1)) < x) ||
rewrite(c0 <= x && c1 <= x, fold(max(c0, c1)) <= x) ||
rewrite(x < c0 && x < c1, x < fold(min(c0, c1))) ||
Expand Down
Loading
Loading