Skip to content

Allow matching on 3+ variant niche-encoded enums to optimize better #139729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use std::fmt;
#[cfg(feature = "nightly")]
use std::iter::Step;
use std::num::{NonZeroUsize, ParseIntError};
use std::ops::{Add, AddAssign, Mul, RangeInclusive, Sub};
use std::ops::{Add, AddAssign, Mul, RangeFull, RangeInclusive, Sub};
use std::str::FromStr;

use bitflags::bitflags;
Expand Down Expand Up @@ -1162,12 +1162,45 @@ impl WrappingRange {
}

/// Returns `true` if `size` completely fills the range.
///
/// Note that this is *not* the same as `self == WrappingRange::full(size)`.
/// Niche calculations can produce full ranges which are not the canonical one;
/// for example `Option<NonZero<u16>>` gets `valid_range: (..=0) | (1..)`.
#[inline]
fn is_full_for(&self, size: Size) -> bool {
let max_value = size.unsigned_int_max();
debug_assert!(self.start <= max_value && self.end <= max_value);
self.start == (self.end.wrapping_add(1) & max_value)
}

/// Checks whether this range is considered non-wrapping when the values are
/// interpreted as *unsigned* numbers of width `size`.
///
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
/// and `Err(..)` if the range is full so it depends how you think about it.
#[inline]
pub fn no_unsigned_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
if self.is_full_for(size) { Err(..) } else { Ok(self.start <= self.end) }
}

/// Checks whether this range is considered non-wrapping when the values are
/// interpreted as *signed* numbers of width `size`.
///
/// This is heavily dependent on the `size`, as `100..=200` does wrap when
/// interpreted as `i8`, but doesn't when interpreted as `i16`.
///
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
/// and `Err(..)` if the range is full so it depends how you think about it.
#[inline]
pub fn no_signed_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
if self.is_full_for(size) {
Err(..)
} else {
let start: i128 = size.sign_extend(self.start);
let end: i128 = size.sign_extend(self.end);
Ok(start <= end)
}
}
}

impl fmt::Debug for WrappingRange {
Expand Down
222 changes: 181 additions & 41 deletions compiler/rustc_codegen_ssa/src/mir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::fmt;
use arrayvec::ArrayVec;
use either::Either;
use rustc_abi as abi;
use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants};
use rustc_abi::{
Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, VariantIdx, Variants,
};
use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
use rustc_middle::mir::{self, ConstValue};
use rustc_middle::ty::Ty;
Expand Down Expand Up @@ -510,6 +512,8 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
);

let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let tag_range = tag_scalar.valid_range(&dl);
let tag_size = tag_scalar.size(&dl);

// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
Expand All @@ -525,53 +529,189 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
// untagged_variant
// }
// However, we will likely be able to emit simpler code.
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
// Best case scenario: only one tagged variant. This will
// likely become just a comparison and a jump.
// The algorithm is:
// is_niche = tag == niche_start
// discr = if is_niche {
// niche_start
// } else {
// untagged_variant
// }

// First, the incredibly-common case of a two-variant enum (like
// `Option` or `Result`) where we only need one check.
if relative_max == 0 {
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
let tagged_discr =
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
let cast_tag = bx.intcast(relative_discr, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
);

// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
let is_natural = bx.icmp(IntPredicate::IntNE, tag, niche_start);
return if untagged_variant == VariantIdx::from_u32(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this will be cleaner as a guard, i.e. return if a { b } else { c } -> if a { return b }; return c

&& *niche_variants.start() == VariantIdx::from_u32(0)
{
let impossible =
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
let impossible = bx.cx().const_uint(tag_llty, impossible);
let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible);
bx.assume(ne);
// The polarity of the comparison above is picked so we can
// just extend for `Option<T>`, which has these variants.
bx.zext(is_natural, cast_to)
} else {
let tagged_discr =
bx.cx().const_uint(cast_to, u64::from(niche_variants.start().as_u32()));
let untagged_discr =
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));
bx.select(is_natural, untagged_discr, tagged_discr)
};
}

let niche_end =
tag_size.truncate(u128::from(relative_max).wrapping_add(niche_start));

// Next, the layout algorithm prefers to put the niches at one end,
// so look for cases where we don't need to calculate a relative_tag
// at all and can just look at the original tag value directly.
// This also lets us move any possibly-wrapping addition to the end
// where it's easiest to get rid of in the normal uses: it's easy
// to optimize `COMPLICATED + 2 == 7` to `COMPLICATED == (7 - 2)`.
{
// Work in whichever size is wider, because it's possible for
// the untagged variant to be further away from the niches than
// is possible to represent in the smaller type.
let (wide_size, wide_ibty) = if cast_to_layout.size > tag_size {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume cast_to can be both wider and thinner than the "natural" tag size, to support as u8, as u128? Somewhat surprised that we don't just always return the natural type and let the caller deal with it...

(cast_to_layout.size, cast_to)
} else {
(tag_size, tag_llty)
};

struct NoWrapData<V> {
wide_tag: V,
is_niche: V,
needs_assume: bool,
wide_niche_to_variant: u128,
wide_niche_untagged: u128,
}

(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};
let first_variant = u128::from(niche_variants.start().as_u32());
let untagged_variant = u128::from(untagged_variant.as_u32());

let opt_data = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) {
let wide_tag = bx.zext(tag, wide_ibty);
let extend = |x| x;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is to be more similar to the signed case, but if so, can we refactor this out into a function?

let wide_niche_start = extend(niche_start);
let wide_niche_end = extend(niche_end);
debug_assert!(wide_niche_start <= wide_niche_end);
let wide_first_variant = extend(first_variant);
let wide_untagged_variant = extend(untagged_variant);
let wide_niche_to_variant =
wide_first_variant.wrapping_sub(wide_niche_start);
let wide_niche_untagged = wide_size
.truncate(wide_untagged_variant.wrapping_sub(wide_niche_to_variant));
let (is_niche, needs_assume) = if tag_range.start == niche_start {
let end = bx.cx().const_uint_big(tag_llty, niche_end);
(
bx.icmp(IntPredicate::IntULE, tag, end),
wide_niche_untagged <= wide_niche_end,
)
Comment on lines +583 to +600
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to review this, the amount of similarly named variables just overflows my cache :(

} else if tag_range.end == niche_end {
let start = bx.cx().const_uint_big(tag_llty, niche_start);
(
bx.icmp(IntPredicate::IntUGE, tag, start),
wide_niche_untagged >= wide_niche_start,
)
} else {
bug!()
};
Some(NoWrapData {
wide_tag,
is_niche,
needs_assume,
wide_niche_to_variant,
wide_niche_untagged,
})
} else if tag_range.no_signed_wraparound(tag_size) == Ok(true) {
let wide_tag = bx.sext(tag, wide_ibty);
let extend = |x| tag_size.sign_extend(x);
let wide_niche_start = extend(niche_start);
let wide_niche_end = extend(niche_end);
debug_assert!(wide_niche_start <= wide_niche_end);
let wide_first_variant = extend(first_variant);
let wide_untagged_variant = extend(untagged_variant);
let wide_niche_to_variant =
wide_first_variant.wrapping_sub(wide_niche_start);
let wide_niche_untagged = wide_size.sign_extend(
wide_untagged_variant
.wrapping_sub(wide_niche_to_variant)
.cast_unsigned(),
);
let (is_niche, needs_assume) = if tag_range.start == niche_start {
let end = bx.cx().const_uint_big(tag_llty, niche_end);
(
bx.icmp(IntPredicate::IntSLE, tag, end),
wide_niche_untagged <= wide_niche_end,
)
} else if tag_range.end == niche_end {
let start = bx.cx().const_uint_big(tag_llty, niche_start);
(
bx.icmp(IntPredicate::IntSGE, tag, start),
wide_niche_untagged >= wide_niche_start,
)
} else {
bug!()
};
Some(NoWrapData {
wide_tag,
is_niche,
needs_assume,
wide_niche_to_variant: wide_niche_to_variant.cast_unsigned(),
wide_niche_untagged: wide_niche_untagged.cast_unsigned(),
})
} else {
None
};
if let Some(NoWrapData {
wide_tag,
is_niche,
needs_assume,
wide_niche_to_variant,
wide_niche_untagged,
}) = opt_data
{
let wide_niche_untagged =
bx.cx().const_uint_big(wide_ibty, wide_niche_untagged);
if needs_assume && bx.cx().sess().opts.optimize != OptLevel::No {
let not_untagged =
bx.icmp(IntPredicate::IntNE, wide_tag, wide_niche_untagged);
bx.assume(not_untagged);
}

let wide_niche = bx.select(is_niche, wide_tag, wide_niche_untagged);
let cast_niche = bx.trunc(wide_niche, cast_to);
let discr = if wide_niche_to_variant == 0 {
cast_niche
} else {
let niche_to_variant =
bx.cx().const_uint_big(cast_to, wide_niche_to_variant);
bx.add(cast_niche, niche_to_variant)
};
return discr;
}
}

// Otherwise the special cases don't apply,
// so we'll have to go with the general algorithm.
let relative_tag = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
let relative_discr = bx.intcast(relative_tag, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_tag,
bx.cx().const_uint(tag_llty, u64::from(relative_max)),
);

// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let impossible =
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
let impossible = bx.cx().const_uint(tag_llty, impossible);
let ne = bx.icmp(IntPredicate::IntNE, relative_tag, impossible);
bx.assume(ne);
}

let delta = niche_variants.start().as_u32();
let tagged_discr = if delta == 0 {
tagged_discr
relative_discr
} else {
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
bx.add(relative_discr, bx.cx().const_uint(cast_to, u64::from(delta)))
};

let discr = bx.select(
Expand Down
Loading
Loading