diff --git a/library/core/src/alloc/layout.rs b/library/core/src/alloc/layout.rs index 59ebe5fbe0227..3473ac09e956f 100644 --- a/library/core/src/alloc/layout.rs +++ b/library/core/src/alloc/layout.rs @@ -72,9 +72,8 @@ impl Layout { Layout::from_size_valid_align(size, unsafe { ValidAlign::new_unchecked(align) }) } - /// Internal helper constructor to skip revalidating alignment validity. - #[inline] - const fn from_size_valid_align(size: usize, align: ValidAlign) -> Result { + #[inline(always)] + const fn max_size_for_align(align: ValidAlign) -> usize { // (power-of-two implies align != 0.) // Rounded up size is: @@ -89,7 +88,13 @@ impl Layout { // // Above implies that checking for summation overflow is both // necessary and sufficient. - if size > isize::MAX as usize - (align.as_nonzero().get() - 1) { + isize::MAX as usize - (align.as_usize() - 1) + } + + /// Internal helper constructor to skip revalidating alignment validity. + #[inline] + const fn from_size_valid_align(size: usize, align: ValidAlign) -> Result { + if size > Self::max_size_for_align(align) { return Err(LayoutError); } @@ -128,7 +133,7 @@ impl Layout { without modifying the layout"] #[inline] pub const fn align(&self) -> usize { - self.align.as_nonzero().get() + self.align.as_usize() } /// Constructs a `Layout` suitable for holding a value of type `T`. @@ -410,13 +415,33 @@ impl Layout { /// Creates a layout describing the record for a `[T; n]`. /// - /// On arithmetic overflow, returns `LayoutError`. + /// On arithmetic overflow or when the total size would exceed + /// `isize::MAX`, returns `LayoutError`. #[stable(feature = "alloc_layout_manipulation", since = "1.44.0")] #[inline] pub fn array(n: usize) -> Result { - let array_size = mem::size_of::().checked_mul(n).ok_or(LayoutError)?; - // The safe constructor is called here to enforce the isize size limit. - Layout::from_size_valid_align(array_size, ValidAlign::of::()) + // Reduce the amount of code we need to monomorphize per `T`. + return inner(mem::size_of::(), ValidAlign::of::(), n); + + #[inline] + fn inner(element_size: usize, align: ValidAlign, n: usize) -> Result { + // We need to check two things about the size: + // - That the total size won't overflow a `usize`, and + // - That the total size still fits in an `isize`. + // By using division we can check them both with a single threshold. + // That'd usually be a bad idea, but thankfully here the element size + // and alignment are constants, so the compiler will fold all of it. + if element_size != 0 && n > Layout::max_size_for_align(align) / element_size { + return Err(LayoutError); + } + + let array_size = element_size * n; + + // SAFETY: We just checked above that the `array_size` will not + // exceed `isize::MAX` even when rounded up to the alignment. + // And `ValidAlign` guarantees it's a power of two. + unsafe { Ok(Layout::from_size_align_unchecked(array_size, align.as_usize())) } + } } } diff --git a/library/core/src/mem/valid_align.rs b/library/core/src/mem/valid_align.rs index 4ce6d13cf9027..b9ccc0b4c799f 100644 --- a/library/core/src/mem/valid_align.rs +++ b/library/core/src/mem/valid_align.rs @@ -35,10 +35,15 @@ impl ValidAlign { unsafe { mem::transmute::(align) } } + #[inline] + pub(crate) const fn as_usize(self) -> usize { + self.0 as usize + } + #[inline] pub(crate) const fn as_nonzero(self) -> NonZeroUsize { // SAFETY: All the discriminants are non-zero. - unsafe { NonZeroUsize::new_unchecked(self.0 as usize) } + unsafe { NonZeroUsize::new_unchecked(self.as_usize()) } } /// Returns the base 2 logarithm of the alignment. diff --git a/library/core/tests/alloc.rs b/library/core/tests/alloc.rs index 8a5a06b3440f8..3ceaeadcec6c3 100644 --- a/library/core/tests/alloc.rs +++ b/library/core/tests/alloc.rs @@ -1,4 +1,5 @@ use core::alloc::Layout; +use core::mem::size_of; use core::ptr::{self, NonNull}; #[test] @@ -12,6 +13,49 @@ fn const_unchecked_layout() { assert_eq!(Some(DANGLING), NonNull::new(ptr::invalid_mut(ALIGN))); } +#[test] +fn layout_round_up_to_align_edge_cases() { + const MAX_SIZE: usize = isize::MAX as usize; + + for shift in 0..usize::BITS { + let align = 1_usize << shift; + let edge = (MAX_SIZE + 1) - align; + let low = edge.saturating_sub(10); + let high = edge.saturating_add(10); + assert!(Layout::from_size_align(low, align).is_ok()); + assert!(Layout::from_size_align(high, align).is_err()); + for size in low..=high { + assert_eq!( + Layout::from_size_align(size, align).is_ok(), + size.next_multiple_of(align) <= MAX_SIZE, + ); + } + } +} + +#[test] +fn layout_array_edge_cases() { + for_type::(); + for_type::<[i32; 0b10101]>(); + for_type::<[u8; 0b1010101]>(); + + // Make sure ZSTs don't lead to divide-by-zero + assert_eq!(Layout::array::<()>(usize::MAX).unwrap(), Layout::from_size_align(0, 1).unwrap()); + + fn for_type() { + const MAX_SIZE: usize = isize::MAX as usize; + + let edge = (MAX_SIZE + 1) / size_of::(); + let low = edge.saturating_sub(10); + let high = edge.saturating_add(10); + assert!(Layout::array::(low).is_ok()); + assert!(Layout::array::(high).is_err()); + for n in low..=high { + assert_eq!(Layout::array::(n).is_ok(), n * size_of::() <= MAX_SIZE); + } + } +} + #[test] fn layout_debug_shows_log2_of_alignment() { // `Debug` is not stable, but here's what it does right now diff --git a/src/test/codegen/layout-size-checks.rs b/src/test/codegen/layout-size-checks.rs new file mode 100644 index 0000000000000..d067cc10a948c --- /dev/null +++ b/src/test/codegen/layout-size-checks.rs @@ -0,0 +1,31 @@ +// compile-flags: -O +// only-x86_64 +// ignore-debug: the debug assertions get in the way + +#![crate_type = "lib"] + +use std::alloc::Layout; + +type RGB48 = [u16; 3]; + +// CHECK-LABEL: @layout_array_rgb48 +#[no_mangle] +pub fn layout_array_rgb48(n: usize) -> Layout { + // CHECK-NOT: llvm.umul.with.overflow.i64 + // CHECK: icmp ugt i64 %n, 1537228672809129301 + // CHECK-NOT: llvm.umul.with.overflow.i64 + // CHECK: mul nuw nsw i64 %n, 6 + // CHECK-NOT: llvm.umul.with.overflow.i64 + Layout::array::(n).unwrap() +} + +// CHECK-LABEL: @layout_array_i32 +#[no_mangle] +pub fn layout_array_i32(n: usize) -> Layout { + // CHECK-NOT: llvm.umul.with.overflow.i64 + // CHECK: icmp ugt i64 %n, 2305843009213693951 + // CHECK-NOT: llvm.umul.with.overflow.i64 + // CHECK: shl nuw nsw i64 %n, 2 + // CHECK-NOT: llvm.umul.with.overflow.i64 + Layout::array::(n).unwrap() +}