diff --git a/Cargo.lock b/Cargo.lock index 49c1eb5b45f77..0553cf76f0a2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4562,6 +4562,7 @@ dependencies = [ "rustc_hir", "rustc_middle", "rustc_span", + "smallvec", "tracing", ] diff --git a/compiler/rustc_transmute/Cargo.toml b/compiler/rustc_transmute/Cargo.toml index f0c783b30020e..0250cc0ea0788 100644 --- a/compiler/rustc_transmute/Cargo.toml +++ b/compiler/rustc_transmute/Cargo.toml @@ -5,11 +5,13 @@ edition = "2024" [dependencies] # tidy-alphabetical-start +itertools = "0.12" rustc_abi = { path = "../rustc_abi", optional = true } rustc_data_structures = { path = "../rustc_data_structures" } rustc_hir = { path = "../rustc_hir", optional = true } rustc_middle = { path = "../rustc_middle", optional = true } rustc_span = { path = "../rustc_span", optional = true } +smallvec = "1.8.1" tracing = "0.1" # tidy-alphabetical-end @@ -20,8 +22,3 @@ rustc = [ "dep:rustc_middle", "dep:rustc_span", ] - -[dev-dependencies] -# tidy-alphabetical-start -itertools = "0.12" -# tidy-alphabetical-end diff --git a/compiler/rustc_transmute/src/layout/dfa.rs b/compiler/rustc_transmute/src/layout/dfa.rs index bb909c54d2bc3..d1f58157b696b 100644 --- a/compiler/rustc_transmute/src/layout/dfa.rs +++ b/compiler/rustc_transmute/src/layout/dfa.rs @@ -1,8 +1,9 @@ use std::fmt; +use std::ops::RangeInclusive; use std::sync::atomic::{AtomicU32, Ordering}; use super::{Byte, Ref, Tree, Uninhabited}; -use crate::Map; +use crate::{Map, Set}; #[derive(PartialEq)] #[cfg_attr(test, derive(Clone))] @@ -20,7 +21,7 @@ pub(crate) struct Transitions where R: Ref, { - byte_transitions: Map, + byte_transitions: EdgeSet, ref_transitions: Map, } @@ -29,7 +30,7 @@ where R: Ref, { fn default() -> Self { - Self { byte_transitions: Map::default(), ref_transitions: Map::default() } + Self { byte_transitions: EdgeSet::empty(), ref_transitions: Map::default() } } } @@ -56,15 +57,10 @@ where { #[cfg(test)] pub(crate) fn bool() -> Self { - let mut transitions: Map> = Map::default(); - let start = State::new(); - let accept = State::new(); - - transitions.entry(start).or_default().byte_transitions.insert(Byte::Init(0x00), accept); - - transitions.entry(start).or_default().byte_transitions.insert(Byte::Init(0x01), accept); - - Self { transitions, start, accept } + Self::from_transitions(|accept| Transitions { + byte_transitions: EdgeSet::new(Byte::new(0x00..=0x01), accept), + ref_transitions: Map::default(), + }) } pub(crate) fn unit() -> Self { @@ -76,23 +72,24 @@ where } pub(crate) fn from_byte(byte: Byte) -> Self { - let mut transitions: Map> = Map::default(); - let start = State::new(); - let accept = State::new(); - - transitions.entry(start).or_default().byte_transitions.insert(byte, accept); - - Self { transitions, start, accept } + Self::from_transitions(|accept| Transitions { + byte_transitions: EdgeSet::new(byte, accept), + ref_transitions: Map::default(), + }) } pub(crate) fn from_ref(r: R) -> Self { - let mut transitions: Map> = Map::default(); + Self::from_transitions(|accept| Transitions { + byte_transitions: EdgeSet::empty(), + ref_transitions: [(r, accept)].into_iter().collect(), + }) + } + + fn from_transitions(f: impl FnOnce(State) -> Transitions) -> Self { let start = State::new(); let accept = State::new(); - transitions.entry(start).or_default().ref_transitions.insert(r, accept); - - Self { transitions, start, accept } + Self { transitions: [(start, f(accept))].into_iter().collect(), start, accept } } pub(crate) fn from_tree(tree: Tree) -> Result { @@ -132,13 +129,16 @@ where for (source, transition) in other.transitions { let fix_state = |state| if state == other.start { self.accept } else { state }; - let entry = transitions.entry(fix_state(source)).or_default(); - for (edge, destination) in transition.byte_transitions { - entry.byte_transitions.insert(edge, fix_state(destination)); - } - for (edge, destination) in transition.ref_transitions { - entry.ref_transitions.insert(edge, fix_state(destination)); - } + let byte_transitions = transition.byte_transitions.map_states(&fix_state); + let ref_transitions = transition + .ref_transitions + .into_iter() + .map(|(r, state)| (r, fix_state(state))) + .collect(); + + let old = transitions + .insert(fix_state(source), Transitions { byte_transitions, ref_transitions }); + assert!(old.is_none()); } Self { transitions, start, accept } @@ -170,67 +170,111 @@ where let start = mapped((Some(a.start), Some(b.start))); let mut transitions: Map> = Map::default(); - let mut queue = vec![(Some(a.start), Some(b.start))]; let empty_transitions = Transitions::default(); - while let Some((a_src, b_src)) = queue.pop() { + struct WorkQueue { + queue: Vec<(Option, Option)>, + // Track all entries ever enqueued to avoid duplicating work. This + // gives us a guarantee that a given (a_state, b_state) pair will + // only ever be visited once. + enqueued: Set<(Option, Option)>, + } + impl WorkQueue { + fn enqueue(&mut self, a_state: Option, b_state: Option) { + if self.enqueued.insert((a_state, b_state)) { + self.queue.push((a_state, b_state)); + } + } + } + let mut queue = WorkQueue { queue: Vec::new(), enqueued: Set::default() }; + queue.enqueue(Some(a.start), Some(b.start)); + + while let Some((a_src, b_src)) = queue.queue.pop() { + let src = mapped((a_src, b_src)); + if src == accept { + // While it's possible to have a DFA whose accept state has + // out-edges, these do not affect the semantics of the DFA, and + // so there's no point in processing them. Continuing here also + // has the advantage of guaranteeing that we only ever process a + // given node in the output DFA once. In particular, with the + // exception of the accept state, we ensure that we only push a + // given node to the `queue` once. This allows the following + // code to assume that we're processing a node we've never + // processed before, which means we never need to merge two edge + // sets - we only ever need to construct a new edge set from + // whole cloth. + continue; + } + let a_transitions = a_src.and_then(|a_src| a.transitions.get(&a_src)).unwrap_or(&empty_transitions); let b_transitions = b_src.and_then(|b_src| b.transitions.get(&b_src)).unwrap_or(&empty_transitions); let byte_transitions = - a_transitions.byte_transitions.keys().chain(b_transitions.byte_transitions.keys()); - - for byte_transition in byte_transitions { - let a_dst = a_transitions.byte_transitions.get(byte_transition).copied(); - let b_dst = b_transitions.byte_transitions.get(byte_transition).copied(); + a_transitions.byte_transitions.union(&b_transitions.byte_transitions); + let byte_transitions = byte_transitions.map_states(|(a_dst, b_dst)| { assert!(a_dst.is_some() || b_dst.is_some()); - let src = mapped((a_src, b_src)); - let dst = mapped((a_dst, b_dst)); - - transitions.entry(src).or_default().byte_transitions.insert(*byte_transition, dst); - - if !transitions.contains_key(&dst) { - queue.push((a_dst, b_dst)) - } - } + queue.enqueue(a_dst, b_dst); + mapped((a_dst, b_dst)) + }); let ref_transitions = a_transitions.ref_transitions.keys().chain(b_transitions.ref_transitions.keys()); - for ref_transition in ref_transitions { - let a_dst = a_transitions.ref_transitions.get(ref_transition).copied(); - let b_dst = b_transitions.ref_transitions.get(ref_transition).copied(); + let ref_transitions = ref_transitions + .map(|ref_transition| { + let a_dst = a_transitions.ref_transitions.get(ref_transition).copied(); + let b_dst = b_transitions.ref_transitions.get(ref_transition).copied(); - assert!(a_dst.is_some() || b_dst.is_some()); - - let src = mapped((a_src, b_src)); - let dst = mapped((a_dst, b_dst)); + assert!(a_dst.is_some() || b_dst.is_some()); - transitions.entry(src).or_default().ref_transitions.insert(*ref_transition, dst); + queue.enqueue(a_dst, b_dst); + (*ref_transition, mapped((a_dst, b_dst))) + }) + .collect(); - if !transitions.contains_key(&dst) { - queue.push((a_dst, b_dst)) - } - } + let old = transitions.insert(src, Transitions { byte_transitions, ref_transitions }); + // See `if src == accept { ... }` above. The comment there explains + // why this assert is valid. + assert_eq!(old, None); } Self { transitions, start, accept } } - pub(crate) fn bytes_from(&self, start: State) -> Option<&Map> { - Some(&self.transitions.get(&start)?.byte_transitions) + pub(crate) fn states_from( + &self, + state: State, + src_validity: RangeInclusive, + ) -> impl Iterator { + self.transitions + .get(&state) + .map(move |t| t.byte_transitions.states_from(src_validity)) + .into_iter() + .flatten() + } + + pub(crate) fn get_uninit_edge_dst(&self, state: State) -> Option { + let transitions = self.transitions.get(&state)?; + transitions.byte_transitions.get_uninit_edge_dst() } - pub(crate) fn byte_from(&self, start: State, byte: Byte) -> Option { - self.transitions.get(&start)?.byte_transitions.get(&byte).copied() + pub(crate) fn bytes_from(&self, start: State) -> impl Iterator { + self.transitions + .get(&start) + .into_iter() + .flat_map(|transitions| transitions.byte_transitions.iter()) } - pub(crate) fn refs_from(&self, start: State) -> Option<&Map> { - Some(&self.transitions.get(&start)?.ref_transitions) + pub(crate) fn refs_from(&self, start: State) -> impl Iterator { + self.transitions + .get(&start) + .into_iter() + .flat_map(|transitions| transitions.ref_transitions.iter()) + .map(|(r, s)| (*r, *s)) } #[cfg(test)] @@ -241,15 +285,25 @@ where ) -> Self { let start = State(start); let accept = State(accept); - let mut transitions: Map> = Map::default(); + let mut transitions: Map> = Map::default(); - for &(src, edge, dst) in edges { - let src = State(src); - let dst = State(dst); - let old = transitions.entry(src).or_default().byte_transitions.insert(edge.into(), dst); - assert!(old.is_none()); + for (src, edge, dst) in edges.iter().copied() { + transitions.entry(State(src)).or_default().push((edge.into(), State(dst))); } + let transitions = transitions + .into_iter() + .map(|(src, edges)| { + ( + src, + Transitions { + byte_transitions: EdgeSet::from_edges(edges), + ref_transitions: Map::default(), + }, + ) + }) + .collect(); + Self { start, accept, transitions } } } @@ -277,3 +331,242 @@ where writeln!(f, "}}") } } + +use edge_set::EdgeSet; +mod edge_set { + use std::cmp; + + use run::*; + use smallvec::{SmallVec, smallvec}; + + use super::*; + mod run { + use std::ops::{Range, RangeInclusive}; + + use super::*; + use crate::layout::Byte; + + /// A logical set of edges. + /// + /// A `Run` encodes one edge for every byte value in `start..=end` + /// pointing to `dst`. + #[derive(Eq, PartialEq, Copy, Clone, Debug)] + pub(super) struct Run { + // `start` and `end` are both inclusive (ie, closed) bounds, as this + // is required in order to be able to store 0..=255. We provide + // setters and getters which operate on closed/open ranges, which + // are more intuitive and easier for performing offset math. + start: u8, + end: u8, + pub(super) dst: S, + } + + impl Run { + pub(super) fn new(range: RangeInclusive, dst: S) -> Self { + Self { start: *range.start(), end: *range.end(), dst } + } + + pub(super) fn from_inclusive_exclusive(range: Range, dst: S) -> Self { + Self { + start: range.start.try_into().unwrap(), + end: (range.end - 1).try_into().unwrap(), + dst, + } + } + + pub(super) fn contains(&self, idx: u16) -> bool { + idx >= u16::from(self.start) && idx <= u16::from(self.end) + } + + pub(super) fn as_inclusive_exclusive(&self) -> (u16, u16) { + (u16::from(self.start), u16::from(self.end) + 1) + } + + pub(super) fn as_byte(&self) -> Byte { + Byte::new(self.start..=self.end) + } + + pub(super) fn map_state(self, f: impl FnOnce(S) -> SS) -> Run { + let Run { start, end, dst } = self; + Run { start, end, dst: f(dst) } + } + + /// Produces a new `Run` whose lower bound is the greater of + /// `self`'s existing lower bound and `lower_bound`. + pub(super) fn clamp_lower(self, lower_bound: u8) -> Self { + let Run { start, end, dst } = self; + Run { start: cmp::max(start, lower_bound), end, dst } + } + } + } + + /// The set of outbound byte edges associated with a DFA node (not including + /// reference edges). + #[derive(Eq, PartialEq, Clone, Debug)] + pub(super) struct EdgeSet { + // A sequence of runs stored in ascending order. Since the graph is a + // DFA, these must be non-overlapping with one another. + runs: SmallVec<[Run; 1]>, + // The edge labeled with the uninit byte, if any. + // + // FIXME(@joshlf): Make `State` a `NonZero` so that this is NPO'd. + uninit: Option, + } + + impl EdgeSet { + pub(crate) fn new(byte: Byte, dst: S) -> Self { + match byte.range() { + Some(range) => Self { runs: smallvec![Run::new(range, dst)], uninit: None }, + None => Self { runs: SmallVec::new(), uninit: Some(dst) }, + } + } + + pub(crate) fn empty() -> Self { + Self { runs: SmallVec::new(), uninit: None } + } + + #[cfg(test)] + pub(crate) fn from_edges(mut edges: Vec<(Byte, S)>) -> Self + where + S: Ord, + { + edges.sort(); + Self { + runs: edges + .into_iter() + .map(|(byte, state)| Run::new(byte.range().unwrap(), state)) + .collect(), + uninit: None, + } + } + + pub(crate) fn iter(&self) -> impl Iterator + where + S: Copy, + { + self.uninit + .map(|dst| (Byte::uninit(), dst)) + .into_iter() + .chain(self.runs.iter().map(|run| (run.as_byte(), run.dst))) + } + + pub(crate) fn states_from( + &self, + byte: RangeInclusive, + ) -> impl Iterator + where + S: Copy, + { + // FIXME(@joshlf): Optimize this. A manual scan over `self.runs` may + // permit us to more efficiently discard runs which will not be + // produced by this iterator. + self.iter().filter(move |(o, _)| Byte::new(byte.clone()).transmutable_into(&o)) + } + + pub(crate) fn get_uninit_edge_dst(&self) -> Option + where + S: Copy, + { + self.uninit + } + + pub(crate) fn map_states(self, mut f: impl FnMut(S) -> SS) -> EdgeSet { + EdgeSet { + // NOTE: It appears as through ` as + // IntoIterator>::IntoIter` and `std::iter::Map` both implement + // `TrustedLen`, which in turn means that this `.collect()` + // allocates the correct number of elements once up-front [1]. + // + // [1] https://doc.rust-lang.org/1.85.0/src/alloc/vec/spec_from_iter_nested.rs.html#47 + runs: self.runs.into_iter().map(|run| run.map_state(&mut f)).collect(), + uninit: self.uninit.map(f), + } + } + + /// Unions two edge sets together. + /// + /// If `u = a.union(b)`, then for each byte value, `u` will have an edge + /// with that byte value and with the destination `(Some(_), None)`, + /// `(None, Some(_))`, or `(Some(_), Some(_))` depending on whether `a`, + /// `b`, or both have an edge with that byte value. + /// + /// If neither `a` nor `b` have an edge with a particular byte value, + /// then no edge with that value will be present in `u`. + pub(crate) fn union(&self, other: &Self) -> EdgeSet<(Option, Option)> + where + S: Copy, + { + let uninit = match (self.uninit, other.uninit) { + (None, None) => None, + (s, o) => Some((s, o)), + }; + + let mut runs = SmallVec::new(); + + // Iterate over `self.runs` and `other.runs` simultaneously, + // advancing `idx` as we go. At each step, we advance `idx` as far + // as we can without crossing a run boundary in either `self.runs` + // or `other.runs`. + + // INVARIANT: `idx < s[0].end && idx < o[0].end`. + let (mut s, mut o) = (self.runs.as_slice(), other.runs.as_slice()); + let mut idx = 0u16; + while let (Some((s_run, s_rest)), Some((o_run, o_rest))) = + (s.split_first(), o.split_first()) + { + let (s_start, s_end) = s_run.as_inclusive_exclusive(); + let (o_start, o_end) = o_run.as_inclusive_exclusive(); + + // Compute `end` as the end of the current run (which starts + // with `idx`). + let (end, dst) = match (s_run.contains(idx), o_run.contains(idx)) { + // `idx` is in an existing run in both `s` and `o`, so `end` + // is equal to the smallest of the two ends of those runs. + (true, true) => (cmp::min(s_end, o_end), (Some(s_run.dst), Some(o_run.dst))), + // `idx` is in an existing run in `s`, but not in any run in + // `o`. `end` is either the end of the `s` run or the + // beginning of the next `o` run, whichever comes first. + (true, false) => (cmp::min(s_end, o_start), (Some(s_run.dst), None)), + // The inverse of the previous case. + (false, true) => (cmp::min(s_start, o_end), (None, Some(o_run.dst))), + // `idx` is not in a run in either `s` or `o`, so advance it + // to the beginning of the next run. + (false, false) => { + idx = cmp::min(s_start, o_start); + continue; + } + }; + + // FIXME(@joshlf): If this is contiguous with the previous run + // and has the same `dst`, just merge it into that run rather + // than adding a new one. + runs.push(Run::from_inclusive_exclusive(idx..end, dst)); + idx = end; + + if idx >= s_end { + s = s_rest; + } + if idx >= o_end { + o = o_rest; + } + } + + // At this point, either `s` or `o` have been exhausted, so the + // remaining elements in the other slice are guaranteed to be + // non-overlapping. We can add all remaining runs to `runs` with no + // further processing. + if let Ok(idx) = u8::try_from(idx) { + let (slc, map) = if !s.is_empty() { + let map: fn(_) -> _ = |st| (Some(st), None); + (s, map) + } else { + let map: fn(_) -> _ = |st| (None, Some(st)); + (o, map) + }; + runs.extend(slc.iter().map(|run| run.clamp_lower(idx).map_state(map))); + } + + EdgeSet { runs, uninit } + } + } +} diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs index c940f7c42a82f..d555ea702a9fe 100644 --- a/compiler/rustc_transmute/src/layout/mod.rs +++ b/compiler/rustc_transmute/src/layout/mod.rs @@ -1,5 +1,6 @@ use std::fmt::{self, Debug}; use std::hash::Hash; +use std::ops::RangeInclusive; pub(crate) mod tree; pub(crate) use tree::Tree; @@ -10,18 +11,56 @@ pub(crate) use dfa::Dfa; #[derive(Debug)] pub(crate) struct Uninhabited; -/// An instance of a byte is either initialized to a particular value, or uninitialized. -#[derive(Hash, Eq, PartialEq, Clone, Copy)] -pub(crate) enum Byte { - Uninit, - Init(u8), +/// A range of byte values, or the uninit byte. +#[derive(Hash, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)] +pub(crate) struct Byte { + // An inclusive-inclusive range. We use this instead of `RangeInclusive` + // because `RangeInclusive: !Copy`. + // + // `None` means uninit. + // + // FIXME(@joshlf): Optimize this representation. Some pairs of values (where + // `lo > hi`) are illegal, and we could use these to represent `None`. + range: Option<(u8, u8)>, +} + +impl Byte { + fn new(range: RangeInclusive) -> Self { + Self { range: Some((*range.start(), *range.end())) } + } + + fn from_val(val: u8) -> Self { + Self { range: Some((val, val)) } + } + + pub(crate) fn uninit() -> Byte { + Byte { range: None } + } + + /// Returns `None` if `self` is the uninit byte. + pub(crate) fn range(&self) -> Option> { + self.range.map(|(lo, hi)| lo..=hi) + } + + /// Are any of the values in `self` transmutable into `other`? + /// + /// Note two special cases: An uninit byte is only transmutable into another + /// uninit byte. Any byte is transmutable into an uninit byte. + pub(crate) fn transmutable_into(&self, other: &Byte) -> bool { + match (self.range, other.range) { + (None, None) => true, + (None, Some(_)) => false, + (Some(_), None) => true, + (Some((slo, shi)), Some((olo, ohi))) => slo <= ohi && olo <= shi, + } + } } impl fmt::Debug for Byte { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self { - Self::Uninit => f.write_str("??u8"), - Self::Init(b) => write!(f, "{b:#04x}u8"), + match self.range { + None => write!(f, "uninit"), + Some((lo, hi)) => write!(f, "{lo}..={hi}"), } } } @@ -29,7 +68,7 @@ impl fmt::Debug for Byte { #[cfg(test)] impl From for Byte { fn from(src: u8) -> Self { - Self::Init(src) + Self::from_val(src) } } @@ -62,6 +101,21 @@ impl Ref for ! { } } +#[cfg(test)] +impl Ref for [(); N] { + fn min_align(&self) -> usize { + N + } + + fn size(&self) -> usize { + N + } + + fn is_mutable(&self) -> bool { + false + } +} + #[cfg(feature = "rustc")] pub mod rustc { use std::fmt::{self, Write}; diff --git a/compiler/rustc_transmute/src/layout/tree.rs b/compiler/rustc_transmute/src/layout/tree.rs index 70ecc75403fd8..6a09be18ef944 100644 --- a/compiler/rustc_transmute/src/layout/tree.rs +++ b/compiler/rustc_transmute/src/layout/tree.rs @@ -54,22 +54,22 @@ where /// A `Tree` containing a single, uninitialized byte. pub(crate) fn uninit() -> Self { - Self::Byte(Byte::Uninit) + Self::Byte(Byte::uninit()) } /// A `Tree` representing the layout of `bool`. pub(crate) fn bool() -> Self { - Self::from_bits(0x00).or(Self::from_bits(0x01)) + Self::Byte(Byte::new(0x00..=0x01)) } /// A `Tree` whose layout matches that of a `u8`. pub(crate) fn u8() -> Self { - Self::Alt((0u8..=255).map(Self::from_bits).collect()) + Self::Byte(Byte::new(0x00..=0xFF)) } /// A `Tree` whose layout accepts exactly the given bit pattern. pub(crate) fn from_bits(bits: u8) -> Self { - Self::Byte(Byte::Init(bits)) + Self::Byte(Byte::from_val(bits)) } /// A `Tree` whose layout is a number of the given width. diff --git a/compiler/rustc_transmute/src/lib.rs b/compiler/rustc_transmute/src/lib.rs index 76fa6ceabe7e7..ce18dad55179c 100644 --- a/compiler/rustc_transmute/src/lib.rs +++ b/compiler/rustc_transmute/src/lib.rs @@ -1,8 +1,9 @@ // tidy-alphabetical-start +#![cfg_attr(test, feature(test))] #![feature(never_type)] // tidy-alphabetical-end -pub(crate) use rustc_data_structures::fx::FxIndexMap as Map; +pub(crate) use rustc_data_structures::fx::{FxIndexMap as Map, FxIndexSet as Set}; pub mod layout; mod maybe_transmutable; diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs index db0e1ab8e986a..0a19cccc2ed03 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs @@ -1,10 +1,14 @@ +use std::rc::Rc; +use std::{cmp, iter}; + +use itertools::Either; use tracing::{debug, instrument, trace}; pub(crate) mod query_context; #[cfg(test)] mod tests; -use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, Uninhabited, dfa}; +use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, dfa}; use crate::maybe_transmutable::query_context::QueryContext; use crate::{Answer, Condition, Map, Reason}; @@ -111,7 +115,7 @@ where // the `src` type do not exist. let src = match Dfa::from_tree(src) { Ok(src) => src, - Err(Uninhabited) => return Answer::Yes, + Err(layout::Uninhabited) => return Answer::Yes, }; // Convert `dst` from a tree-based representation to an DFA-based @@ -122,7 +126,7 @@ where // free of safety invariants. let dst = match Dfa::from_tree(dst) { Ok(dst) => dst, - Err(Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants), + Err(layout::Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants), }; MaybeTransmutableQuery { src, dst, assume, context }.answer() @@ -174,8 +178,8 @@ where // are able to safely transmute, even with truncation. Answer::Yes } else if src_state == self.src.accept { - // extension: `size_of(Src) >= size_of(Dst)` - if let Some(dst_state_prime) = self.dst.byte_from(dst_state, Byte::Uninit) { + // extension: `size_of(Src) <= size_of(Dst)` + if let Some(dst_state_prime) = self.dst.get_uninit_edge_dst(dst_state) { self.answer_memo(cache, src_state, dst_state_prime) } else { Answer::No(Reason::DstIsTooBig) @@ -193,26 +197,120 @@ where Quantifier::ForAll }; + let c = &core::cell::RefCell::new(&mut *cache); let bytes_answer = src_quantifier.apply( - // for each of the byte transitions out of the `src_state`... - self.src.bytes_from(src_state).unwrap_or(&Map::default()).into_iter().map( - |(&src_validity, &src_state_prime)| { - // ...try to find a matching transition out of `dst_state`. - if let Some(dst_state_prime) = - self.dst.byte_from(dst_state, src_validity) - { - self.answer_memo(cache, src_state_prime, dst_state_prime) - } else if let Some(dst_state_prime) = - // otherwise, see if `dst_state` has any outgoing `Uninit` transitions - // (any init byte is a valid uninit byte) - self.dst.byte_from(dst_state, Byte::Uninit) - { - self.answer_memo(cache, src_state_prime, dst_state_prime) - } else { - // otherwise, we've exhausted our options. - // the DFAs, from this point onwards, are bit-incompatible. - Answer::No(Reason::DstIsBitIncompatible) + // for each of the byte set transitions out of the `src_state`... + self.src.bytes_from(src_state).flat_map( + move |(src_validity, src_state_prime)| { + // ...find all matching transitions out of `dst_state`. + + let Some(src_validity) = src_validity.range() else { + // NOTE: We construct an iterator here rather + // than just computing the value directly (via + // `self.answer_memo`) so that, if the iterator + // we produce from this branch is + // short-circuited, we don't waste time + // computing `self.answer_memo` unnecessarily. + // That will specifically happen if + // `src_quantifier == Quantifier::ThereExists`, + // since we emit `Answer::Yes` first (before + // chaining `answer_iter`). + let answer_iter = if let Some(dst_state_prime) = + self.dst.get_uninit_edge_dst(dst_state) + { + Either::Left(iter::once_with(move || { + let mut c = c.borrow_mut(); + self.answer_memo(&mut *c, src_state_prime, dst_state_prime) + })) + } else { + Either::Right(iter::once(Answer::No( + Reason::DstIsBitIncompatible, + ))) + }; + + // When `answer == Answer::No(...)`, there are + // two cases to consider: + // - If `assume.validity`, then we should + // succeed because the user is responsible for + // ensuring that the *specific* byte value + // appearing at runtime is valid for the + // destination type. When `assume.validity`, + // `src_quantifier == + // Quantifier::ThereExists`, so adding an + // `Answer::Yes` has the effect of ensuring + // that the "there exists" is always + // satisfied. + // - If `!assume.validity`, then we should fail. + // In this case, `src_quantifier == + // Quantifier::ForAll`, so adding an + // `Answer::Yes` has no effect. + return Either::Left(iter::once(Answer::Yes).chain(answer_iter)); + }; + + #[derive(Copy, Clone, Debug)] + struct Accum { + // The number of matching byte edges that we + // have found in the destination so far. + sum: usize, + found_uninit: bool, } + + let accum1 = Rc::new(std::cell::Cell::new(Accum { + sum: 0, + found_uninit: false, + })); + let accum2 = Rc::clone(&accum1); + let sv = src_validity.clone(); + let update_accum = move |mut accum: Accum, dst_validity: Byte| { + if let Some(dst_validity) = dst_validity.range() { + // Only add the part of `dst_validity` that + // overlaps with `src_validity`. + let start = cmp::max(*sv.start(), *dst_validity.start()); + let end = cmp::min(*sv.end(), *dst_validity.end()); + + // We add 1 here to account for the fact + // that `end` is an inclusive bound. + accum.sum += 1 + usize::from(end.saturating_sub(start)); + } else { + accum.found_uninit = true; + } + accum + }; + + let answers = self + .dst + .states_from(dst_state, src_validity.clone()) + .map(move |(dst_validity, dst_state_prime)| { + let mut c = c.borrow_mut(); + accum1.set(update_accum(accum1.get(), dst_validity)); + let answer = + self.answer_memo(&mut *c, src_state_prime, dst_state_prime); + answer + }) + .chain( + iter::once_with(move || { + let src_validity_len = usize::from(*src_validity.end()) + - usize::from(*src_validity.start()) + + 1; + let accum = accum2.get(); + + // If this condition is false, then + // there are some byte values in the + // source which have no corresponding + // transition in the destination DFA. In + // that case, we add a `No` to our list + // of answers. When + // `!self.assume.validity`, this will + // cause the query to fail. + if accum.found_uninit || accum.sum == src_validity_len { + None + } else { + Some(Answer::No(Reason::DstIsBitIncompatible)) + } + }) + .flatten(), + ); + Either::Right(answers) }, ), ); @@ -235,48 +333,38 @@ where let refs_answer = src_quantifier.apply( // for each reference transition out of `src_state`... - self.src.refs_from(src_state).unwrap_or(&Map::default()).into_iter().map( - |(&src_ref, &src_state_prime)| { - // ...there exists a reference transition out of `dst_state`... - Quantifier::ThereExists.apply( - self.dst - .refs_from(dst_state) - .unwrap_or(&Map::default()) - .into_iter() - .map(|(&dst_ref, &dst_state_prime)| { - if !src_ref.is_mutable() && dst_ref.is_mutable() { - Answer::No(Reason::DstIsMoreUnique) - } else if !self.assume.alignment - && src_ref.min_align() < dst_ref.min_align() - { - Answer::No(Reason::DstHasStricterAlignment { - src_min_align: src_ref.min_align(), - dst_min_align: dst_ref.min_align(), - }) - } else if dst_ref.size() > src_ref.size() { - Answer::No(Reason::DstRefIsTooBig { - src: src_ref, - dst: dst_ref, - }) - } else { - // ...such that `src` is transmutable into `dst`, if - // `src_ref` is transmutability into `dst_ref`. - and( - Answer::If(Condition::IfTransmutable { - src: src_ref, - dst: dst_ref, - }), - self.answer_memo( - cache, - src_state_prime, - dst_state_prime, - ), - ) - } - }), - ) - }, - ), + self.src.refs_from(src_state).map(|(src_ref, src_state_prime)| { + // ...there exists a reference transition out of `dst_state`... + Quantifier::ThereExists.apply(self.dst.refs_from(dst_state).map( + |(dst_ref, dst_state_prime)| { + if !src_ref.is_mutable() && dst_ref.is_mutable() { + Answer::No(Reason::DstIsMoreUnique) + } else if !self.assume.alignment + && src_ref.min_align() < dst_ref.min_align() + { + Answer::No(Reason::DstHasStricterAlignment { + src_min_align: src_ref.min_align(), + dst_min_align: dst_ref.min_align(), + }) + } else if dst_ref.size() > src_ref.size() { + Answer::No(Reason::DstRefIsTooBig { + src: src_ref, + dst: dst_ref, + }) + } else { + // ...such that `src` is transmutable into `dst`, if + // `src_ref` is transmutability into `dst_ref`. + and( + Answer::If(Condition::IfTransmutable { + src: src_ref, + dst: dst_ref, + }), + self.answer_memo(cache, src_state_prime, dst_state_prime), + ) + } + }, + )) + }), ); if self.assume.validity { diff --git a/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs b/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs index f8b59bdf32684..214da101be375 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/query_context.rs @@ -8,9 +8,17 @@ pub(crate) trait QueryContext { #[cfg(test)] pub(crate) mod test { + use std::marker::PhantomData; + use super::QueryContext; - pub(crate) struct UltraMinimal; + pub(crate) struct UltraMinimal(PhantomData); + + impl Default for UltraMinimal { + fn default() -> Self { + Self(PhantomData) + } + } #[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)] pub(crate) enum Def { @@ -24,9 +32,9 @@ pub(crate) mod test { } } - impl QueryContext for UltraMinimal { + impl QueryContext for UltraMinimal { type Def = Def; - type Ref = !; + type Ref = R; } } diff --git a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs index cc6a4dce17b63..24e2a1acadd68 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs @@ -1,3 +1,5 @@ +extern crate test; + use itertools::Itertools; use super::query_context::test::{Def, UltraMinimal}; @@ -12,15 +14,25 @@ trait Representation { impl Representation for Tree { fn is_transmutable(src: Self, dst: Self, assume: Assume) -> Answer { - crate::maybe_transmutable::MaybeTransmutableQuery::new(src, dst, assume, UltraMinimal) - .answer() + crate::maybe_transmutable::MaybeTransmutableQuery::new( + src, + dst, + assume, + UltraMinimal::default(), + ) + .answer() } } impl Representation for Dfa { fn is_transmutable(src: Self, dst: Self, assume: Assume) -> Answer { - crate::maybe_transmutable::MaybeTransmutableQuery::new(src, dst, assume, UltraMinimal) - .answer() + crate::maybe_transmutable::MaybeTransmutableQuery::new( + src, + dst, + assume, + UltraMinimal::default(), + ) + .answer() } } @@ -89,6 +101,36 @@ mod safety { } } +mod size { + use super::*; + + #[test] + fn size() { + let small = Tree::number(1); + let large = Tree::number(2); + + for alignment in [false, true] { + for lifetimes in [false, true] { + for safety in [false, true] { + for validity in [false, true] { + let assume = Assume { alignment, lifetimes, safety, validity }; + assert_eq!( + is_transmutable(&small, &large, assume), + Answer::No(Reason::DstIsTooBig), + "assume: {assume:?}" + ); + assert_eq!( + is_transmutable(&large, &small, assume), + Answer::Yes, + "assume: {assume:?}" + ); + } + } + } + } + } +} + mod bool { use super::*; @@ -112,6 +154,27 @@ mod bool { ); } + #[test] + fn transmute_u8() { + let bool = &Tree::bool(); + let u8 = &Tree::u8(); + for (src, dst, assume_validity, answer) in [ + (bool, u8, false, Answer::Yes), + (bool, u8, true, Answer::Yes), + (u8, bool, false, Answer::No(Reason::DstIsBitIncompatible)), + (u8, bool, true, Answer::Yes), + ] { + assert_eq!( + is_transmutable( + src, + dst, + Assume { validity: assume_validity, ..Assume::default() } + ), + answer + ); + } + } + #[test] fn should_permit_validity_expansion_and_reject_contraction() { let b0 = layout::Tree::::from_bits(0); @@ -175,6 +238,62 @@ mod bool { } } +mod uninit { + use super::*; + + #[test] + fn size() { + let mu = Tree::uninit(); + let u8 = Tree::u8(); + + for alignment in [false, true] { + for lifetimes in [false, true] { + for safety in [false, true] { + for validity in [false, true] { + let assume = Assume { alignment, lifetimes, safety, validity }; + + let want = if validity { + Answer::Yes + } else { + Answer::No(Reason::DstIsBitIncompatible) + }; + + assert_eq!(is_transmutable(&mu, &u8, assume), want, "assume: {assume:?}"); + assert_eq!( + is_transmutable(&u8, &mu, assume), + Answer::Yes, + "assume: {assume:?}" + ); + } + } + } + } + } +} + +mod alt { + use super::*; + use crate::Answer; + + #[test] + fn should_permit_identity_transmutation() { + type Tree = layout::Tree; + + let x = Tree::Seq(vec![Tree::from_bits(0), Tree::from_bits(0)]); + let y = Tree::Seq(vec![Tree::bool(), Tree::from_bits(1)]); + let layout = Tree::Alt(vec![x, y]); + + let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new( + layout.clone(), + layout.clone(), + crate::Assume::default(), + UltraMinimal::default(), + ) + .answer(); + assert_eq!(answer, Answer::Yes, "layout:{:#?}", layout); + } +} + mod union { use super::*; @@ -203,3 +322,59 @@ mod union { assert_eq!(is_transmutable(&t, &u, Assume::default()), Answer::Yes); } } + +mod r#ref { + use super::*; + + #[test] + fn should_permit_identity_transmutation() { + type Tree = crate::layout::Tree; + + let layout = Tree::Seq(vec![Tree::from_bits(0), Tree::Ref([()])]); + + let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new( + layout.clone(), + layout, + Assume::default(), + UltraMinimal::default(), + ) + .answer(); + assert_eq!(answer, Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] })); + } +} + +mod benches { + use std::hint::black_box; + + use test::Bencher; + + use super::*; + + #[bench] + fn bench_dfa_from_tree(b: &mut Bencher) { + let num = Tree::number(8).prune(&|_| false); + let num = black_box(num); + + b.iter(|| { + let _ = black_box(Dfa::from_tree(num.clone())); + }) + } + + #[bench] + fn bench_transmute(b: &mut Bencher) { + let num = Tree::number(8).prune(&|_| false); + let dfa = black_box(Dfa::from_tree(num).unwrap()); + + b.iter(|| { + let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new( + dfa.clone(), + dfa.clone(), + Assume::default(), + UltraMinimal::default(), + ) + .answer(); + let answer = std::hint::black_box(answer); + assert_eq!(answer, Answer::Yes); + }) + } +}