Skip to content

Commit

Permalink
Add tag_for_variant query
Browse files Browse the repository at this point in the history
This query allows for sharing code between `rustc_const_eval` and
`rustc_transmutability`.
  • Loading branch information
jswrenn committed Mar 20, 2024
1 parent 9023f90 commit 9179e1b
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 94 deletions.
20 changes: 19 additions & 1 deletion compiler/rustc_const_eval/src/const_eval/eval_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use either::{Left, Right};
use rustc_hir::def::DefKind;
use rustc_middle::mir::interpret::{AllocId, ErrorHandled, InterpErrorInfo};
use rustc_middle::mir::{self, ConstAlloc, ConstValue};
use rustc_middle::query::TyCtxtAt;
use rustc_middle::query::{Key, TyCtxtAt};
use rustc_middle::traits::Reveal;
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::print::with_no_trimmed_paths;
Expand Down Expand Up @@ -243,6 +243,24 @@ pub(crate) fn turn_into_const_value<'tcx>(
op_to_const(&ecx, &mplace.into(), /* for diagnostics */ false)
}

/// Computes the tag (if any) for a given type and variant.
#[instrument(skip(tcx), level = "debug")]
pub fn tag_for_variant_provider<'tcx>(
tcx: TyCtxt<'tcx>,
(ty, variant_index): (Ty<'tcx>, abi::VariantIdx),
) -> Option<ty::ScalarInt> {
assert!(ty.is_enum());

let ecx = InterpCx::new(
tcx,
ty.default_span(tcx),
ty::ParamEnv::reveal_all(),
CompileTimeInterpreter::new(CanAccessMutGlobal::No, CheckAlignment::Error),
);

ecx.tag_for_variant(ty, variant_index).unwrap().value()
}

#[instrument(skip(tcx), level = "debug")]
pub fn eval_to_const_value_raw_provider<'tcx>(
tcx: TyCtxt<'tcx>,
Expand Down
174 changes: 106 additions & 68 deletions compiler/rustc_const_eval/src/interpret/discriminant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,28 @@
use rustc_middle::mir;
use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, ScalarInt, Ty};
use rustc_target::abi::{self, TagEncoding};
use rustc_target::abi::{VariantIdx, Variants};

use super::{ImmTy, InterpCx, InterpResult, Machine, Readable, Scalar, Writeable};

/// The tag of an enum discriminant.
pub(crate) enum Tag {
/// No tag; the variant is `Single`-encoded.
None,
/// The variant is tagged.
Tagged { tag: ScalarInt, tag_field: usize },
/// No tag; the variant is identified by its validity.
Untagged,
}

impl Tag {
pub(crate) fn value(self) -> Option<ScalarInt> {
if let Self::Tagged { tag, .. } = self { Some(tag) } else { None }
}
}

impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
/// Writes the discriminant of the given variant.
///
Expand All @@ -28,78 +44,28 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
throw_ub!(UninhabitedEnumVariantWritten(variant_index))
}

match dest.layout().variants {
abi::Variants::Single { index } => {
assert_eq!(index, variant_index);
}
abi::Variants::Multiple {
tag_encoding: TagEncoding::Direct,
tag: tag_layout,
tag_field,
..
} => {
let (tag, tag_field) = match self.tag_for_variant(dest.layout().ty, variant_index)? {
Tag::None => return Ok(()),
Tag::Tagged { tag, tag_field } => {
// No need to validate that the discriminant here because the
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.

let discr_val = dest
.layout()
.ty
.discriminant_for_variant(*self.tcx, variant_index)
.unwrap()
.val;

// raw discriminants for enums are isize or bigger during
// their computation, but the in-memory tag is the smallest possible
// representation
let size = tag_layout.size(self);
let tag_val = size.truncate(discr_val);

let tag_dest = self.project_field(dest, tag_field)?;
self.write_scalar(Scalar::from_uint(tag_val, size), &tag_dest)?;
// `TyAndLayout::for_variant()` call earlier already checks the
// variant is valid.
(tag, tag_field)
}
abi::Variants::Multiple {
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
tag: tag_layout,
tag_field,
..
} => {
// No need to validate that the discriminant here because the
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.

if variant_index != untagged_variant {
let variants_start = niche_variants.start().as_u32();
let variant_index_relative = variant_index
.as_u32()
.checked_sub(variants_start)
.expect("overflow computing relative variant idx");
// We need to use machine arithmetic when taking into account `niche_start`:
// tag_val = variant_index_relative + niche_start_val
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
let variant_index_relative_val =
ImmTy::from_uint(variant_index_relative, tag_layout);
let tag_val = self.wrapping_binary_op(
mir::BinOp::Add,
&variant_index_relative_val,
&niche_start_val,
)?;
// Write result.
let niche_dest = self.project_field(dest, tag_field)?;
self.write_immediate(*tag_val, &niche_dest)?;
} else {
// The untagged variant is implicitly encoded simply by having a value that is
// outside the niche variants. But what if the data stored here does not
// actually encode this variant? That would be bad! So let's double-check...
let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
if actual_variant != variant_index {
throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
}
Tag::Untagged => {
// The untagged variant is implicitly encoded simply by having a value that is
// outside the niche variants. But what if the data stored here does not
// actually encode this variant? That would be bad! So let's double-check...
let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
if actual_variant != variant_index {
throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
}
return Ok(());
}
}
};

Ok(())
let tag_dest = self.project_field(dest, tag_field)?;
self.write_scalar(tag, &tag_dest)
}

/// Read discriminant, return the runtime value as well as the variant index.
Expand Down Expand Up @@ -277,4 +243,76 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
};
Ok(ImmTy::from_scalar(discr_value, discr_layout))
}

/// Computes the tag (if any) of a given variant of type `ty`.
pub(crate) fn tag_for_variant(
&self,
ty: Ty<'tcx>,
variant_index: VariantIdx,
) -> InterpResult<'tcx, Tag> {
match self.layout_of(ty)?.variants {
abi::Variants::Single { index } => {
assert_eq!(index, variant_index);
Ok(Tag::None)
}

abi::Variants::Multiple {
tag_encoding: TagEncoding::Direct,
tag: tag_layout,
tag_field,
..
} => {
// raw discriminants for enums are isize or bigger during
// their computation, but the in-memory tag is the smallest possible
// representation
let discr = self.discriminant_for_variant(ty, variant_index)?;
let discr_size = discr.layout.size;
let discr_val = discr.to_scalar().to_bits(discr_size)?;
let tag_size = tag_layout.size(self);
let tag_val = tag_size.truncate(discr_val);
let tag = ScalarInt::try_from_uint(tag_val, tag_size).unwrap();
Ok(Tag::Tagged { tag, tag_field })
}

abi::Variants::Multiple {
tag_encoding: TagEncoding::Niche { untagged_variant, .. },
..
} if untagged_variant == variant_index => {
// The untagged variant is implicitly encoded simply by having a
// value that is outside the niche variants.
Ok(Tag::Untagged)
}

abi::Variants::Multiple {
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
tag: tag_layout,
tag_field,
..
} => {
assert!(variant_index != untagged_variant);
let variants_start = niche_variants.start().as_u32();
let variant_index_relative = variant_index
.as_u32()
.checked_sub(variants_start)
.expect("overflow computing relative variant idx");
// We need to use machine arithmetic when taking into account `niche_start`:
// tag_val = variant_index_relative + niche_start_val
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
let variant_index_relative_val =
ImmTy::from_uint(variant_index_relative, tag_layout);
let tag = self
.wrapping_binary_op(
mir::BinOp::Add,
&variant_index_relative_val,
&niche_start_val,
)?
.to_scalar()
.try_to_int()
.unwrap();
Ok(Tag::Tagged { tag, tag_field })
}
}
}
}
1 change: 1 addition & 0 deletions compiler/rustc_const_eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ rustc_fluent_macro::fluent_messages! { "../messages.ftl" }

pub fn provide(providers: &mut Providers) {
const_eval::provide(providers);
providers.tag_for_variant = const_eval::tag_for_variant_provider;
providers.eval_to_const_value_raw = const_eval::eval_to_const_value_raw_provider;
providers.eval_to_allocation_raw = const_eval::eval_to_allocation_raw_provider;
providers.eval_static_initializer = const_eval::eval_static_initializer_provider;
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/query/erase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ trivial! {
Option<rustc_middle::middle::stability::DeprecationEntry>,
Option<rustc_middle::ty::Destructor>,
Option<rustc_middle::ty::ImplTraitInTraitData>,
Option<rustc_middle::ty::ScalarInt>,
Option<rustc_span::def_id::CrateNum>,
Option<rustc_span::def_id::DefId>,
Option<rustc_span::def_id::LocalDefId>,
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use rustc_query_system::query::DefIdCacheSelector;
use rustc_query_system::query::{DefaultCacheSelector, SingleCacheSelector, VecCacheSelector};
use rustc_span::symbol::{Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi;

/// Placeholder for `CrateNum`'s "local" counterpart
#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -502,6 +503,14 @@ impl<'tcx> Key for (DefId, Ty<'tcx>, GenericArgsRef<'tcx>, ty::ParamEnv<'tcx>) {
}
}

impl<'tcx> Key for (Ty<'tcx>, abi::VariantIdx) {
type CacheSelector = DefaultCacheSelector<Self>;

fn default_span(&self, _tcx: TyCtxt<'_>) -> Span {
DUMMY_SP
}
}

impl<'tcx> Key for (ty::Predicate<'tcx>, traits::WellFormedLoc) {
type CacheSelector = DefaultCacheSelector<Self>;

Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,13 @@ rustc_queries! {
}
}

/// Computes the tag (if any) for a given type and variant.
query tag_for_variant(
key: (Ty<'tcx>, abi::VariantIdx)
) -> Option<ty::ScalarInt> {
desc { "computing variant tag for enum" }
}

/// Evaluates a constant and returns the computed allocation.
///
/// **Do not use this** directly, use the `eval_to_const_value` or `eval_to_valtree` instead.
Expand Down
44 changes: 19 additions & 25 deletions compiler/rustc_transmute/src/layout/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ pub(crate) mod rustc {
use crate::layout::rustc::{Def, Ref};

use rustc_middle::ty::layout::LayoutError;
use rustc_middle::ty::util::Discr;
use rustc_middle::ty::AdtDef;
use rustc_middle::ty::GenericArgsRef;
use rustc_middle::ty::ParamEnv;
use rustc_middle::ty::ScalarInt;
use rustc_middle::ty::VariantDef;
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::ErrorGuaranteed;
Expand Down Expand Up @@ -331,14 +331,15 @@ pub(crate) mod rustc {
trace!(?adt_def, "treeifying enum");
let mut tree = Tree::uninhabited();

for (idx, discr) in adt_def.discriminants(tcx) {
for (idx, variant) in adt_def.variants().iter_enumerated() {
let tag = tcx.tag_for_variant((ty, idx));
tree = tree.or(Self::from_repr_c_variant(
ty,
*adt_def,
args_ref,
&layout_summary,
Some(discr),
adt_def.variant(idx),
tag,
variant,
tcx,
)?);
}
Expand Down Expand Up @@ -393,7 +394,7 @@ pub(crate) mod rustc {
adt_def: AdtDef<'tcx>,
args_ref: GenericArgsRef<'tcx>,
layout_summary: &LayoutSummary,
discr: Option<Discr<'tcx>>,
tag: Option<ScalarInt>,
variant_def: &'tcx VariantDef,
tcx: TyCtxt<'tcx>,
) -> Result<Self, Err> {
Expand All @@ -403,9 +404,6 @@ pub(crate) mod rustc {
let min_align = repr.align.unwrap_or(Align::ONE);
let max_align = repr.pack.unwrap_or(Align::MAX);

let clamp =
|align: Align| align.clamp(min_align, max_align).bytes().try_into().unwrap();

let variant_span = trace_span!(
"treeifying variant",
min_align = ?min_align,
Expand All @@ -419,17 +417,12 @@ pub(crate) mod rustc {
)
.unwrap();

// The layout of the variant is prefixed by the discriminant, if any.
if let Some(discr) = discr {
trace!(?discr, "treeifying discriminant");
let discr_layout = alloc::Layout::from_size_align(
layout_summary.discriminant_size,
clamp(layout_summary.discriminant_align),
)
.unwrap();
trace!(?discr_layout, "computed discriminant layout");
variant_layout = variant_layout.extend(discr_layout).unwrap().0;
tree = tree.then(Self::from_discr(discr, tcx, layout_summary.discriminant_size));
// The layout of the variant is prefixed by the tag, if any.
if let Some(tag) = tag {
let tag_layout =
alloc::Layout::from_size_align(tag.size().bytes_usize(), 1).unwrap();
tree = tree.then(Self::from_tag(tag, tcx));
variant_layout = variant_layout.extend(tag_layout).unwrap().0;
}

// Next come fields.
Expand Down Expand Up @@ -469,18 +462,19 @@ pub(crate) mod rustc {
Ok(tree)
}

pub fn from_discr(discr: Discr<'tcx>, tcx: TyCtxt<'tcx>, size: usize) -> Self {
pub fn from_tag(tag: ScalarInt, tcx: TyCtxt<'tcx>) -> Self {
use rustc_target::abi::Endian;

let size = tag.size();
let bits = tag.to_bits(size).unwrap();
let bytes: [u8; 16];
let bytes = match tcx.data_layout.endian {
Endian::Little => {
bytes = discr.val.to_le_bytes();
&bytes[..size]
bytes = bits.to_le_bytes();
&bytes[..size.bytes_usize()]
}
Endian::Big => {
bytes = discr.val.to_be_bytes();
&bytes[bytes.len() - size..]
bytes = bits.to_be_bytes();
&bytes[bytes.len() - size.bytes_usize()..]
}
};
Self::Seq(bytes.iter().map(|&b| Self::from_bits(b)).collect())
Expand Down

0 comments on commit 9179e1b

Please sign in to comment.