From d0dea8323eb9b884e5eb7619e17f53989d566527 Mon Sep 17 00:00:00 2001 From: Kamalesh Palanisamy Date: Sun, 14 Jan 2024 00:36:27 -0500 Subject: [PATCH] Modify GenericArg and Term structs to use strict provenance rules --- compiler/rustc_middle/src/ty/generic_args.rs | 42 +++++++++++++++----- compiler/rustc_middle/src/ty/mod.rs | 32 +++++++++++---- 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_middle/src/ty/generic_args.rs b/compiler/rustc_middle/src/ty/generic_args.rs index 63f4bab7914ed..5a4b86eeffd76 100644 --- a/compiler/rustc_middle/src/ty/generic_args.rs +++ b/compiler/rustc_middle/src/ty/generic_args.rs @@ -20,6 +20,7 @@ use std::marker::PhantomData; use std::mem; use std::num::NonZeroUsize; use std::ops::{ControlFlow, Deref}; +use std::ptr::NonNull; /// An entity in the Rust type system, which can be one of /// several kinds (types, lifetimes, and consts). @@ -31,10 +32,29 @@ use std::ops::{ControlFlow, Deref}; /// `Region` and `Const` are all interned. #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct GenericArg<'tcx> { - ptr: NonZeroUsize, + ptr: NonNull<()>, marker: PhantomData<(Ty<'tcx>, ty::Region<'tcx>, ty::Const<'tcx>)>, } +#[cfg(parallel_compiler)] +unsafe impl<'tcx> rustc_data_structures::sync::DynSend for GenericArg<'tcx> where + &'tcx (Ty<'tcx>, ty::Region<'tcx>, ty::Const<'tcx>): rustc_data_structures::sync::DynSend +{ +} +#[cfg(parallel_compiler)] +unsafe impl<'tcx> rustc_data_structures::sync::DynSync for GenericArg<'tcx> where + &'tcx (Ty<'tcx>, ty::Region<'tcx>, ty::Const<'tcx>): rustc_data_structures::sync::DynSync +{ +} +unsafe impl<'tcx> Send for GenericArg<'tcx> where + &'tcx (Ty<'tcx>, ty::Region<'tcx>, ty::Const<'tcx>): Send +{ +} +unsafe impl<'tcx> Sync for GenericArg<'tcx> where + &'tcx (Ty<'tcx>, ty::Region<'tcx>, ty::Const<'tcx>): Sync +{ +} + impl<'tcx> IntoDiagnosticArg for GenericArg<'tcx> { fn into_diagnostic_arg(self) -> DiagnosticArgValue<'static> { self.to_string().into_diagnostic_arg() @@ -60,21 +80,21 @@ impl<'tcx> GenericArgKind<'tcx> { GenericArgKind::Lifetime(lt) => { // Ensure we can use the tag bits. assert_eq!(mem::align_of_val(&*lt.0.0) & TAG_MASK, 0); - (REGION_TAG, lt.0.0 as *const ty::RegionKind<'tcx> as usize) + (REGION_TAG, NonNull::from(lt.0.0).cast()) } GenericArgKind::Type(ty) => { // Ensure we can use the tag bits. assert_eq!(mem::align_of_val(&*ty.0.0) & TAG_MASK, 0); - (TYPE_TAG, ty.0.0 as *const WithCachedTypeInfo> as usize) + (TYPE_TAG, NonNull::from(ty.0.0).cast()) } GenericArgKind::Const(ct) => { // Ensure we can use the tag bits. assert_eq!(mem::align_of_val(&*ct.0.0) & TAG_MASK, 0); - (CONST_TAG, ct.0.0 as *const WithCachedTypeInfo> as usize) + (CONST_TAG, NonNull::from(ct.0.0).cast()) } }; - GenericArg { ptr: unsafe { NonZeroUsize::new_unchecked(ptr | tag) }, marker: PhantomData } + GenericArg { ptr: ptr.map_addr(|addr| addr | tag), marker: PhantomData } } } @@ -123,20 +143,22 @@ impl<'tcx> From> for GenericArg<'tcx> { impl<'tcx> GenericArg<'tcx> { #[inline] pub fn unpack(self) -> GenericArgKind<'tcx> { - let ptr = self.ptr.get(); + let ptr = unsafe { + self.ptr.map_addr(|addr| NonZeroUsize::new_unchecked(addr.get() & !TAG_MASK)) + }; // SAFETY: use of `Interned::new_unchecked` here is ok because these // pointers were originally created from `Interned` types in `pack()`, // and this is just going in the other direction. unsafe { - match ptr & TAG_MASK { + match self.ptr.addr().get() & TAG_MASK { REGION_TAG => GenericArgKind::Lifetime(ty::Region(Interned::new_unchecked( - &*((ptr & !TAG_MASK) as *const ty::RegionKind<'tcx>), + ptr.cast::>().as_ref(), ))), TYPE_TAG => GenericArgKind::Type(Ty(Interned::new_unchecked( - &*((ptr & !TAG_MASK) as *const WithCachedTypeInfo>), + ptr.cast::>>().as_ref(), ))), CONST_TAG => GenericArgKind::Const(ty::Const(Interned::new_unchecked( - &*((ptr & !TAG_MASK) as *const WithCachedTypeInfo>), + ptr.cast::>>().as_ref(), ))), _ => intrinsics::unreachable(), } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index ad9296a4cc88a..fc22b8422c0b4 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -63,6 +63,7 @@ use std::marker::PhantomData; use std::mem; use std::num::NonZeroUsize; use std::ops::ControlFlow; +use std::ptr::NonNull; use std::{fmt, str}; pub use crate::ty::diagnostics::*; @@ -848,10 +849,23 @@ pub type PolyCoercePredicate<'tcx> = ty::Binder<'tcx, CoercePredicate<'tcx>>; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Term<'tcx> { - ptr: NonZeroUsize, + ptr: NonNull<()>, marker: PhantomData<(Ty<'tcx>, Const<'tcx>)>, } +#[cfg(parallel_compiler)] +unsafe impl<'tcx> rustc_data_structures::sync::DynSend for Term<'tcx> where + &'tcx (Ty<'tcx>, Const<'tcx>): rustc_data_structures::sync::DynSend +{ +} +#[cfg(parallel_compiler)] +unsafe impl<'tcx> rustc_data_structures::sync::DynSync for Term<'tcx> where + &'tcx (Ty<'tcx>, Const<'tcx>): rustc_data_structures::sync::DynSync +{ +} +unsafe impl<'tcx> Send for Term<'tcx> where &'tcx (Ty<'tcx>, Const<'tcx>): Send {} +unsafe impl<'tcx> Sync for Term<'tcx> where &'tcx (Ty<'tcx>, Const<'tcx>): Sync {} + impl Debug for Term<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let data = if let Some(ty) = self.ty() { @@ -914,17 +928,19 @@ impl<'tcx, D: TyDecoder>> Decodable for Term<'tcx> { impl<'tcx> Term<'tcx> { #[inline] pub fn unpack(self) -> TermKind<'tcx> { - let ptr = self.ptr.get(); + let ptr = unsafe { + self.ptr.map_addr(|addr| NonZeroUsize::new_unchecked(addr.get() & !TAG_MASK)) + }; // SAFETY: use of `Interned::new_unchecked` here is ok because these // pointers were originally created from `Interned` types in `pack()`, // and this is just going in the other direction. unsafe { - match ptr & TAG_MASK { + match self.ptr.addr().get() & TAG_MASK { TYPE_TAG => TermKind::Ty(Ty(Interned::new_unchecked( - &*((ptr & !TAG_MASK) as *const WithCachedTypeInfo>), + ptr.cast::>>().as_ref(), ))), CONST_TAG => TermKind::Const(ty::Const(Interned::new_unchecked( - &*((ptr & !TAG_MASK) as *const WithCachedTypeInfo>), + ptr.cast::>>().as_ref(), ))), _ => core::intrinsics::unreachable(), } @@ -986,16 +1002,16 @@ impl<'tcx> TermKind<'tcx> { TermKind::Ty(ty) => { // Ensure we can use the tag bits. assert_eq!(mem::align_of_val(&*ty.0.0) & TAG_MASK, 0); - (TYPE_TAG, ty.0.0 as *const WithCachedTypeInfo> as usize) + (TYPE_TAG, NonNull::from(ty.0.0).cast()) } TermKind::Const(ct) => { // Ensure we can use the tag bits. assert_eq!(mem::align_of_val(&*ct.0.0) & TAG_MASK, 0); - (CONST_TAG, ct.0.0 as *const WithCachedTypeInfo> as usize) + (CONST_TAG, NonNull::from(ct.0.0).cast()) } }; - Term { ptr: unsafe { NonZeroUsize::new_unchecked(ptr | tag) }, marker: PhantomData } + Term { ptr: ptr.map_addr(|addr| addr | tag), marker: PhantomData } } }