diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs index 26ecaebe97f2d..0b543f091f730 100644 --- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs +++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs @@ -493,6 +493,10 @@ impl<'cx, 'tcx> TypeFolder> for Canonicalizer<'cx, 'tcx> { ct } } + + fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { + if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p } + } } impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { diff --git a/compiler/rustc_next_trait_solver/src/canonicalizer.rs b/compiler/rustc_next_trait_solver/src/canonicalizer.rs index 93b8940ee37d8..addeb3e2b78e0 100644 --- a/compiler/rustc_next_trait_solver/src/canonicalizer.rs +++ b/compiler/rustc_next_trait_solver/src/canonicalizer.rs @@ -4,12 +4,22 @@ use rustc_type_ir::data_structures::{HashMap, ensure_sufficient_stack}; use rustc_type_ir::inherent::*; use rustc_type_ir::solve::{Goal, QueryInput}; use rustc_type_ir::{ - self as ty, Canonical, CanonicalTyVarKind, CanonicalVarKind, InferCtxtLike, Interner, - TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, + self as ty, Canonical, CanonicalTyVarKind, CanonicalVarKind, Flags, InferCtxtLike, Interner, + TypeFlags, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, }; use crate::delegate::SolverDelegate; +/// Does this have infer/placeholder/param, free regions or ReErased? +const NEEDS_CANONICAL: TypeFlags = TypeFlags::from_bits( + TypeFlags::HAS_INFER.bits() + | TypeFlags::HAS_PLACEHOLDER.bits() + | TypeFlags::HAS_PARAM.bits() + | TypeFlags::HAS_FREE_REGIONS.bits() + | TypeFlags::HAS_RE_ERASED.bits(), +) +.unwrap(); + /// Whether we're canonicalizing a query input or the query response. /// /// When canonicalizing an input we're in the context of the caller @@ -79,7 +89,11 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { cache: Default::default(), }; - let value = value.fold_with(&mut canonicalizer); + let value = if value.has_type_flags(NEEDS_CANONICAL) { + value.fold_with(&mut canonicalizer) + } else { + value + }; assert!(!value.has_infer(), "unexpected infer in {value:?}"); assert!(!value.has_placeholders(), "unexpected placeholders in {value:?}"); let (max_universe, variables) = canonicalizer.finalize(); @@ -111,7 +125,14 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { cache: Default::default(), }; - let param_env = input.goal.param_env.fold_with(&mut env_canonicalizer); + + let param_env = input.goal.param_env; + let param_env = if param_env.has_type_flags(NEEDS_CANONICAL) { + param_env.fold_with(&mut env_canonicalizer) + } else { + param_env + }; + debug_assert_eq!(env_canonicalizer.binder_index, ty::INNERMOST); // Then canonicalize the rest of the input without keeping `'static` // while *mostly* reusing the canonicalizer from above. @@ -134,10 +155,22 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { cache: Default::default(), }; - let predicate = input.goal.predicate.fold_with(&mut rest_canonicalizer); + let predicate = input.goal.predicate; + let predicate = if predicate.has_type_flags(NEEDS_CANONICAL) { + predicate.fold_with(&mut rest_canonicalizer) + } else { + predicate + }; let goal = Goal { param_env, predicate }; + + let predefined_opaques_in_body = input.predefined_opaques_in_body; let predefined_opaques_in_body = - input.predefined_opaques_in_body.fold_with(&mut rest_canonicalizer); + if input.predefined_opaques_in_body.has_type_flags(NEEDS_CANONICAL) { + predefined_opaques_in_body.fold_with(&mut rest_canonicalizer) + } else { + predefined_opaques_in_body + }; + let value = QueryInput { goal, predefined_opaques_in_body }; assert!(!value.has_infer(), "unexpected infer in {value:?}"); @@ -387,7 +420,11 @@ impl<'a, D: SolverDelegate, I: Interner> Canonicalizer<'a, D, I> { | ty::Alias(_, _) | ty::Bound(_, _) | ty::Error(_) => { - return ensure_sufficient_stack(|| t.super_fold_with(self)); + return if t.has_type_flags(NEEDS_CANONICAL) { + ensure_sufficient_stack(|| t.super_fold_with(self)) + } else { + t + }; } }; @@ -522,11 +559,17 @@ impl, I: Interner> TypeFolder for Canonicaliz | ty::ConstKind::Unevaluated(_) | ty::ConstKind::Value(_) | ty::ConstKind::Error(_) - | ty::ConstKind::Expr(_) => return c.super_fold_with(self), + | ty::ConstKind::Expr(_) => { + return if c.has_type_flags(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c }; + } }; let var = self.get_or_insert_bound_var(c, kind); Const::new_anon_bound(self.cx(), self.binder_index, var) } + + fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { + if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p } + } } diff --git a/compiler/rustc_next_trait_solver/src/resolve.rs b/compiler/rustc_next_trait_solver/src/resolve.rs index 992c5ddf504ef..39abec2d7d8db 100644 --- a/compiler/rustc_next_trait_solver/src/resolve.rs +++ b/compiler/rustc_next_trait_solver/src/resolve.rs @@ -86,4 +86,8 @@ impl, I: Interner> TypeFolder for EagerResolv } } } + + fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { + if p.has_infer() { p.super_fold_with(self) } else { p } + } }