Skip to content

Commit

Permalink
Fix slice::ChunksMut aliasing
Browse files Browse the repository at this point in the history
  • Loading branch information
saethlin committed Jul 3, 2022
1 parent 5f98537 commit 7919e42
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 72 deletions.
171 changes: 99 additions & 72 deletions library/core/src/slice/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1629,14 +1629,15 @@ unsafe impl<'a, T> TrustedRandomAccessNoCoerce for Chunks<'a, T> {
#[stable(feature = "rust1", since = "1.0.0")]
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct ChunksMut<'a, T: 'a> {
v: &'a mut [T],
v: *mut [T],
chunk_size: usize,
_marker: PhantomData<&'a mut T>,
}

impl<'a, T: 'a> ChunksMut<'a, T> {
#[inline]
pub(super) fn new(slice: &'a mut [T], size: usize) -> Self {
Self { v: slice, chunk_size: size }
Self { v: slice, chunk_size: size, _marker: PhantomData }
}
}

Expand All @@ -1650,10 +1651,11 @@ impl<'a, T> Iterator for ChunksMut<'a, T> {
None
} else {
let sz = cmp::min(self.v.len(), self.chunk_size);
let tmp = mem::replace(&mut self.v, &mut []);
let (head, tail) = tmp.split_at_mut(sz);
// SAFETY: sz cannot exceed the slice length based on the calculation above
let (head, tail) = unsafe { self.v.split_at_mut(sz) };
self.v = tail;
Some(head)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *head })
}
}

Expand Down Expand Up @@ -1685,11 +1687,13 @@ impl<'a, T> Iterator for ChunksMut<'a, T> {
Some(sum) => cmp::min(self.v.len(), sum),
None => self.v.len(),
};
let tmp = mem::replace(&mut self.v, &mut []);
let (head, tail) = tmp.split_at_mut(end);
let (_, nth) = head.split_at_mut(start);
// SAFETY: end is inbounds because we compared above against self.v.len()
let (head, tail) = unsafe { self.v.split_at_mut(end) };
// SAFETY: start is inbounds because
let (_, nth) = unsafe { head.split_at_mut(start) };
self.v = tail;
Some(nth)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *nth })
}
}

Expand All @@ -1699,7 +1703,8 @@ impl<'a, T> Iterator for ChunksMut<'a, T> {
None
} else {
let start = (self.v.len() - 1) / self.chunk_size * self.chunk_size;
Some(&mut self.v[start..])
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *self.v.get_unchecked_mut(start..) })
}
}

Expand Down Expand Up @@ -1727,12 +1732,12 @@ impl<'a, T> DoubleEndedIterator for ChunksMut<'a, T> {
} else {
let remainder = self.v.len() % self.chunk_size;
let sz = if remainder != 0 { remainder } else { self.chunk_size };
let tmp = mem::replace(&mut self.v, &mut []);
let tmp_len = tmp.len();
let len = self.v.len();
// SAFETY: Similar to `Chunks::next_back`
let (head, tail) = unsafe { tmp.split_at_mut_unchecked(tmp_len - sz) };
let (head, tail) = unsafe { self.v.split_at_mut_unchecked(len - sz) };
self.v = head;
Some(tail)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *tail })
}
}

Expand All @@ -1748,10 +1753,12 @@ impl<'a, T> DoubleEndedIterator for ChunksMut<'a, T> {
Some(res) => cmp::min(self.v.len(), res),
None => self.v.len(),
};
let (temp, _tail) = mem::replace(&mut self.v, &mut []).split_at_mut(end);
let (head, nth_back) = temp.split_at_mut(start);
// SAFETY: end is inbounds because we compared above against self.v.len()
let (temp, _tail) = unsafe { self.v.split_at_mut(end) };
let (head, nth_back) = unsafe { temp.split_at_mut(start) };
self.v = head;
Some(nth_back)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *nth_back })
}
}
}
Expand Down Expand Up @@ -1957,9 +1964,10 @@ unsafe impl<'a, T> TrustedRandomAccessNoCoerce for ChunksExact<'a, T> {
#[stable(feature = "chunks_exact", since = "1.31.0")]
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct ChunksExactMut<'a, T: 'a> {
v: &'a mut [T],
rem: &'a mut [T],
v: *mut [T],
rem: &'a mut [T], // The iterator never yields from here, so this can be unique
chunk_size: usize,
_marker: PhantomData<&'a mut T>,
}

impl<'a, T> ChunksExactMut<'a, T> {
Expand All @@ -1969,7 +1977,7 @@ impl<'a, T> ChunksExactMut<'a, T> {
let fst_len = slice.len() - rem;
// SAFETY: 0 <= fst_len <= slice.len() by construction above
let (fst, snd) = unsafe { slice.split_at_mut_unchecked(fst_len) };
Self { v: fst, rem: snd, chunk_size }
Self { v: fst, rem: snd, chunk_size, _marker: PhantomData }
}

/// Returns the remainder of the original slice that is not going to be
Expand All @@ -1991,10 +1999,11 @@ impl<'a, T> Iterator for ChunksExactMut<'a, T> {
if self.v.len() < self.chunk_size {
None
} else {
let tmp = mem::replace(&mut self.v, &mut []);
let (head, tail) = tmp.split_at_mut(self.chunk_size);
// SAFETY: self.chunk_size is inbounds because we compared above against self.v.len()
let (head, tail) = unsafe { self.v.split_at_mut(self.chunk_size) };
self.v = tail;
Some(head)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *head })
}
}

Expand All @@ -2016,8 +2025,7 @@ impl<'a, T> Iterator for ChunksExactMut<'a, T> {
self.v = &mut [];
None
} else {
let tmp = mem::replace(&mut self.v, &mut []);
let (_, snd) = tmp.split_at_mut(start);
let (_, snd) = unsafe { self.v.split_at_mut(start) };
self.v = snd;
self.next()
}
Expand All @@ -2042,11 +2050,11 @@ impl<'a, T> DoubleEndedIterator for ChunksExactMut<'a, T> {
if self.v.len() < self.chunk_size {
None
} else {
let tmp = mem::replace(&mut self.v, &mut []);
let tmp_len = tmp.len();
let (head, tail) = tmp.split_at_mut(tmp_len - self.chunk_size);
// SAFETY: This subtraction is inbounds because of the check above
let (head, tail) = unsafe { self.v.split_at_mut(self.v.len() - self.chunk_size) };
self.v = head;
Some(tail)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *tail })
}
}

Expand All @@ -2059,10 +2067,11 @@ impl<'a, T> DoubleEndedIterator for ChunksExactMut<'a, T> {
} else {
let start = (len - 1 - n) * self.chunk_size;
let end = start + self.chunk_size;
let (temp, _tail) = mem::replace(&mut self.v, &mut []).split_at_mut(end);
let (head, nth_back) = temp.split_at_mut(start);
let (temp, _tail) = unsafe { mem::replace(&mut self.v, &mut []).split_at_mut(end) };
let (head, nth_back) = unsafe { temp.split_at_mut(start) };
self.v = head;
Some(nth_back)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *nth_back })
}
}
}
Expand Down Expand Up @@ -2646,14 +2655,15 @@ unsafe impl<'a, T> TrustedRandomAccessNoCoerce for RChunks<'a, T> {
#[stable(feature = "rchunks", since = "1.31.0")]
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct RChunksMut<'a, T: 'a> {
v: &'a mut [T],
v: *mut [T],
chunk_size: usize,
_marker: PhantomData<&'a mut T>,
}

impl<'a, T: 'a> RChunksMut<'a, T> {
#[inline]
pub(super) fn new(slice: &'a mut [T], size: usize) -> Self {
Self { v: slice, chunk_size: size }
Self { v: slice, chunk_size: size, _marker: PhantomData }
}
}

Expand All @@ -2667,16 +2677,16 @@ impl<'a, T> Iterator for RChunksMut<'a, T> {
None
} else {
let sz = cmp::min(self.v.len(), self.chunk_size);
let tmp = mem::replace(&mut self.v, &mut []);
let tmp_len = tmp.len();
let len = self.v.len();
// SAFETY: split_at_mut_unchecked just requires the argument be less
// than the length. This could only happen if the expression
// `tmp_len - sz` overflows. This could only happen if `sz >
// tmp_len`, which is impossible as we initialize it as the `min` of
// `self.v.len()` (e.g. `tmp_len`) and `self.chunk_size`.
let (head, tail) = unsafe { tmp.split_at_mut_unchecked(tmp_len - sz) };
// `len - sz` overflows. This could only happen if `sz >
// len`, which is impossible as we initialize it as the `min` of
// `self.v.len()` (e.g. `len`) and `self.chunk_size`.
let (head, tail) = unsafe { self.v.split_at_mut_unchecked(len - sz) };
self.v = head;
Some(tail)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *tail })
}
}

Expand Down Expand Up @@ -2710,11 +2720,11 @@ impl<'a, T> Iterator for RChunksMut<'a, T> {
Some(sum) => sum,
None => 0,
};
let tmp = mem::replace(&mut self.v, &mut []);
let (head, tail) = tmp.split_at_mut(start);
let (nth, _) = tail.split_at_mut(end - start);
let (head, tail) = unsafe { self.v.split_at_mut(start) };
let (nth, _) = unsafe { tail.split_at_mut(end - start) };
self.v = head;
Some(nth)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *nth })
}
}

Expand All @@ -2725,7 +2735,8 @@ impl<'a, T> Iterator for RChunksMut<'a, T> {
} else {
let rem = self.v.len() % self.chunk_size;
let end = if rem == 0 { self.chunk_size } else { rem };
Some(&mut self.v[0..end])
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *self.v.get_unchecked_mut(0..end) })
}
}

Expand All @@ -2750,11 +2761,11 @@ impl<'a, T> DoubleEndedIterator for RChunksMut<'a, T> {
} else {
let remainder = self.v.len() % self.chunk_size;
let sz = if remainder != 0 { remainder } else { self.chunk_size };
let tmp = mem::replace(&mut self.v, &mut []);
// SAFETY: Similar to `Chunks::next_back`
let (head, tail) = unsafe { tmp.split_at_mut_unchecked(sz) };
let (head, tail) = unsafe { self.v.split_at_mut_unchecked(sz) };
self.v = tail;
Some(head)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *head })
}
}

Expand All @@ -2769,10 +2780,11 @@ impl<'a, T> DoubleEndedIterator for RChunksMut<'a, T> {
let offset_from_end = (len - 1 - n) * self.chunk_size;
let end = self.v.len() - offset_from_end;
let start = end.saturating_sub(self.chunk_size);
let (tmp, tail) = mem::replace(&mut self.v, &mut []).split_at_mut(end);
let (_, nth_back) = tmp.split_at_mut(start);
let (tmp, tail) = unsafe { self.v.split_at_mut(end) };
let (_, nth_back) = unsafe { tmp.split_at_mut(start) };
self.v = tail;
Some(nth_back)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *nth_back })
}
}
}
Expand Down Expand Up @@ -2898,8 +2910,7 @@ impl<'a, T> Iterator for RChunksExact<'a, T> {
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> Self::Item {
let end = self.v.len() - idx * self.chunk_size;
let start = end - self.chunk_size;
// SAFETY:
// SAFETY: mostmy identical to `Chunks::__iterator_get_unchecked`.
// SAFETY: mostly identical to `Chunks::__iterator_get_unchecked`.
unsafe { from_raw_parts(self.v.as_ptr().add(start), self.chunk_size) }
}
}
Expand Down Expand Up @@ -2982,7 +2993,7 @@ unsafe impl<'a, T> TrustedRandomAccessNoCoerce for RChunksExact<'a, T> {
#[stable(feature = "rchunks", since = "1.31.0")]
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct RChunksExactMut<'a, T: 'a> {
v: &'a mut [T],
v: *mut [T],
rem: &'a mut [T],
chunk_size: usize,
}
Expand Down Expand Up @@ -3015,11 +3026,11 @@ impl<'a, T> Iterator for RChunksExactMut<'a, T> {
if self.v.len() < self.chunk_size {
None
} else {
let tmp = mem::replace(&mut self.v, &mut []);
let tmp_len = tmp.len();
let (head, tail) = tmp.split_at_mut(tmp_len - self.chunk_size);
let len = self.v.len();
let (head, tail) = unsafe { self.v.split_at_mut(len - self.chunk_size) };
self.v = head;
Some(tail)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *tail })
}
}

Expand All @@ -3041,9 +3052,8 @@ impl<'a, T> Iterator for RChunksExactMut<'a, T> {
self.v = &mut [];
None
} else {
let tmp = mem::replace(&mut self.v, &mut []);
let tmp_len = tmp.len();
let (fst, _) = tmp.split_at_mut(tmp_len - end);
let len = self.v.len();
let (fst, _) = unsafe { self.v.split_at_mut(len - end) };
self.v = fst;
self.next()
}
Expand All @@ -3069,10 +3079,10 @@ impl<'a, T> DoubleEndedIterator for RChunksExactMut<'a, T> {
if self.v.len() < self.chunk_size {
None
} else {
let tmp = mem::replace(&mut self.v, &mut []);
let (head, tail) = tmp.split_at_mut(self.chunk_size);
let (head, tail) = unsafe { self.v.split_at_mut(self.chunk_size) };
self.v = tail;
Some(head)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *head })
}
}

Expand All @@ -3088,10 +3098,11 @@ impl<'a, T> DoubleEndedIterator for RChunksExactMut<'a, T> {
let offset = (len - n) * self.chunk_size;
let start = self.v.len() - offset;
let end = start + self.chunk_size;
let (tmp, tail) = mem::replace(&mut self.v, &mut []).split_at_mut(end);
let (_, nth_back) = tmp.split_at_mut(start);
let (tmp, tail) = unsafe { self.v.split_at_mut(end) };
let (_, nth_back) = unsafe { tmp.split_at_mut(start) };
self.v = tail;
Some(nth_back)
// SAFETY: Nothing points to or will point to the contents of this slice
Some(unsafe { &mut *nth_back })
}
}
}
Expand Down Expand Up @@ -3174,7 +3185,11 @@ where
let mut len = 1;
let mut iter = self.slice.windows(2);
while let Some([l, r]) = iter.next() {
if (self.predicate)(l, r) { len += 1 } else { break }
if (self.predicate)(l, r) {
len += 1
} else {
break;
}
}
let (head, tail) = self.slice.split_at(len);
self.slice = tail;
Expand Down Expand Up @@ -3206,7 +3221,11 @@ where
let mut len = 1;
let mut iter = self.slice.windows(2);
while let Some([l, r]) = iter.next_back() {
if (self.predicate)(l, r) { len += 1 } else { break }
if (self.predicate)(l, r) {
len += 1
} else {
break;
}
}
let (head, tail) = self.slice.split_at(self.slice.len() - len);
self.slice = head;
Expand Down Expand Up @@ -3261,7 +3280,11 @@ where
let mut len = 1;
let mut iter = self.slice.windows(2);
while let Some([l, r]) = iter.next() {
if (self.predicate)(l, r) { len += 1 } else { break }
if (self.predicate)(l, r) {
len += 1
} else {
break;
}
}
let slice = mem::take(&mut self.slice);
let (head, tail) = slice.split_at_mut(len);
Expand Down Expand Up @@ -3294,7 +3317,11 @@ where
let mut len = 1;
let mut iter = self.slice.windows(2);
while let Some([l, r]) = iter.next_back() {
if (self.predicate)(l, r) { len += 1 } else { break }
if (self.predicate)(l, r) {
len += 1
} else {
break;
}
}
let slice = mem::take(&mut self.slice);
let (head, tail) = slice.split_at_mut(slice.len() - len);
Expand Down
Loading

0 comments on commit 7919e42

Please sign in to comment.