Skip to content

Commit

Permalink
Rollup merge of rust-lang#75644 - c410-f3r:array, r=yaahc
Browse files Browse the repository at this point in the history
Add 'core::array::from_fn' and 'core::array::try_from_fn'

These auxiliary methods fill uninitialized arrays in a safe way and are particularly useful for elements that don't implement `Default`.

```rust
// Foo doesn't implement Default
struct Foo(usize);

let _array = core::array::from_fn::<_, _, 2>(|idx| Foo(idx));
```

Different from `FromIterator`, it is guaranteed that the array will be fully filled and no error regarding uninitialized state will be throw. In certain scenarios, however, the creation of an **element** can fail and that is why the `try_from_fn` function is also provided.

```rust
#[derive(Debug, PartialEq)]
enum SomeError {
    Foo,
}

let array = core::array::try_from_fn(|i| Ok::<_, SomeError>(i));
assert_eq!(array, Ok([0, 1, 2, 3, 4]));

let another_array = core::array::try_from_fn(|_| Err(SomeError::Foo));
assert_eq!(another_array, Err(SomeError::Foo));
 ```
  • Loading branch information
GuillaumeGomez authored Oct 9, 2021
2 parents bb918d0 + 85c4a52 commit 1069985
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 19 deletions.
120 changes: 103 additions & 17 deletions library/core/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,69 @@ mod iter;
#[stable(feature = "array_value_iter", since = "1.51.0")]
pub use iter::IntoIter;

/// Creates an array `[T; N]` where each array element `T` is returned by the `cb` call.
///
/// # Arguments
///
/// * `cb`: Callback where the passed argument is the current array index.
///
/// # Example
///
/// ```rust
/// #![feature(array_from_fn)]
///
/// let array = core::array::from_fn(|i| i);
/// assert_eq!(array, [0, 1, 2, 3, 4]);
/// ```
#[inline]
#[unstable(feature = "array_from_fn", issue = "89379")]
pub fn from_fn<F, T, const N: usize>(mut cb: F) -> [T; N]
where
F: FnMut(usize) -> T,
{
let mut idx = 0;
[(); N].map(|_| {
let res = cb(idx);
idx += 1;
res
})
}

/// Creates an array `[T; N]` where each fallible array element `T` is returned by the `cb` call.
/// Unlike `core::array::from_fn`, where the element creation can't fail, this version will return an error
/// if any element creation was unsuccessful.
///
/// # Arguments
///
/// * `cb`: Callback where the passed argument is the current array index.
///
/// # Example
///
/// ```rust
/// #![feature(array_from_fn)]
///
/// #[derive(Debug, PartialEq)]
/// enum SomeError {
/// Foo,
/// }
///
/// let array = core::array::try_from_fn(|i| Ok::<_, SomeError>(i));
/// assert_eq!(array, Ok([0, 1, 2, 3, 4]));
///
/// let another_array = core::array::try_from_fn::<SomeError, _, (), 2>(|_| Err(SomeError::Foo));
/// assert_eq!(another_array, Err(SomeError::Foo));
/// ```
#[inline]
#[unstable(feature = "array_from_fn", issue = "89379")]
pub fn try_from_fn<E, F, T, const N: usize>(cb: F) -> Result<[T; N], E>
where
F: FnMut(usize) -> Result<T, E>,
{
// SAFETY: we know for certain that this iterator will yield exactly `N`
// items.
unsafe { collect_into_array_rslt_unchecked(&mut (0..N).map(cb)) }
}

/// Converts a reference to `T` into a reference to an array of length 1 (without copying).
#[stable(feature = "array_from_ref", since = "1.53.0")]
pub fn from_ref<T>(s: &T) -> &[T; 1] {
Expand Down Expand Up @@ -448,13 +511,15 @@ impl<T, const N: usize> [T; N] {
///
/// It is up to the caller to guarantee that `iter` yields at least `N` items.
/// Violating this condition causes undefined behavior.
unsafe fn collect_into_array_unchecked<I, const N: usize>(iter: &mut I) -> [I::Item; N]
unsafe fn collect_into_array_rslt_unchecked<E, I, T, const N: usize>(
iter: &mut I,
) -> Result<[T; N], E>
where
// Note: `TrustedLen` here is somewhat of an experiment. This is just an
// internal function, so feel free to remove if this bound turns out to be a
// bad idea. In that case, remember to also remove the lower bound
// `debug_assert!` below!
I: Iterator + TrustedLen,
I: Iterator<Item = Result<T, E>> + TrustedLen,
{
debug_assert!(N <= iter.size_hint().1.unwrap_or(usize::MAX));
debug_assert!(N <= iter.size_hint().0);
Expand All @@ -463,6 +528,21 @@ where
unsafe { collect_into_array(iter).unwrap_unchecked() }
}

// Infallible version of `collect_into_array_rslt_unchecked`.
unsafe fn collect_into_array_unchecked<I, const N: usize>(iter: &mut I) -> [I::Item; N]
where
I: Iterator + TrustedLen,
{
let mut map = iter.map(Ok::<_, Infallible>);

// SAFETY: The same safety considerations w.r.t. the iterator length
// apply for `collect_into_array_rslt_unchecked` as for
// `collect_into_array_unchecked`
match unsafe { collect_into_array_rslt_unchecked(&mut map) } {
Ok(array) => array,
}
}

/// Pulls `N` items from `iter` and returns them as an array. If the iterator
/// yields fewer than `N` items, `None` is returned and all already yielded
/// items are dropped.
Expand All @@ -473,43 +553,49 @@ where
///
/// If `iter.next()` panicks, all items already yielded by the iterator are
/// dropped.
fn collect_into_array<I, const N: usize>(iter: &mut I) -> Option<[I::Item; N]>
fn collect_into_array<E, I, T, const N: usize>(iter: &mut I) -> Option<Result<[T; N], E>>
where
I: Iterator,
I: Iterator<Item = Result<T, E>>,
{
if N == 0 {
// SAFETY: An empty array is always inhabited and has no validity invariants.
return unsafe { Some(mem::zeroed()) };
return unsafe { Some(Ok(mem::zeroed())) };
}

struct Guard<T, const N: usize> {
ptr: *mut T,
struct Guard<'a, T, const N: usize> {
array_mut: &'a mut [MaybeUninit<T>; N],
initialized: usize,
}

impl<T, const N: usize> Drop for Guard<T, N> {
impl<T, const N: usize> Drop for Guard<'_, T, N> {
fn drop(&mut self) {
debug_assert!(self.initialized <= N);

let initialized_part = crate::ptr::slice_from_raw_parts_mut(self.ptr, self.initialized);

// SAFETY: this raw slice will contain only initialized objects.
// SAFETY: this slice will contain only initialized objects.
unsafe {
crate::ptr::drop_in_place(initialized_part);
crate::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
&mut self.array_mut.get_unchecked_mut(..self.initialized),
));
}
}
}

let mut array = MaybeUninit::uninit_array::<N>();
let mut guard: Guard<_, N> =
Guard { ptr: MaybeUninit::slice_as_mut_ptr(&mut array), initialized: 0 };
let mut guard = Guard { array_mut: &mut array, initialized: 0 };

while let Some(item_rslt) = iter.next() {
let item = match item_rslt {
Err(err) => {
return Some(Err(err));
}
Ok(elem) => elem,
};

while let Some(item) = iter.next() {
// SAFETY: `guard.initialized` starts at 0, is increased by one in the
// loop and the loop is aborted once it reaches N (which is
// `array.len()`).
unsafe {
array.get_unchecked_mut(guard.initialized).write(item);
guard.array_mut.get_unchecked_mut(guard.initialized).write(item);
}
guard.initialized += 1;

Expand All @@ -520,7 +606,7 @@ where
// SAFETY: the condition above asserts that all elements are
// initialized.
let out = unsafe { MaybeUninit::array_assume_init(array) };
return Some(out);
return Some(Ok(out));
}
}

Expand Down
84 changes: 82 additions & 2 deletions library/core/tests/array.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::array;
use core::convert::TryFrom;
use core::sync::atomic::{AtomicUsize, Ordering};

#[test]
fn array_from_ref() {
Expand Down Expand Up @@ -303,8 +304,6 @@ fn array_map() {
#[test]
#[should_panic(expected = "test succeeded")]
fn array_map_drop_safety() {
use core::sync::atomic::AtomicUsize;
use core::sync::atomic::Ordering;
static DROPPED: AtomicUsize = AtomicUsize::new(0);
struct DropCounter;
impl Drop for DropCounter {
Expand Down Expand Up @@ -356,3 +355,84 @@ fn cell_allows_array_cycle() {
b3.a[0].set(Some(&b1));
b3.a[1].set(Some(&b2));
}

#[test]
fn array_from_fn() {
let array = core::array::from_fn(|idx| idx);
assert_eq!(array, [0, 1, 2, 3, 4]);
}

#[test]
fn array_try_from_fn() {
#[derive(Debug, PartialEq)]
enum SomeError {
Foo,
}

let array = core::array::try_from_fn(|i| Ok::<_, SomeError>(i));
assert_eq!(array, Ok([0, 1, 2, 3, 4]));

let another_array = core::array::try_from_fn::<SomeError, _, (), 2>(|_| Err(SomeError::Foo));
assert_eq!(another_array, Err(SomeError::Foo));
}

#[cfg(not(panic = "abort"))]
#[test]
fn array_try_from_fn_drops_inserted_elements_on_err() {
static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);

struct CountDrop;
impl Drop for CountDrop {
fn drop(&mut self) {
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
}
}

let _ = catch_unwind_silent(move || {
let _: Result<[CountDrop; 4], ()> = core::array::try_from_fn(|idx| {
if idx == 2 {
return Err(());
}
Ok(CountDrop)
});
});

assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
}

#[cfg(not(panic = "abort"))]
#[test]
fn array_try_from_fn_drops_inserted_elements_on_panic() {
static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);

struct CountDrop;
impl Drop for CountDrop {
fn drop(&mut self) {
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
}
}

let _ = catch_unwind_silent(move || {
let _: Result<[CountDrop; 4], ()> = core::array::try_from_fn(|idx| {
if idx == 2 {
panic!("peek a boo");
}
Ok(CountDrop)
});
});

assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
}

#[cfg(not(panic = "abort"))]
// https://stackoverflow.com/a/59211505
fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
where
F: FnOnce() -> R + core::panic::UnwindSafe,
{
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = std::panic::catch_unwind(f);
std::panic::set_hook(prev_hook);
result
}
1 change: 1 addition & 0 deletions library/core/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#![feature(extern_types)]
#![feature(flt2dec)]
#![feature(fmt_internals)]
#![feature(array_from_fn)]
#![feature(hashmap_internals)]
#![feature(try_find)]
#![feature(is_sorted)]
Expand Down

0 comments on commit 1069985

Please sign in to comment.