Skip to content

Commit

Permalink
Fix provenance issues
Browse files Browse the repository at this point in the history
The uninitialized region of a Vec cannot be accessed by slicing the Vec
first, see rust-lang/rust#92097

Similarly, it is never valid to merge adjacent slices, because any
pointer derived from a slice only has provenance over that slice, not
anything adjacent. So we pass raw pointers and a length around to avoid
narrowing provenance by converting to a reference.
  • Loading branch information
saethlin authored and cuviper committed Apr 1, 2022
1 parent 1c5277f commit 0d8def7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 47 deletions.
107 changes: 66 additions & 41 deletions src/iter/collect/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,34 @@ use std::mem::MaybeUninit;
use std::ptr;
use std::slice;

/// We need to store raw pointers in a `Send` struct to remember
/// provenance (see `CollectResult`).
#[derive(Clone, Copy)]
struct SendPtr<T>(*mut T);

// SAFETY: !Send for raw pointers is not for safety, just as a lint
unsafe impl<T> Send for SendPtr<T> {}

pub(super) struct CollectConsumer<'c, T: Send> {
/// A slice covering the target memory, not yet initialized!
target: &'c mut [MaybeUninit<T>],
/// See `CollectConsumer` for explanation of why this is not a slice
start: SendPtr<MaybeUninit<T>>,
len: usize,
marker: PhantomData<&'c mut T>,
}

impl<'c, T: Send + 'c> CollectConsumer<'c, T> {
/// The target memory is considered uninitialized, and will be
/// overwritten without reading or dropping existing values.
pub(super) fn new(target: &'c mut [MaybeUninit<T>]) -> Self {
CollectConsumer { target }
pub(super) fn new(
start: *mut MaybeUninit<T>,
len: usize,
marker: PhantomData<&'c mut T>,
) -> Self {
CollectConsumer {
start: SendPtr(start),
len,
marker,
}
}
}

Expand All @@ -23,10 +41,14 @@ impl<'c, T: Send + 'c> CollectConsumer<'c, T> {
/// the elements will be dropped, unless its ownership is released before then.
#[must_use]
pub(super) struct CollectResult<'c, T> {
/// A slice covering the target memory, initialized up to our separate `len`.
target: &'c mut [MaybeUninit<T>],
/// The current initialized length in `target`
len: usize,
/// This pointer and length has the same representation as a slice,
/// but retains the provenance of the entire array so that we can merge
/// these regions together in `CollectReducer`.
/// Constructing a slice from this start + start_l
start: SendPtr<MaybeUninit<T>>,
total_len: usize,
/// The current initialized length after `start`
initialized_len: usize,
/// Lifetime invariance guarantees that the data flows from consumer to result,
/// especially for the `scope_fn` callback in `Collect::with_consumer`.
invariant_lifetime: PhantomData<&'c mut &'c mut [T]>,
Expand All @@ -37,25 +59,26 @@ unsafe impl<'c, T> Send for CollectResult<'c, T> where T: Send {}
impl<'c, T> CollectResult<'c, T> {
/// The current length of the collect result
pub(super) fn len(&self) -> usize {
self.len
self.initialized_len
}

/// Release ownership of the slice of elements, and return the length
pub(super) fn release_ownership(mut self) -> usize {
let ret = self.len;
self.len = 0;
let ret = self.initialized_len;
self.initialized_len = 0;
ret
}
}

impl<'c, T> Drop for CollectResult<'c, T> {
fn drop(&mut self) {
// Drop the first `self.len` elements, which have been recorded
// Drop the first `self.initialized_len` elements, which have been recorded
// to be initialized by the folder.
unsafe {
// TODO: use `MaybeUninit::slice_as_mut_ptr`
let start = self.target.as_mut_ptr() as *mut T;
ptr::drop_in_place(slice::from_raw_parts_mut(start, self.len));
ptr::drop_in_place(slice::from_raw_parts_mut(
self.start.0 as *mut T,
self.initialized_len,
));
}
}
}
Expand All @@ -66,24 +89,27 @@ impl<'c, T: Send + 'c> Consumer<T> for CollectConsumer<'c, T> {
type Result = CollectResult<'c, T>;

fn split_at(self, index: usize) -> (Self, Self, CollectReducer) {
let CollectConsumer { target } = self;

// Produce new consumers. Normal slicing ensures that the
// memory range given to each consumer is disjoint.
let (left, right) = target.split_at_mut(index);
(
CollectConsumer::new(left),
CollectConsumer::new(right),
CollectReducer,
)
let CollectConsumer { start, len, marker } = self;

// Produce new consumers.
// SAFETY: This assert checks that `index` is a valid offset for `start`
unsafe {
assert!(index <= len);
(
CollectConsumer::new(start.0, index, marker),
CollectConsumer::new(start.0.add(index), len - index, marker),
CollectReducer,
)
}
}

fn into_folder(self) -> Self::Folder {
// Create a result/folder that consumes values and writes them
// into target. The initial result has length 0.
// into the region after start. The initial result has length 0.
CollectResult {
target: self.target,
len: 0,
start: self.start,
total_len: self.len,
initialized_len: 0,
invariant_lifetime: PhantomData,
}
}
Expand All @@ -97,15 +123,15 @@ impl<'c, T: Send + 'c> Folder<T> for CollectResult<'c, T> {
type Result = Self;

fn consume(mut self, item: T) -> Self {
let dest = self
.target
.get_mut(self.len)
.expect("too many values pushed to consumer");
assert!(
self.initialized_len < self.total_len,
"too many values pushed to consumer"
);

// Write item and increase the initialized length
unsafe {
dest.as_mut_ptr().write(item);
self.len += 1;
(*self.start.0.add(self.initialized_len)).write(item);
self.initialized_len += 1;
}

self
Expand Down Expand Up @@ -146,14 +172,13 @@ impl<'c, T> Reducer<CollectResult<'c, T>> for CollectReducer {
// Merge if the CollectResults are adjacent and in left to right order
// else: drop the right piece now and total length will end up short in the end,
// when the correctness of the collected result is asserted.
let left_end = left.target[left.len..].as_ptr();
if left_end == right.target.as_ptr() {
let len = left.len + right.release_ownership();
unsafe {
left.target = slice::from_raw_parts_mut(left.target.as_mut_ptr(), len);
unsafe {
let left_end = left.start.0.add(left.initialized_len);
if left_end == right.start.0 {
left.total_len += right.total_len;
left.initialized_len += right.release_ownership();
}
left.len = len;
left
}
left
}
}
14 changes: 11 additions & 3 deletions src/iter/collect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ impl<'c, T: Send + 'c> Collect<'c, T> {
F: FnOnce(CollectConsumer<'_, T>) -> CollectResult<'_, T>,
{
let slice = Self::reserve_get_tail_slice(&mut self.vec, self.len);
let result = scope_fn(CollectConsumer::new(slice));
let result = scope_fn(CollectConsumer::new(
slice.as_mut_ptr(),
slice.len(),
std::marker::PhantomData,
));

// The CollectResult represents a contiguous part of the
// slice, that has been written to.
Expand Down Expand Up @@ -136,8 +140,12 @@ impl<'c, T: Send + 'c> Collect<'c, T> {
// TODO: use `Vec::spare_capacity_mut` instead
// SAFETY: `MaybeUninit<T>` is guaranteed to have the same layout
// as `T`, and we already made sure to have the additional space.
// This pointer is derived from `Vec` directly, not through a `Deref`,
// so it has provenance over the whole allocation.
let start = vec.len();
let tail_ptr = vec[start..].as_mut_ptr() as *mut MaybeUninit<T>;
unsafe { slice::from_raw_parts_mut(tail_ptr, len) }
unsafe {
let tail_ptr = vec.as_mut_ptr() as *mut MaybeUninit<T>;
slice::from_raw_parts_mut(tail_ptr.add(start), len)
}
}
}
5 changes: 2 additions & 3 deletions src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ impl<'data, T: Send> IndexedParallelIterator for Drain<'data, T> {

// Create the producer as the exclusive "owner" of the slice.
let producer = {
// Get a correct borrow lifetime, then extend it to the original length.
let mut slice = &mut self.vec[start..];
slice = slice::from_raw_parts_mut(slice.as_mut_ptr(), self.range.len());
let ptr = self.vec.as_mut_ptr().add(start);
let slice = slice::from_raw_parts_mut(ptr, self.range.len());
DrainProducer::new(slice)
};

Expand Down

0 comments on commit 0d8def7

Please sign in to comment.