-
Notifications
You must be signed in to change notification settings - Fork 13.3k
Allow matching on 3+ variant niche-encoded enums to optimize better #139729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,9 @@ use std::fmt; | |
use arrayvec::ArrayVec; | ||
use either::Either; | ||
use rustc_abi as abi; | ||
use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants}; | ||
use rustc_abi::{ | ||
Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, VariantIdx, Variants, | ||
}; | ||
use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range}; | ||
use rustc_middle::mir::{self, ConstValue}; | ||
use rustc_middle::ty::Ty; | ||
|
@@ -510,6 +512,8 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { | |
); | ||
|
||
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32(); | ||
let tag_range = tag_scalar.valid_range(&dl); | ||
let tag_size = tag_scalar.size(&dl); | ||
|
||
// We have a subrange `niche_start..=niche_end` inside `range`. | ||
// If the value of the tag is inside this subrange, it's a | ||
|
@@ -525,53 +529,189 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { | |
// untagged_variant | ||
// } | ||
// However, we will likely be able to emit simpler code. | ||
let (is_niche, tagged_discr, delta) = if relative_max == 0 { | ||
// Best case scenario: only one tagged variant. This will | ||
// likely become just a comparison and a jump. | ||
// The algorithm is: | ||
// is_niche = tag == niche_start | ||
// discr = if is_niche { | ||
// niche_start | ||
// } else { | ||
// untagged_variant | ||
// } | ||
|
||
// First, the incredibly-common case of a two-variant enum (like | ||
// `Option` or `Result`) where we only need one check. | ||
if relative_max == 0 { | ||
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start); | ||
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start); | ||
let tagged_discr = | ||
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64); | ||
(is_niche, tagged_discr, 0) | ||
} else { | ||
// The special cases don't apply, so we'll have to go with | ||
// the general algorithm. | ||
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start)); | ||
let cast_tag = bx.intcast(relative_discr, cast_to, false); | ||
let is_niche = bx.icmp( | ||
IntPredicate::IntULE, | ||
relative_discr, | ||
bx.cx().const_uint(tag_llty, relative_max as u64), | ||
); | ||
|
||
// Thanks to parameter attributes and load metadata, LLVM already knows | ||
// the general valid range of the tag. It's possible, though, for there | ||
// to be an impossible value *in the middle*, which those ranges don't | ||
// communicate, so it's worth an `assume` to let the optimizer know. | ||
if niche_variants.contains(&untagged_variant) | ||
&& bx.cx().sess().opts.optimize != OptLevel::No | ||
let is_natural = bx.icmp(IntPredicate::IntNE, tag, niche_start); | ||
return if untagged_variant == VariantIdx::from_u32(1) | ||
&& *niche_variants.start() == VariantIdx::from_u32(0) | ||
{ | ||
let impossible = | ||
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32()); | ||
let impossible = bx.cx().const_uint(tag_llty, impossible); | ||
let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible); | ||
bx.assume(ne); | ||
// The polarity of the comparison above is picked so we can | ||
// just extend for `Option<T>`, which has these variants. | ||
bx.zext(is_natural, cast_to) | ||
} else { | ||
let tagged_discr = | ||
bx.cx().const_uint(cast_to, u64::from(niche_variants.start().as_u32())); | ||
let untagged_discr = | ||
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32())); | ||
bx.select(is_natural, untagged_discr, tagged_discr) | ||
}; | ||
} | ||
|
||
let niche_end = | ||
tag_size.truncate(u128::from(relative_max).wrapping_add(niche_start)); | ||
|
||
// Next, the layout algorithm prefers to put the niches at one end, | ||
// so look for cases where we don't need to calculate a relative_tag | ||
// at all and can just look at the original tag value directly. | ||
// This also lets us move any possibly-wrapping addition to the end | ||
// where it's easiest to get rid of in the normal uses: it's easy | ||
// to optimize `COMPLICATED + 2 == 7` to `COMPLICATED == (7 - 2)`. | ||
{ | ||
// Work in whichever size is wider, because it's possible for | ||
// the untagged variant to be further away from the niches than | ||
// is possible to represent in the smaller type. | ||
let (wide_size, wide_ibty) = if cast_to_layout.size > tag_size { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume |
||
(cast_to_layout.size, cast_to) | ||
} else { | ||
(tag_size, tag_llty) | ||
}; | ||
|
||
struct NoWrapData<V> { | ||
wide_tag: V, | ||
is_niche: V, | ||
needs_assume: bool, | ||
wide_niche_to_variant: u128, | ||
wide_niche_untagged: u128, | ||
} | ||
|
||
(is_niche, cast_tag, niche_variants.start().as_u32() as u128) | ||
}; | ||
let first_variant = u128::from(niche_variants.start().as_u32()); | ||
let untagged_variant = u128::from(untagged_variant.as_u32()); | ||
|
||
let opt_data = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) { | ||
let wide_tag = bx.zext(tag, wide_ibty); | ||
let extend = |x| x; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this is to be more similar to the signed case, but if so, can we refactor this out into a function? |
||
let wide_niche_start = extend(niche_start); | ||
let wide_niche_end = extend(niche_end); | ||
debug_assert!(wide_niche_start <= wide_niche_end); | ||
let wide_first_variant = extend(first_variant); | ||
let wide_untagged_variant = extend(untagged_variant); | ||
let wide_niche_to_variant = | ||
wide_first_variant.wrapping_sub(wide_niche_start); | ||
let wide_niche_untagged = wide_size | ||
.truncate(wide_untagged_variant.wrapping_sub(wide_niche_to_variant)); | ||
let (is_niche, needs_assume) = if tag_range.start == niche_start { | ||
let end = bx.cx().const_uint_big(tag_llty, niche_end); | ||
( | ||
bx.icmp(IntPredicate::IntULE, tag, end), | ||
wide_niche_untagged <= wide_niche_end, | ||
) | ||
Comment on lines
+583
to
+600
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how to review this, the amount of similarly named variables just overflows my cache :( |
||
} else if tag_range.end == niche_end { | ||
let start = bx.cx().const_uint_big(tag_llty, niche_start); | ||
( | ||
bx.icmp(IntPredicate::IntUGE, tag, start), | ||
wide_niche_untagged >= wide_niche_start, | ||
) | ||
} else { | ||
bug!() | ||
}; | ||
Some(NoWrapData { | ||
wide_tag, | ||
is_niche, | ||
needs_assume, | ||
wide_niche_to_variant, | ||
wide_niche_untagged, | ||
}) | ||
} else if tag_range.no_signed_wraparound(tag_size) == Ok(true) { | ||
let wide_tag = bx.sext(tag, wide_ibty); | ||
let extend = |x| tag_size.sign_extend(x); | ||
let wide_niche_start = extend(niche_start); | ||
let wide_niche_end = extend(niche_end); | ||
debug_assert!(wide_niche_start <= wide_niche_end); | ||
let wide_first_variant = extend(first_variant); | ||
let wide_untagged_variant = extend(untagged_variant); | ||
let wide_niche_to_variant = | ||
wide_first_variant.wrapping_sub(wide_niche_start); | ||
let wide_niche_untagged = wide_size.sign_extend( | ||
wide_untagged_variant | ||
.wrapping_sub(wide_niche_to_variant) | ||
.cast_unsigned(), | ||
); | ||
let (is_niche, needs_assume) = if tag_range.start == niche_start { | ||
let end = bx.cx().const_uint_big(tag_llty, niche_end); | ||
( | ||
bx.icmp(IntPredicate::IntSLE, tag, end), | ||
wide_niche_untagged <= wide_niche_end, | ||
) | ||
} else if tag_range.end == niche_end { | ||
let start = bx.cx().const_uint_big(tag_llty, niche_start); | ||
( | ||
bx.icmp(IntPredicate::IntSGE, tag, start), | ||
wide_niche_untagged >= wide_niche_start, | ||
) | ||
} else { | ||
bug!() | ||
}; | ||
Some(NoWrapData { | ||
wide_tag, | ||
is_niche, | ||
needs_assume, | ||
wide_niche_to_variant: wide_niche_to_variant.cast_unsigned(), | ||
wide_niche_untagged: wide_niche_untagged.cast_unsigned(), | ||
}) | ||
} else { | ||
None | ||
}; | ||
if let Some(NoWrapData { | ||
wide_tag, | ||
is_niche, | ||
needs_assume, | ||
wide_niche_to_variant, | ||
wide_niche_untagged, | ||
}) = opt_data | ||
{ | ||
let wide_niche_untagged = | ||
bx.cx().const_uint_big(wide_ibty, wide_niche_untagged); | ||
if needs_assume && bx.cx().sess().opts.optimize != OptLevel::No { | ||
let not_untagged = | ||
bx.icmp(IntPredicate::IntNE, wide_tag, wide_niche_untagged); | ||
bx.assume(not_untagged); | ||
} | ||
|
||
let wide_niche = bx.select(is_niche, wide_tag, wide_niche_untagged); | ||
let cast_niche = bx.trunc(wide_niche, cast_to); | ||
let discr = if wide_niche_to_variant == 0 { | ||
cast_niche | ||
} else { | ||
let niche_to_variant = | ||
bx.cx().const_uint_big(cast_to, wide_niche_to_variant); | ||
bx.add(cast_niche, niche_to_variant) | ||
}; | ||
return discr; | ||
} | ||
} | ||
|
||
// Otherwise the special cases don't apply, | ||
// so we'll have to go with the general algorithm. | ||
let relative_tag = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start)); | ||
let relative_discr = bx.intcast(relative_tag, cast_to, false); | ||
let is_niche = bx.icmp( | ||
IntPredicate::IntULE, | ||
relative_tag, | ||
bx.cx().const_uint(tag_llty, u64::from(relative_max)), | ||
); | ||
|
||
// Thanks to parameter attributes and load metadata, LLVM already knows | ||
// the general valid range of the tag. It's possible, though, for there | ||
// to be an impossible value *in the middle*, which those ranges don't | ||
// communicate, so it's worth an `assume` to let the optimizer know. | ||
if niche_variants.contains(&untagged_variant) | ||
&& bx.cx().sess().opts.optimize != OptLevel::No | ||
{ | ||
let impossible = | ||
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32()); | ||
let impossible = bx.cx().const_uint(tag_llty, impossible); | ||
let ne = bx.icmp(IntPredicate::IntNE, relative_tag, impossible); | ||
bx.assume(ne); | ||
} | ||
|
||
let delta = niche_variants.start().as_u32(); | ||
let tagged_discr = if delta == 0 { | ||
tagged_discr | ||
relative_discr | ||
} else { | ||
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta)) | ||
bx.add(relative_discr, bx.cx().const_uint(cast_to, u64::from(delta))) | ||
}; | ||
|
||
let discr = bx.select( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think this will be cleaner as a guard, i.e.
return if a { b } else { c }
->if a { return b }; return c