Skip to content

Commit

Permalink
Auto merge of #105750 - oli-obk:valtrees, r=lcnr
Browse files Browse the repository at this point in the history
Always fall back to PartialEq when a constant in a pattern is not recursively structural-eq

Right now we destructure the constant as far as we can, but with this PR we just don't take it apart anymore. This is preparatory work for moving to always using valtrees, as these will just do a single conversion of the constant to a valtree at the start, and if that fails, fall back to `PartialEq`.

This removes a few cases where we emitted the `unreachable pattern` lint, because we stop looking into the constant deeply enough to detect that a constant is already covered by another pattern.

Previous work: #70743

This is groundwork towards fixing #83085 and #105047
  • Loading branch information
bors committed May 16, 2023
2 parents a673ad6 + 2282258 commit 9239760
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 94 deletions.
49 changes: 39 additions & 10 deletions compiler/rustc_mir_build/src/build/matches/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,19 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
);
}

/// Compare two `&T` values using `<T as std::compare::PartialEq>::eq`
/// Compare two values using `<T as std::compare::PartialEq>::eq`.
/// If the values are already references, just call it directly, otherwise
/// take a reference to the values first and then call it.
fn non_scalar_compare(
&mut self,
block: BasicBlock,
make_target_blocks: impl FnOnce(&mut Self) -> Vec<BasicBlock>,
source_info: SourceInfo,
value: ConstantKind<'tcx>,
place: Place<'tcx>,
mut val: Place<'tcx>,
mut ty: Ty<'tcx>,
) {
let mut expect = self.literal_operand(source_info.span, value);
let mut val = Operand::Copy(place);

// If we're using `b"..."` as a pattern, we need to insert an
// unsizing coercion, as the byte string has the type `&[u8; N]`.
Expand Down Expand Up @@ -421,9 +422,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
block,
source_info,
temp,
Rvalue::Cast(CastKind::Pointer(PointerCast::Unsize), val, ty),
Rvalue::Cast(
CastKind::Pointer(PointerCast::Unsize),
Operand::Copy(val),
ty,
),
);
val = Operand::Move(temp);
val = temp;
}
if opt_ref_test_ty.is_some() {
let slice = self.temp(ty, source_info.span);
Expand All @@ -438,12 +443,36 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
}

let ty::Ref(_, deref_ty, _) = *ty.kind() else {
bug!("non_scalar_compare called on non-reference type: {}", ty);
};
match *ty.kind() {
ty::Ref(_, deref_ty, _) => ty = deref_ty,
_ => {
// non_scalar_compare called on non-reference type
let temp = self.temp(ty, source_info.span);
self.cfg.push_assign(block, source_info, temp, Rvalue::Use(expect));
let ref_ty = self.tcx.mk_imm_ref(self.tcx.lifetimes.re_erased, ty);
let ref_temp = self.temp(ref_ty, source_info.span);

self.cfg.push_assign(
block,
source_info,
ref_temp,
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, temp),
);
expect = Operand::Move(ref_temp);

let ref_temp = self.temp(ref_ty, source_info.span);
self.cfg.push_assign(
block,
source_info,
ref_temp,
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, val),
);
val = ref_temp;
}
}

let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, Some(source_info.span));
let method = trait_method(self.tcx, eq_def_id, sym::eq, [deref_ty, deref_ty]);
let method = trait_method(self.tcx, eq_def_id, sym::eq, [ty, ty]);

let bool_ty = self.tcx.types.bool;
let eq_result = self.temp(bool_ty, source_info.span);
Expand All @@ -463,7 +492,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {

literal: method,
})),
args: vec![val, expect],
args: vec![Operand::Copy(val), expect],
destination: eq_result,
target: Some(eq_block),
unwind: UnwindAction::Continue,
Expand Down
65 changes: 25 additions & 40 deletions compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,13 @@ struct ConstToPat<'tcx> {
treat_byte_string_as_slice: bool,
}

mod fallback_to_const_ref {
#[derive(Debug)]
/// This error type signals that we encountered a non-struct-eq situation behind a reference.
/// We bubble this up in order to get back to the reference destructuring and make that emit
/// a const pattern instead of a deref pattern. This allows us to simply call `PartialEq::eq`
/// on such patterns (since that function takes a reference) and not have to jump through any
/// hoops to get a reference to the value.
pub(super) struct FallbackToConstRef(());

pub(super) fn fallback_to_const_ref(c2p: &super::ConstToPat<'_>) -> FallbackToConstRef {
assert!(c2p.behind_reference.get());
FallbackToConstRef(())
}
}
use fallback_to_const_ref::{fallback_to_const_ref, FallbackToConstRef};
/// This error type signals that we encountered a non-struct-eq situation.
/// We bubble this up in order to get back to the reference destructuring and make that emit
/// a const pattern instead of a deref pattern. This allows us to simply call `PartialEq::eq`
/// on such patterns (since that function takes a reference) and not have to jump through any
/// hoops to get a reference to the value.
#[derive(Debug)]
struct FallbackToConstRef;

impl<'tcx> ConstToPat<'tcx> {
fn new(
Expand Down Expand Up @@ -236,13 +228,13 @@ impl<'tcx> ConstToPat<'tcx> {

let kind = match cv.ty().kind() {
ty::Float(_) => {
tcx.emit_spanned_lint(
lint::builtin::ILLEGAL_FLOATING_POINT_LITERAL_PATTERN,
id,
span,
FloatPattern,
);
PatKind::Constant { value: cv }
tcx.emit_spanned_lint(
lint::builtin::ILLEGAL_FLOATING_POINT_LITERAL_PATTERN,
id,
span,
FloatPattern,
);
return Err(FallbackToConstRef);
}
ty::Adt(adt_def, _) if adt_def.is_union() => {
// Matching on union fields is unsafe, we can't hide it in constants
Expand Down Expand Up @@ -289,7 +281,7 @@ impl<'tcx> ConstToPat<'tcx> {
// Since we are behind a reference, we can just bubble the error up so we get a
// constant at reference type, making it easy to let the fallback call
// `PartialEq::eq` on it.
return Err(fallback_to_const_ref(self));
return Err(FallbackToConstRef);
}
ty::Adt(adt_def, _) if !self.type_marked_structural(cv.ty()) => {
debug!(
Expand Down Expand Up @@ -393,11 +385,11 @@ impl<'tcx> ConstToPat<'tcx> {
self.behind_reference.set(old);
val
}
// Backwards compatibility hack: support references to non-structural types.
// We'll lower
// this pattern to a `PartialEq::eq` comparison and `PartialEq::eq` takes a
// reference. This makes the rest of the matching logic simpler as it doesn't have
// to figure out how to get a reference again.
// Backwards compatibility hack: support references to non-structural types,
// but hard error if we aren't behind a double reference. We could just use
// the fallback code path below, but that would allow *more* of this fishy
// code to compile, as then it only goes through the future incompat lint
// instead of a hard error.
ty::Adt(_, _) if !self.type_marked_structural(*pointee_ty) => {
if self.behind_reference.get() {
if !self.saw_const_match_error.get()
Expand All @@ -411,7 +403,7 @@ impl<'tcx> ConstToPat<'tcx> {
IndirectStructuralMatch { non_sm_ty: *pointee_ty },
);
}
PatKind::Constant { value: cv }
return Err(FallbackToConstRef);
} else {
if !self.saw_const_match_error.get() {
self.saw_const_match_error.set(true);
Expand All @@ -435,24 +427,17 @@ impl<'tcx> ConstToPat<'tcx> {
PatKind::Wild
} else {
let old = self.behind_reference.replace(true);
// In case there are structural-match violations somewhere in this subpattern,
// we fall back to a const pattern. If we do not do this, we may end up with
// a !structural-match constant that is not of reference type, which makes it
// very hard to invoke `PartialEq::eq` on it as a fallback.
let val = match self.recur(tcx.deref_mir_constant(self.param_env.and(cv)), false) {
Ok(subpattern) => PatKind::Deref { subpattern },
Err(_) => PatKind::Constant { value: cv },
};
let subpattern = self.recur(tcx.deref_mir_constant(self.param_env.and(cv)), false)?;
self.behind_reference.set(old);
val
PatKind::Deref { subpattern }
}
}
},
ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::FnDef(..) => {
PatKind::Constant { value: cv }
}
ty::RawPtr(pointee) if pointee.ty.is_sized(tcx, param_env) => {
PatKind::Constant { value: cv }
return Err(FallbackToConstRef);
}
// FIXME: these can have very surprising behaviour where optimization levels or other
// compilation choices change the runtime behaviour of the match.
Expand All @@ -469,7 +454,7 @@ impl<'tcx> ConstToPat<'tcx> {
PointerPattern
);
}
PatKind::Constant { value: cv }
return Err(FallbackToConstRef);
}
_ => {
self.saw_const_match_error.set(true);
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,8 @@ impl<'tcx> Constructor<'tcx> {
}

/// Faster version of `is_covered_by` when applied to many constructors. `used_ctors` is
/// assumed to be built from `matrix.head_ctors()` with wildcards filtered out, and `self` is
/// assumed to have been split from a wildcard.
/// assumed to be built from `matrix.head_ctors()` with wildcards and opaques filtered out,
/// and `self` is assumed to have been split from a wildcard.
fn is_covered_by_any<'p>(
&self,
pcx: &PatCtxt<'_, 'p, 'tcx>,
Expand Down Expand Up @@ -894,7 +894,7 @@ impl<'tcx> Constructor<'tcx> {
/// in `to_ctors`: in some cases we only return `Missing`.
#[derive(Debug)]
pub(super) struct SplitWildcard<'tcx> {
/// Constructors seen in the matrix.
/// Constructors (other than wildcards and opaques) seen in the matrix.
matrix_ctors: Vec<Constructor<'tcx>>,
/// All the constructors for this type
all_ctors: SmallVec<[Constructor<'tcx>; 1]>,
Expand Down Expand Up @@ -1037,7 +1037,7 @@ impl<'tcx> SplitWildcard<'tcx> {
// Since `all_ctors` never contains wildcards, this won't recurse further.
self.all_ctors =
self.all_ctors.iter().flat_map(|ctor| ctor.split(pcx, ctors.clone())).collect();
self.matrix_ctors = ctors.filter(|c| !c.is_wildcard()).cloned().collect();
self.matrix_ctors = ctors.filter(|c| !matches!(c, Wildcard | Opaque)).cloned().collect();
}

/// Whether there are any value constructors for this type that are not present in the matrix.
Expand Down
16 changes: 16 additions & 0 deletions compiler/rustc_mir_build/src/thir/pattern/usefulness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,22 @@
//!
//! The details are not necessary to understand this file, so we explain them in
//! [`super::deconstruct_pat`]. Splitting is done by the [`Constructor::split`] function.
//!
//! # Constants in patterns
//!
//! There are two kinds of constants in patterns:
//!
//! * literals (`1`, `true`, `"foo"`)
//! * named or inline consts (`FOO`, `const { 5 + 6 }`)
//!
//! The latter are converted into other patterns with literals at the leaves. For example
//! `const_to_pat(const { [1, 2, 3] })` becomes an `Array(vec![Const(1), Const(2), Const(3)])`
//! pattern. This gets problematic when comparing the constant via `==` would behave differently
//! from matching on the constant converted to a pattern. Situations like that can occur, when
//! the user implements `PartialEq` manually, and thus could make `==` behave arbitrarily different.
//! In order to honor the `==` implementation, constants of types that implement `PartialEq` manually
//! stay as a full constant and become an `Opaque` pattern. These `Opaque` patterns do not participate
//! in exhaustiveness, specialization or overlap checking.
use self::ArmType::*;
use self::Usefulness::*;
Expand Down
18 changes: 13 additions & 5 deletions tests/ui/pattern/usefulness/consts-opaque.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ const BAR: Bar = Bar;
#[derive(PartialEq)]
enum Baz {
Baz1,
Baz2
Baz2,
}
impl Eq for Baz {}
const BAZ: Baz = Baz::Baz1;

#[rustfmt::skip]
fn main() {
match FOO {
FOO => {}
Expand Down Expand Up @@ -124,8 +125,16 @@ fn main() {

match WRAPQUUX {
Wrap(_) => {}
WRAPQUUX => {} // detected unreachable because we do inspect the `Wrap` layer
//~^ ERROR unreachable pattern
WRAPQUUX => {}
}

match WRAPQUUX {
Wrap(_) => {}
}

match WRAPQUUX {
//~^ ERROR: non-exhaustive patterns: `Wrap(_)` not covered
WRAPQUUX => {}
}

#[derive(PartialEq, Eq)]
Expand All @@ -138,8 +147,7 @@ fn main() {
match WHOKNOWSQUUX {
WHOKNOWSQUUX => {}
WhoKnows::Yay(_) => {}
WHOKNOWSQUUX => {} // detected unreachable because we do inspect the `WhoKnows` layer
//~^ ERROR unreachable pattern
WHOKNOWSQUUX => {}
WhoKnows::Nope => {}
}
}
Loading

0 comments on commit 9239760

Please sign in to comment.