Skip to content

Commit

Permalink
Auto merge of #126541 - scottmcm:more-ptr-metadata-gvn, r=cjgillot
Browse files Browse the repository at this point in the history
More ptr metadata gvn

There's basically 3 parts to this PR.

1. Allow references as arguments to `UnOp::PtrMetadata`

This is a MIR semantics addition, so
r? mir

Rather than just raw pointers, also allow references to be passed to `PtrMetadata`.  That means the length of a slice can be just `PtrMetadata(_1)` instead of also needing a ref-to-pointer statement (`_2 = &raw *_1` + `PtrMetadata(_2)`).

AFAIK there should be no provenance or tagging implications of looking at the *metadata* of a pointer, and the code in the backends actually already supported it (other than a debug assert, given that they don't care about ptr vs reference, really), so we might as well allow it.

2. Simplify the argument to `PtrMetadata` in GVN

Because the specific kind of pointer-like thing isn't that important, GVN can simplify all those details away.  Things like `*const`-to-`*mut` casts and `&mut`-to-`&` reborrows are irrelevant, and skipping them lets it see more interesting things.

cc `@cjgillot`

Notably, unsizing casts for arrays.  GVN supported that for `Len`, and now it sees it for `PtrMetadata` as well, allowing `PtrMetadata(pointer)` to become a constant if that pointer came from an array-to-slice unsizing, even through a bunch of other possible steps.

3. Replace `NormalizeArrayLen` with GVN

The `NormalizeArrayLen` pass hasn't been running even in optimized builds for well over a year, and it turns out that GVN -- which *is* on in optimized builds -- can do everything it was trying to do.

So the code for the pass is deleted, but the tests are kept, just changed to the different pass.

As part of this, `LowerSliceLen` was changed to emit `PtrMetadata(_1)` instead of `Len(*_1)`, a small step on the road to eventually eliminating `Rvalue::Len`.
  • Loading branch information
bors committed Jun 21, 2024
2 parents 4e6de37 + 55d1337 commit e32ea48
Show file tree
Hide file tree
Showing 51 changed files with 622 additions and 323 deletions.
4 changes: 3 additions & 1 deletion compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
(OperandValue::Immediate(llval), operand.layout)
}
mir::UnOp::PtrMetadata => {
debug_assert!(operand.layout.ty.is_unsafe_ptr());
debug_assert!(
operand.layout.ty.is_unsafe_ptr() || operand.layout.ty.is_ref(),
);
let (_, meta) = operand.val.pointer_parts();
assert_eq!(operand.layout.fields.count() > 1, meta.is_some());
if let Some(meta) = meta {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_const_eval/src/interpret/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let res = ScalarInt::truncate_from_uint(res, layout.size).0;
Ok(ImmTy::from_scalar(res.into(), layout))
}
ty::RawPtr(..) => {
ty::RawPtr(..) | ty::Ref(..) => {
assert_eq!(un_op, PtrMetadata);
let (_, meta) = val.to_scalar_and_meta();
Ok(match meta {
Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_middle/src/mir/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1446,10 +1446,12 @@ pub enum UnOp {
Not,
/// The `-` operator for negation
Neg,
/// Get the metadata `M` from a `*const/mut impl Pointee<Metadata = M>`.
/// Gets the metadata `M` from a `*const`/`*mut`/`&`/`&mut` to
/// `impl Pointee<Metadata = M>`.
///
/// For example, this will give a `()` from `*const i32`, a `usize` from
/// `*mut [u8]`, or a pointer to a vtable from a `*const dyn Foo`.
/// `&mut [u8]`, or a `ptr::DynMetadata<dyn Foo>` (internally a pointer)
/// from a `*mut dyn Foo`.
///
/// Allowed only in [`MirPhase::Runtime`]; earlier it's an intrinsic.
PtrMetadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
let source = self.parse_operand(args[0])?;
Ok(Rvalue::Cast(CastKind::Transmute, source, expr.ty))
},
@call(mir_cast_ptr_to_ptr, args) => {
let source = self.parse_operand(args[0])?;
Ok(Rvalue::Cast(CastKind::PtrToPtr, source, expr.ty))
},
@call(mir_checked, args) => {
parse_by_kind!(self, args[0], _, "binary op",
ExprKind::Binary { op, lhs, rhs } => {
Expand Down
134 changes: 105 additions & 29 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -836,12 +836,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
Value::BinaryOp(op, lhs, rhs)
}
Rvalue::UnaryOp(op, ref mut arg) => {
let arg = self.simplify_operand(arg, location)?;
if let Some(value) = self.simplify_unary(op, arg) {
return Some(value);
}
Value::UnaryOp(op, arg)
Rvalue::UnaryOp(op, ref mut arg_op) => {
return self.simplify_unary(op, arg_op, location);
}
Rvalue::Discriminant(ref mut place) => {
let place = self.simplify_place_value(place, location)?;
Expand Down Expand Up @@ -949,13 +945,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
was_updated = true;
}

if was_updated {
if let Some(const_) = self.try_as_constant(fields[0]) {
field_ops[FieldIdx::ZERO] = Operand::Constant(Box::new(const_));
} else if let Some(local) = self.try_as_local(fields[0], location) {
field_ops[FieldIdx::ZERO] = Operand::Copy(Place::from(local));
self.reused_locals.insert(local);
}
if was_updated && let Some(op) = self.try_as_operand(fields[0], location) {
field_ops[FieldIdx::ZERO] = op;
}
}

Expand All @@ -965,11 +956,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
let first = fields[0];
if fields.iter().all(|&v| v == first) {
let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap());
if let Some(const_) = self.try_as_constant(first) {
*rvalue = Rvalue::Repeat(Operand::Constant(Box::new(const_)), len);
} else if let Some(local) = self.try_as_local(first, location) {
*rvalue = Rvalue::Repeat(Operand::Copy(local.into()), len);
self.reused_locals.insert(local);
if let Some(op) = self.try_as_operand(first, location) {
*rvalue = Rvalue::Repeat(op, len);
}
return Some(self.insert(Value::Repeat(first, len)));
}
Expand All @@ -979,8 +967,71 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}

#[instrument(level = "trace", skip(self), ret)]
fn simplify_unary(&mut self, op: UnOp, value: VnIndex) -> Option<VnIndex> {
let value = match (op, self.get(value)) {
fn simplify_unary(
&mut self,
op: UnOp,
arg_op: &mut Operand<'tcx>,
location: Location,
) -> Option<VnIndex> {
let mut arg_index = self.simplify_operand(arg_op, location)?;

// PtrMetadata doesn't care about *const vs *mut vs & vs &mut,
// so start by removing those distinctions so we can update the `Operand`
if op == UnOp::PtrMetadata {
let mut was_updated = false;
loop {
match self.get(arg_index) {
// Pointer casts that preserve metadata, such as
// `*const [i32]` <-> `*mut [i32]` <-> `*mut [f32]`.
// It's critical that this not eliminate cases like
// `*const [T]` -> `*const T` which remove metadata.
// We run on potentially-generic MIR, though, so unlike codegen
// we can't always know exactly what the metadata are.
// Thankfully, equality on `ptr_metadata_ty_or_tail` gives us
// what we need: `Ok(meta_ty)` if the metadata is known, or
// `Err(tail_ty)` if not. Matching metadata is ok, but if
// that's not known, then matching tail types is also ok,
// allowing things like `*mut (?A, ?T)` <-> `*mut (?B, ?T)`.
// FIXME: Would it be worth trying to normalize, rather than
// passing the identity closure? Or are the types in the
// Cast realistically about as normalized as we can get anyway?
Value::Cast { kind: CastKind::PtrToPtr, value: inner, from, to }
if from
.builtin_deref(true)
.unwrap()
.ptr_metadata_ty_or_tail(self.tcx, |t| t)
== to
.builtin_deref(true)
.unwrap()
.ptr_metadata_ty_or_tail(self.tcx, |t| t) =>
{
arg_index = *inner;
was_updated = true;
continue;
}

// `&mut *p`, `&raw *p`, etc don't change metadata.
Value::Address { place, kind: _, provenance: _ }
if let PlaceRef { local, projection: [PlaceElem::Deref] } =
place.as_ref()
&& let Some(local_index) = self.locals[local] =>
{
arg_index = local_index;
was_updated = true;
continue;
}

_ => {
if was_updated && let Some(op) = self.try_as_operand(arg_index, location) {
*arg_op = op;
}
break;
}
}
}
}

let value = match (op, self.get(arg_index)) {
(UnOp::Not, Value::UnaryOp(UnOp::Not, inner)) => return Some(*inner),
(UnOp::Neg, Value::UnaryOp(UnOp::Neg, inner)) => return Some(*inner),
(UnOp::Not, Value::BinaryOp(BinOp::Eq, lhs, rhs)) => {
Expand All @@ -992,9 +1043,26 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
(UnOp::PtrMetadata, Value::Aggregate(AggregateTy::RawPtr { .. }, _, fields)) => {
return Some(fields[1]);
}
_ => return None,
// We have an unsizing cast, which assigns the length to fat pointer metadata.
(
UnOp::PtrMetadata,
Value::Cast {
kind: CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize),
from,
to,
..
},
) if let ty::Slice(..) = to.builtin_deref(true).unwrap().kind()
&& let ty::Array(_, len) = from.builtin_deref(true).unwrap().kind() =>
{
return self.insert_constant(Const::from_ty_const(
*len,
self.tcx.types.usize,
self.tcx,
));
}
_ => Value::UnaryOp(op, arg_index),
};

Some(self.insert(value))
}

Expand Down Expand Up @@ -1174,13 +1242,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
}

if was_updated {
if let Some(const_) = self.try_as_constant(value) {
*operand = Operand::Constant(Box::new(const_));
} else if let Some(local) = self.try_as_local(value, location) {
*operand = Operand::Copy(local.into());
self.reused_locals.insert(local);
}
if was_updated && let Some(op) = self.try_as_operand(value, location) {
*operand = op;
}

Some(self.insert(Value::Cast { kind: *kind, value, from, to }))
Expand Down Expand Up @@ -1296,6 +1359,19 @@ fn op_to_prop_const<'tcx>(
}

impl<'tcx> VnState<'_, 'tcx> {
/// If either [`Self::try_as_constant`] as [`Self::try_as_local`] succeeds,
/// returns that result as an [`Operand`].
fn try_as_operand(&mut self, index: VnIndex, location: Location) -> Option<Operand<'tcx>> {
if let Some(const_) = self.try_as_constant(index) {
Some(Operand::Constant(Box::new(const_)))
} else if let Some(local) = self.try_as_local(index, location) {
self.reused_locals.insert(local);
Some(Operand::Copy(local.into()))
} else {
None
}
}

/// If `index` is a `Value::Constant`, return the `Constant` to be put in the MIR.
fn try_as_constant(&mut self, index: VnIndex) -> Option<ConstOperand<'tcx>> {
// This was already constant in MIR, do not change it.
Expand Down
4 changes: 0 additions & 4 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ mod lower_slice_len;
mod match_branches;
mod mentioned_items;
mod multiple_return_terminators;
mod normalize_array_len;
mod nrvo;
mod prettify;
mod promote_consts;
Expand Down Expand Up @@ -581,9 +580,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&o1(simplify::SimplifyCfg::AfterUnreachableEnumBranching),
// Inlining may have introduced a lot of redundant code and a large move pattern.
// Now, we need to shrink the generated MIR.

// Has to run after `slice::len` lowering
&normalize_array_len::NormalizeArrayLen,
&ref_prop::ReferencePropagation,
&sroa::ScalarReplacementOfAggregates,
&match_branches::MatchBranchSimplification,
Expand Down
24 changes: 8 additions & 16 deletions compiler/rustc_mir_transform/src/lower_slice_len.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
//! This pass lowers calls to core::slice::len to just Len op.
//! This pass lowers calls to core::slice::len to just PtrMetadata op.
//! It should run before inlining!
use rustc_hir::def_id::DefId;
use rustc_index::IndexSlice;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, TyCtxt};
use rustc_middle::ty::TyCtxt;

pub struct LowerSliceLenCalls;

Expand All @@ -29,16 +28,11 @@ pub fn lower_slice_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let basic_blocks = body.basic_blocks.as_mut_preserves_cfg();
for block in basic_blocks {
// lower `<[_]>::len` calls
lower_slice_len_call(tcx, block, &body.local_decls, slice_len_fn_item_def_id);
lower_slice_len_call(block, slice_len_fn_item_def_id);
}
}

fn lower_slice_len_call<'tcx>(
tcx: TyCtxt<'tcx>,
block: &mut BasicBlockData<'tcx>,
local_decls: &IndexSlice<Local, LocalDecl<'tcx>>,
slice_len_fn_item_def_id: DefId,
) {
fn lower_slice_len_call<'tcx>(block: &mut BasicBlockData<'tcx>, slice_len_fn_item_def_id: DefId) {
let terminator = block.terminator();
if let TerminatorKind::Call {
func,
Expand All @@ -50,19 +44,17 @@ fn lower_slice_len_call<'tcx>(
} = &terminator.kind
// some heuristics for fast rejection
&& let [arg] = &args[..]
&& let Some(arg) = arg.node.place()
&& let ty::FnDef(fn_def_id, _) = func.ty(local_decls, tcx).kind()
&& *fn_def_id == slice_len_fn_item_def_id
&& let Some((fn_def_id, _)) = func.const_fn_def()
&& fn_def_id == slice_len_fn_item_def_id
{
// perform modifications from something like:
// _5 = core::slice::<impl [u8]>::len(move _6) -> bb1
// into:
// _5 = Len(*_6)
// _5 = PtrMetadata(move _6)
// goto bb1

// make new RValue for Len
let deref_arg = tcx.mk_place_deref(arg);
let r_value = Rvalue::Len(deref_arg);
let r_value = Rvalue::UnaryOp(UnOp::PtrMetadata, arg.node.clone());
let len_statement_kind = StatementKind::Assign(Box::new((*destination, r_value)));
let add_statement =
Statement { kind: len_statement_kind, source_info: terminator.source_info };
Expand Down
Loading

0 comments on commit e32ea48

Please sign in to comment.