summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_infer/src/infer/generalize.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_infer/src/infer/generalize.rs')
-rw-r--r--compiler/rustc_infer/src/infer/generalize.rs479
1 files changed, 479 insertions, 0 deletions
diff --git a/compiler/rustc_infer/src/infer/generalize.rs b/compiler/rustc_infer/src/infer/generalize.rs
new file mode 100644
index 000000000..d4a1dacde
--- /dev/null
+++ b/compiler/rustc_infer/src/infer/generalize.rs
@@ -0,0 +1,479 @@
+use rustc_data_structures::sso::SsoHashMap;
+use rustc_hir::def_id::DefId;
+use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue};
+use rustc_middle::ty::error::TypeError;
+use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
+use rustc_middle::ty::{self, InferConst, Term, Ty, TyCtxt, TypeVisitableExt};
+use rustc_span::Span;
+
+use crate::infer::nll_relate::TypeRelatingDelegate;
+use crate::infer::type_variable::TypeVariableValue;
+use crate::infer::{InferCtxt, RegionVariableOrigin};
+
+/// Attempts to generalize `term` for the type variable `for_vid`.
+/// This checks for cycles -- that is, whether the type `term`
+/// references `for_vid`.
+pub(super) fn generalize<'tcx, D: GeneralizerDelegate<'tcx>, T: Into<Term<'tcx>> + Relate<'tcx>>(
+ infcx: &InferCtxt<'tcx>,
+ delegate: &mut D,
+ term: T,
+ for_vid: impl Into<ty::TermVid<'tcx>>,
+ ambient_variance: ty::Variance,
+) -> RelateResult<'tcx, Generalization<T>> {
+ let (for_universe, root_vid) = match for_vid.into() {
+ ty::TermVid::Ty(ty_vid) => (
+ infcx.probe_ty_var(ty_vid).unwrap_err(),
+ ty::TermVid::Ty(infcx.inner.borrow_mut().type_variables().sub_root_var(ty_vid)),
+ ),
+ ty::TermVid::Const(ct_vid) => (
+ infcx.probe_const_var(ct_vid).unwrap_err(),
+ ty::TermVid::Const(infcx.inner.borrow_mut().const_unification_table().find(ct_vid)),
+ ),
+ };
+
+ let mut generalizer = Generalizer {
+ infcx,
+ delegate,
+ ambient_variance,
+ root_vid,
+ for_universe,
+ root_term: term.into(),
+ needs_wf: false,
+ cache: Default::default(),
+ };
+
+ assert!(!term.has_escaping_bound_vars());
+ let value = generalizer.relate(term, term)?;
+ let needs_wf = generalizer.needs_wf;
+ Ok(Generalization { value, needs_wf })
+}
+
+/// Abstracts the handling of region vars between HIR and MIR/NLL typechecking
+/// in the generalizer code.
+pub trait GeneralizerDelegate<'tcx> {
+ fn param_env(&self) -> ty::ParamEnv<'tcx>;
+
+ fn forbid_inference_vars() -> bool;
+
+ fn generalize_region(&mut self, universe: ty::UniverseIndex) -> ty::Region<'tcx>;
+}
+
+pub struct CombineDelegate<'cx, 'tcx> {
+ pub infcx: &'cx InferCtxt<'tcx>,
+ pub param_env: ty::ParamEnv<'tcx>,
+ pub span: Span,
+}
+
+impl<'tcx> GeneralizerDelegate<'tcx> for CombineDelegate<'_, 'tcx> {
+ fn param_env(&self) -> ty::ParamEnv<'tcx> {
+ self.param_env
+ }
+
+ fn forbid_inference_vars() -> bool {
+ false
+ }
+
+ fn generalize_region(&mut self, universe: ty::UniverseIndex) -> ty::Region<'tcx> {
+ // FIXME: This is non-ideal because we don't give a
+ // very descriptive origin for this region variable.
+ self.infcx
+ .next_region_var_in_universe(RegionVariableOrigin::MiscVariable(self.span), universe)
+ }
+}
+
+impl<'tcx, T> GeneralizerDelegate<'tcx> for T
+where
+ T: TypeRelatingDelegate<'tcx>,
+{
+ fn param_env(&self) -> ty::ParamEnv<'tcx> {
+ <Self as TypeRelatingDelegate<'tcx>>::param_env(self)
+ }
+
+ fn forbid_inference_vars() -> bool {
+ <Self as TypeRelatingDelegate<'tcx>>::forbid_inference_vars()
+ }
+
+ fn generalize_region(&mut self, universe: ty::UniverseIndex) -> ty::Region<'tcx> {
+ <Self as TypeRelatingDelegate<'tcx>>::generalize_existential(self, universe)
+ }
+}
+
+/// The "generalizer" is used when handling inference variables.
+///
+/// The basic strategy for handling a constraint like `?A <: B` is to
+/// apply a "generalization strategy" to the term `B` -- this replaces
+/// all the lifetimes in the term `B` with fresh inference variables.
+/// (You can read more about the strategy in this [blog post].)
+///
+/// As an example, if we had `?A <: &'x u32`, we would generalize `&'x
+/// u32` to `&'0 u32` where `'0` is a fresh variable. This becomes the
+/// value of `A`. Finally, we relate `&'0 u32 <: &'x u32`, which
+/// establishes `'0: 'x` as a constraint.
+///
+/// [blog post]: https://is.gd/0hKvIr
+struct Generalizer<'me, 'tcx, D> {
+ infcx: &'me InferCtxt<'tcx>,
+
+ /// This is used to abstract the behaviors of the three previous
+ /// generalizer-like implementations (`Generalizer`, `TypeGeneralizer`,
+ /// and `ConstInferUnifier`). See [`GeneralizerDelegate`] for more
+ /// information.
+ delegate: &'me mut D,
+
+ /// After we generalize this type, we are going to relate it to
+ /// some other type. What will be the variance at this point?
+ ambient_variance: ty::Variance,
+
+ /// The vid of the type variable that is in the process of being
+ /// instantiated. If we find this within the value we are folding,
+ /// that means we would have created a cyclic value.
+ root_vid: ty::TermVid<'tcx>,
+
+ /// The universe of the type variable that is in the process of being
+ /// instantiated. If we find anything that this universe cannot name,
+ /// we reject the relation.
+ for_universe: ty::UniverseIndex,
+
+ /// The root term (const or type) we're generalizing. Used for cycle errors.
+ root_term: Term<'tcx>,
+
+ cache: SsoHashMap<Ty<'tcx>, Ty<'tcx>>,
+
+ /// See the field `needs_wf` in `Generalization`.
+ needs_wf: bool,
+}
+
+impl<'tcx, D> Generalizer<'_, 'tcx, D> {
+ /// Create an error that corresponds to the term kind in `root_term`
+ fn cyclic_term_error(&self) -> TypeError<'tcx> {
+ match self.root_term.unpack() {
+ ty::TermKind::Ty(ty) => TypeError::CyclicTy(ty),
+ ty::TermKind::Const(ct) => TypeError::CyclicConst(ct),
+ }
+ }
+}
+
+impl<'tcx, D> TypeRelation<'tcx> for Generalizer<'_, 'tcx, D>
+where
+ D: GeneralizerDelegate<'tcx>,
+{
+ fn tcx(&self) -> TyCtxt<'tcx> {
+ self.infcx.tcx
+ }
+
+ fn param_env(&self) -> ty::ParamEnv<'tcx> {
+ self.delegate.param_env()
+ }
+
+ fn tag(&self) -> &'static str {
+ "Generalizer"
+ }
+
+ fn a_is_expected(&self) -> bool {
+ true
+ }
+
+ fn relate_item_substs(
+ &mut self,
+ item_def_id: DefId,
+ a_subst: ty::SubstsRef<'tcx>,
+ b_subst: ty::SubstsRef<'tcx>,
+ ) -> RelateResult<'tcx, ty::SubstsRef<'tcx>> {
+ if self.ambient_variance == ty::Variance::Invariant {
+ // Avoid fetching the variance if we are in an invariant
+ // context; no need, and it can induce dependency cycles
+ // (e.g., #41849).
+ relate::relate_substs(self, a_subst, b_subst)
+ } else {
+ let tcx = self.tcx();
+ let opt_variances = tcx.variances_of(item_def_id);
+ relate::relate_substs_with_variances(
+ self,
+ item_def_id,
+ opt_variances,
+ a_subst,
+ b_subst,
+ true,
+ )
+ }
+ }
+
+ #[instrument(level = "debug", skip(self, variance, b), ret)]
+ fn relate_with_variance<T: Relate<'tcx>>(
+ &mut self,
+ variance: ty::Variance,
+ _info: ty::VarianceDiagInfo<'tcx>,
+ a: T,
+ b: T,
+ ) -> RelateResult<'tcx, T> {
+ let old_ambient_variance = self.ambient_variance;
+ self.ambient_variance = self.ambient_variance.xform(variance);
+ debug!(?self.ambient_variance, "new ambient variance");
+ let r = self.relate(a, b)?;
+ self.ambient_variance = old_ambient_variance;
+ Ok(r)
+ }
+
+ #[instrument(level = "debug", skip(self, t2), ret)]
+ fn tys(&mut self, t: Ty<'tcx>, t2: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
+ assert_eq!(t, t2); // we are misusing TypeRelation here; both LHS and RHS ought to be ==
+
+ if let Some(&result) = self.cache.get(&t) {
+ return Ok(result);
+ }
+
+ // Check to see whether the type we are generalizing references
+ // any other type variable related to `vid` via
+ // subtyping. This is basically our "occurs check", preventing
+ // us from creating infinitely sized types.
+ let g = match *t.kind() {
+ ty::Infer(ty::TyVar(_)) | ty::Infer(ty::IntVar(_)) | ty::Infer(ty::FloatVar(_))
+ if D::forbid_inference_vars() =>
+ {
+ bug!("unexpected inference variable encountered in NLL generalization: {t}");
+ }
+
+ ty::Infer(ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_)) => {
+ bug!("unexpected infer type: {t}")
+ }
+
+ ty::Infer(ty::TyVar(vid)) => {
+ let mut inner = self.infcx.inner.borrow_mut();
+ let vid = inner.type_variables().root_var(vid);
+ let sub_vid = inner.type_variables().sub_root_var(vid);
+
+ if ty::TermVid::Ty(sub_vid) == self.root_vid {
+ // If sub-roots are equal, then `root_vid` and
+ // `vid` are related via subtyping.
+ Err(self.cyclic_term_error())
+ } else {
+ let probe = inner.type_variables().probe(vid);
+ match probe {
+ TypeVariableValue::Known { value: u } => {
+ drop(inner);
+ self.relate(u, u)
+ }
+ TypeVariableValue::Unknown { universe } => {
+ match self.ambient_variance {
+ // Invariant: no need to make a fresh type variable
+ // if we can name the universe.
+ ty::Invariant => {
+ if self.for_universe.can_name(universe) {
+ return Ok(t);
+ }
+ }
+
+ // Bivariant: make a fresh var, but we
+ // may need a WF predicate. See
+ // comment on `needs_wf` field for
+ // more info.
+ ty::Bivariant => self.needs_wf = true,
+
+ // Co/contravariant: this will be
+ // sufficiently constrained later on.
+ ty::Covariant | ty::Contravariant => (),
+ }
+
+ let origin = *inner.type_variables().var_origin(vid);
+ let new_var_id =
+ inner.type_variables().new_var(self.for_universe, origin);
+ let u = self.tcx().mk_ty_var(new_var_id);
+
+ // Record that we replaced `vid` with `new_var_id` as part of a generalization
+ // operation. This is needed to detect cyclic types. To see why, see the
+ // docs in the `type_variables` module.
+ inner.type_variables().sub(vid, new_var_id);
+ debug!("replacing original vid={:?} with new={:?}", vid, u);
+ Ok(u)
+ }
+ }
+ }
+ }
+
+ ty::Infer(ty::IntVar(_) | ty::FloatVar(_)) => {
+ // No matter what mode we are in,
+ // integer/floating-point types must be equal to be
+ // relatable.
+ Ok(t)
+ }
+
+ ty::Placeholder(placeholder) => {
+ if self.for_universe.can_name(placeholder.universe) {
+ Ok(t)
+ } else {
+ debug!(
+ "root universe {:?} cannot name placeholder in universe {:?}",
+ self.for_universe, placeholder.universe
+ );
+ Err(TypeError::Mismatch)
+ }
+ }
+
+ _ => relate::structurally_relate_tys(self, t, t),
+ }?;
+
+ self.cache.insert(t, g);
+ Ok(g)
+ }
+
+ #[instrument(level = "debug", skip(self, r2), ret)]
+ fn regions(
+ &mut self,
+ r: ty::Region<'tcx>,
+ r2: ty::Region<'tcx>,
+ ) -> RelateResult<'tcx, ty::Region<'tcx>> {
+ assert_eq!(r, r2); // we are misusing TypeRelation here; both LHS and RHS ought to be ==
+
+ match *r {
+ // Never make variables for regions bound within the type itself,
+ // nor for erased regions.
+ ty::ReLateBound(..) | ty::ReErased => {
+ return Ok(r);
+ }
+
+ // It doesn't really matter for correctness if we generalize ReError,
+ // since we're already on a doomed compilation path.
+ ty::ReError(_) => {
+ return Ok(r);
+ }
+
+ ty::RePlaceholder(..)
+ | ty::ReVar(..)
+ | ty::ReStatic
+ | ty::ReEarlyBound(..)
+ | ty::ReFree(..) => {
+ // see common code below
+ }
+ }
+
+ // If we are in an invariant context, we can re-use the region
+ // as is, unless it happens to be in some universe that we
+ // can't name.
+ if let ty::Invariant = self.ambient_variance {
+ let r_universe = self.infcx.universe_of_region(r);
+ if self.for_universe.can_name(r_universe) {
+ return Ok(r);
+ }
+ }
+
+ Ok(self.delegate.generalize_region(self.for_universe))
+ }
+
+ #[instrument(level = "debug", skip(self, c2), ret)]
+ fn consts(
+ &mut self,
+ c: ty::Const<'tcx>,
+ c2: ty::Const<'tcx>,
+ ) -> RelateResult<'tcx, ty::Const<'tcx>> {
+ assert_eq!(c, c2); // we are misusing TypeRelation here; both LHS and RHS ought to be ==
+
+ match c.kind() {
+ ty::ConstKind::Infer(InferConst::Var(_)) if D::forbid_inference_vars() => {
+ bug!("unexpected inference variable encountered in NLL generalization: {:?}", c);
+ }
+ ty::ConstKind::Infer(InferConst::Var(vid)) => {
+ // If root const vids are equal, then `root_vid` and
+ // `vid` are related and we'd be inferring an infinitely
+ // deep const.
+ if ty::TermVid::Const(
+ self.infcx.inner.borrow_mut().const_unification_table().find(vid),
+ ) == self.root_vid
+ {
+ return Err(self.cyclic_term_error());
+ }
+
+ let mut inner = self.infcx.inner.borrow_mut();
+ let variable_table = &mut inner.const_unification_table();
+ let var_value = variable_table.probe_value(vid);
+ match var_value.val {
+ ConstVariableValue::Known { value: u } => {
+ drop(inner);
+ self.relate(u, u)
+ }
+ ConstVariableValue::Unknown { universe } => {
+ if self.for_universe.can_name(universe) {
+ Ok(c)
+ } else {
+ let new_var_id = variable_table.new_key(ConstVarValue {
+ origin: var_value.origin,
+ val: ConstVariableValue::Unknown { universe: self.for_universe },
+ });
+ Ok(self.tcx().mk_const(new_var_id, c.ty()))
+ }
+ }
+ }
+ }
+ // FIXME: remove this branch once `structurally_relate_consts` is fully
+ // structural.
+ ty::ConstKind::Unevaluated(ty::UnevaluatedConst { def, substs }) => {
+ let substs = self.relate_with_variance(
+ ty::Variance::Invariant,
+ ty::VarianceDiagInfo::default(),
+ substs,
+ substs,
+ )?;
+ Ok(self.tcx().mk_const(ty::UnevaluatedConst { def, substs }, c.ty()))
+ }
+ ty::ConstKind::Placeholder(placeholder) => {
+ if self.for_universe.can_name(placeholder.universe) {
+ Ok(c)
+ } else {
+ debug!(
+ "root universe {:?} cannot name placeholder in universe {:?}",
+ self.for_universe, placeholder.universe
+ );
+ Err(TypeError::Mismatch)
+ }
+ }
+ _ => relate::structurally_relate_consts(self, c, c),
+ }
+ }
+
+ #[instrument(level = "debug", skip(self), ret)]
+ fn binders<T>(
+ &mut self,
+ a: ty::Binder<'tcx, T>,
+ _: ty::Binder<'tcx, T>,
+ ) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
+ where
+ T: Relate<'tcx>,
+ {
+ let result = self.relate(a.skip_binder(), a.skip_binder())?;
+ Ok(a.rebind(result))
+ }
+}
+
+/// Result from a generalization operation. This includes
+/// not only the generalized type, but also a bool flag
+/// indicating whether further WF checks are needed.
+#[derive(Debug)]
+pub struct Generalization<T> {
+ pub value: T,
+
+ /// If true, then the generalized type may not be well-formed,
+ /// even if the source type is well-formed, so we should add an
+ /// additional check to enforce that it is. This arises in
+ /// particular around 'bivariant' type parameters that are only
+ /// constrained by a where-clause. As an example, imagine a type:
+ ///
+ /// struct Foo<A, B> where A: Iterator<Item = B> {
+ /// data: A
+ /// }
+ ///
+ /// here, `A` will be covariant, but `B` is
+ /// unconstrained. However, whatever it is, for `Foo` to be WF, it
+ /// must be equal to `A::Item`. If we have an input `Foo<?A, ?B>`,
+ /// then after generalization we will wind up with a type like
+ /// `Foo<?C, ?D>`. When we enforce that `Foo<?A, ?B> <: Foo<?C,
+ /// ?D>` (or `>:`), we will wind up with the requirement that `?A
+ /// <: ?C`, but no particular relationship between `?B` and `?D`
+ /// (after all, we do not know the variance of the normalized form
+ /// of `A::Item` with respect to `A`). If we do nothing else, this
+ /// may mean that `?D` goes unconstrained (as in #41677). So, in
+ /// this scenario where we create a new type variable in a
+ /// bivariant context, we set the `needs_wf` flag to true. This
+ /// will force the calling code to check that `WF(Foo<?C, ?D>)`
+ /// holds, which in turn implies that `?C::Item == ?D`. So once
+ /// `?C` is constrained, that should suffice to restrict `?D`.
+ pub needs_wf: bool,
+}