Skip to content

Commit

Permalink
Merge #566
Browse files Browse the repository at this point in the history
566: Audit most uses of unsafe r=torkleyy a=torkleyy

Partial fix for #548 

## Checklist

* [x] I've added tests for all code changes and additions (where applicable)
* [x] I've added a demonstration of the new feature to one or more examples
* [x] I've updated the book to reflect my changes
* [x] Usage of new public items is shown in the API docs



Co-authored-by: Thomas Schaller <[email protected]>
  • Loading branch information
bors[bot] and torkleyy committed Mar 31, 2019
2 parents 72a854d + 8b5f9c8 commit 75dde3e
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/bitset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ macro_rules! define_bit_join {
type Type = Index;
type Value = ();
type Mask = $bitset;

// SAFETY: This just moves a `BitSet`; invariants of `Join` are fulfilled, since `Self::Value` cannot be mutated.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(self, ())
}

// SAFETY: No unsafe code and no invariants to meet.
unsafe fn get(_: &mut Self::Value, id: Index) -> Self::Type {
id
}
Expand Down
13 changes: 13 additions & 0 deletions src/changeset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ impl<T> ChangeSet<T> {
T: AddAssign,
{
if self.mask.contains(entity.id()) {
// SAFETY: we checked the mask, thus it's safe to call
unsafe {
*self.inner.get_mut(entity.id()) += value;
}
} else {
// SAFETY: we checked the mask, thus it's safe to call
unsafe {
self.inner.insert(entity.id(), value);
}
Expand All @@ -73,6 +75,7 @@ impl<T> ChangeSet<T> {
/// Clear the changeset
pub fn clear(&mut self) {
for id in &self.mask {
// SAFETY: we checked the mask, thus it's safe to call
unsafe {
self.inner.remove(id);
}
Expand Down Expand Up @@ -110,10 +113,13 @@ impl<'a, T> Join for &'a mut ChangeSet<T> {
type Value = &'a mut DenseVecStorage<T>;
type Mask = &'a BitSet;

// SAFETY: No unsafe code and no invariants to meet.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(&self.mask, &mut self.inner)
}

// SAFETY: No unsafe code and no invariants to meet.
// `DistinctStorage` invariants are also met, but no `ParJoin` implementation exists yet.
unsafe fn get(v: &mut Self::Value, id: Index) -> Self::Type {
let value: *mut Self::Value = v as *mut Self::Value;
(*value).get_mut(id)
Expand All @@ -125,24 +131,31 @@ impl<'a, T> Join for &'a ChangeSet<T> {
type Value = &'a DenseVecStorage<T>;
type Mask = &'a BitSet;

// SAFETY: No unsafe code and no invariants to meet.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(&self.mask, &self.inner)
}

// SAFETY: No unsafe code and no invariants to meet.
// `DistinctStorage` invariants are also met, but no `ParJoin` implementation exists yet.
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
value.get(id)
}
}

/// A `Join` implementation for `ChangeSet` that simply removes all the entries on a call to `get`.
impl<T> Join for ChangeSet<T> {
type Type = T;
type Value = DenseVecStorage<T>;
type Mask = BitSet;

// SAFETY: No unsafe code and no invariants to meet.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(self.mask, self.inner)
}

// SAFETY: No unsafe code and no invariants to meet.
// `DistinctStorage` invariants are also met, but no `ParJoin` implementation exists yet.
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
value.remove(id)
}
Expand Down
40 changes: 40 additions & 0 deletions src/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,20 @@ pub trait Join {

/// Open this join by returning the mask and the storages.
///
/// # Safety
///
/// This is unsafe because implementations of this trait can permit
/// the `Value` to be mutated independently of the `Mask`.
/// If the `Mask` does not correctly report the status of the `Value`
/// then illegal memory access can occur.
unsafe fn open(self) -> (Self::Mask, Self::Value);

/// Get a joined component value by a given index.
///
/// # Safety
///
/// * A call to `get` must be preceded by a check if `id` is part of `Self::Mask`
/// * The implementation of this method may use unsafe code, but has no invariants to meet
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type;

/// If this `Join` typically returns all indices in the mask, then iterating over only it
Expand Down Expand Up @@ -261,10 +268,15 @@ where
type Type = Option<<T as Join>::Type>;
type Value = (<T as Join>::Mask, <T as Join>::Value);
type Mask = BitSetAll;

// SAFETY: This wraps another implementation of `open`, making it dependent on `J`'s correctness.
// We can safely assume `J` is valid, thus this must be valid, too. No invariants to meet.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let (mask, value) = self.0.open();
(BitSetAll, (mask, value))
}

// SAFETY: No invariants to meet and the unsafe code checks the mask, thus fulfills the requirements for calling `get`
unsafe fn get((mask, value): &mut Self::Value, id: Index) -> Self::Type {
if mask.contains(id) {
Some(<T as Join>::get(value, id))
Expand Down Expand Up @@ -293,6 +305,7 @@ impl<J: Join> JoinIter<J> {
println!("WARNING: `Join` possibly iterating through all indices, you might've made a join with all `MaybeJoin`s, which is unbounded in length.");
}

// SAFETY: We do not swap out the mask or the values, nor do we allow it by exposing them.
let (keys, values) = unsafe { j.open() };
JoinIter {
keys: keys.iter(),
Expand Down Expand Up @@ -353,6 +366,7 @@ impl<J: Join> JoinIter<J> {
/// ```
pub fn get(&mut self, entity: Entity, entities: &Entities) -> Option<J::Type> {
if self.keys.contains(entity.id()) && entities.is_alive(entity) {
// SAFETY: the mask (`keys`) is checked as specified in the docs of `get`.
Some(unsafe { J::get(&mut self.values, entity.id()) })
} else {
None
Expand All @@ -367,6 +381,7 @@ impl<J: Join> JoinIter<J> {
/// so the caller should ensure it instead.
pub fn get_unchecked(&mut self, index: Index) -> Option<J::Type> {
if self.keys.contains(index) {
// SAFETY: the mask (`keys`) is checked as specified in the docs of `get`.
Some(unsafe { J::get(&mut self.values, index) })
} else {
None
Expand All @@ -378,6 +393,8 @@ impl<J: Join> std::iter::Iterator for JoinIter<J> {
type Item = J::Type;

fn next(&mut self) -> Option<J::Type> {
// SAFETY: since `idx` is yielded from `keys` (the mask), it is necessarily a part of it.
// Thus, requirements are fulfilled for calling `get`.
self.keys
.next()
.map(|idx| unsafe { J::get(&mut self.values, idx) })
Expand All @@ -395,6 +412,9 @@ macro_rules! define_open {
type Value = ($($from::Value),*,);
type Mask = <($($from::Mask,)*) as BitAnd>::Value;
#[allow(non_snake_case)]

// SAFETY: While we do expose the mask and the values and therefore would allow swapping them,
// this method is `unsafe` and relies on the same invariants.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let ($($from,)*) = self;
let ($($from,)*) = ($($from.open(),)*);
Expand All @@ -404,6 +424,8 @@ macro_rules! define_open {
)
}

// SAFETY: No invariants to meet and `get` is safe to call as the caller must have checked the mask,
// which only has a key that exists in all of the storages.
#[allow(non_snake_case)]
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
let &mut ($(ref mut $from,)*) = v;
Expand All @@ -417,6 +439,10 @@ macro_rules! define_open {
unconstrained
}
}

// SAFETY: This is safe to implement since all components implement `ParJoin`.
// If the access of every individual `get` leads to disjoint memory access, calling
// all of them after another does in no case lead to access of common memory.
#[cfg(feature = "parallel")]
unsafe impl<$($from,)*> ParJoin for ($($from),*,)
where $($from: ParJoin),*,
Expand Down Expand Up @@ -463,10 +489,15 @@ macro_rules! immutable_resource_join {
type Type = <&'a T as Join>::Type;
type Value = <&'a T as Join>::Value;
type Mask = <&'a T as Join>::Mask;

// SAFETY: This only wraps `T` and, while exposing the mask and the values,
// requires the same invariants as the original implementation and is thus safe.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
self.deref().open()
}

// SAFETY: The mask of `Self` and `T` are identical, thus a check to `Self`'s mask (which is required)
// is equal to a check of `T`'s mask, which makes `get` safe to call.
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
<&'a T as Join>::get(v, i)
}
Expand All @@ -477,6 +508,8 @@ macro_rules! immutable_resource_join {
}
}

// SAFETY: This is just a wrapper of `T`'s implementation for `ParJoin` and can
// in no case lead to other memory access patterns.
#[cfg(feature = "parallel")]
unsafe impl<'a, 'b, T> ParJoin for &'a $ty
where
Expand All @@ -498,10 +531,15 @@ macro_rules! mutable_resource_join {
type Type = <&'a mut T as Join>::Type;
type Value = <&'a mut T as Join>::Value;
type Mask = <&'a mut T as Join>::Mask;

// SAFETY: This only wraps `T` and, while exposing the mask and the values,
// requires the same invariants as the original implementation and is thus safe.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
self.deref_mut().open()
}

// SAFETY: The mask of `Self` and `T` are identical, thus a check to `Self`'s mask (which is required)
// is equal to a check of `T`'s mask, which makes `get_mut` safe to call.
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
<&'a mut T as Join>::get(v, i)
}
Expand All @@ -512,6 +550,8 @@ macro_rules! mutable_resource_join {
}
}

// SAFETY: This is just a wrapper of `T`'s implementation for `ParJoin` and can
// in no case lead to other memory access patterns.
#[cfg(feature = "parallel")]
unsafe impl<'a, 'b, T> ParJoin for &'a mut $ty
where
Expand Down
18 changes: 18 additions & 0 deletions src/join/par_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ use join::Join;
/// The purpose of the `ParJoin` trait is to provide a way
/// to access multiple storages in parallel at the same time with
/// the merged bit set.
///
/// # Safety
///
/// The implementation of `ParallelIterator` for `ParJoin` makes multiple assumptions on the structure of `Self`.
/// In particular, `<Self as Join>::get` must be callable from multiple threads, simultaneously, without mutating
/// values not exclusively associated with `id`.
// NOTE: This is currently unspecified behavior. It seems very unlikely that it breaks in the future,
// but technically it's not specified as valid Rust code.
pub unsafe trait ParJoin: Join {
/// Create a joined parallel iterator over the contents.
fn par_join(self) -> JoinParIter<Self>
Expand Down Expand Up @@ -45,6 +53,8 @@ where
let (keys, values) = unsafe { self.0.open() };
// Create a bit producer which splits on up to three levels
let producer = BitProducer((&keys).iter(), 3);
// HACK: use `UnsafeCell` to share `values` between threads;
// this is the unspecified behavior referred to above.
let values = UnsafeCell::new(values);

bridge_unindexed(JoinProducer::<J>::new(producer, &values), consumer)
Expand Down Expand Up @@ -74,6 +84,14 @@ where
}
}

// SAFETY: `Send` is safe to implement if all components of `Self` are logically `Send`.
// `keys` already has `Send` implemented, thus no reasoning is required.
// `values` is a reference to an `UnsafeCell` wrapping `J::Value`;
// `J::Value` is constrained to implement `Send`.
// `UnsafeCell` provides interior mutability, but the specification of it allows sharing
// as long as access does not happen simultaneously; this makes it generally safe to `Send`,
// but we are accessing it simultaneously, which is technically not allowed.
// Also see https://github.com/slide-rs/specs/issues/220
unsafe impl<'a, J> Send for JoinProducer<'a, J>
where
J: Join + Send,
Expand Down
2 changes: 2 additions & 0 deletions src/storage/drain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ where
type Value = &'a mut MaskedStorage<T>;
type Mask = BitSet;

// SAFETY: No invariants to meet and no unsafe code.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let mask = self.data.mask.clone();

(mask, self.data)
}

// SAFETY: No invariants to meet and no unsafe code.
unsafe fn get(value: &mut Self::Value, id: Index) -> T {
value.remove(id).expect("Tried to access same index twice")
}
Expand Down
12 changes: 12 additions & 0 deletions src/storage/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ where
if self.entities.is_alive(e) {
unsafe {
let entries = self.entries();
// SAFETY: This is safe since we're not swapping out the mask or the values.
let (_, mut value): (BitSetAll, _) = entries.open();
// SAFETY: We did check the mask, because the mask is `BitSetAll` and every index is part of it.
Ok(Entries::get(&mut value, e.id()))
}
} else {
Expand Down Expand Up @@ -132,10 +134,13 @@ where
type Value = &'a mut Storage<'b, T, D>;
type Mask = BitSetAll;

// SAFETY: No invariants to meet and no unsafe code.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(BitSetAll, self.0)
}

// SAFETY: We are lengthening the lifetime of `value` to `'a`;
// TODO: how to prove this is safe?
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
// This is HACK. See implementation of Join for &'a mut Storage<'e, T, D> for
// details why it is necessary.
Expand Down Expand Up @@ -172,6 +177,8 @@ where
{
/// Get a reference to the component associated with the entity.
pub fn get(&self) -> &T {
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
unsafe { self.storage.data.inner.get(self.id) }
}
}
Expand All @@ -183,12 +190,16 @@ where
{
/// Get a mutable reference to the component associated with the entity.
pub fn get_mut(&mut self) -> &mut T {
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
unsafe { self.storage.data.inner.get_mut(self.id) }
}

/// Converts the `OccupiedEntry` into a mutable reference bounded by
/// the storage's lifetime.
pub fn into_mut(self) -> &'a mut T {
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
unsafe { self.storage.data.inner.get_mut(self.id) }
}

Expand Down Expand Up @@ -218,6 +229,7 @@ where
/// Inserts a value into the storage.
pub fn insert(self, component: T) -> &'a mut T {
self.storage.data.mask.add(self.id);
// SAFETY: This is safe since we added `self.id` to the mask.
unsafe {
self.storage.data.inner.insert(self.id, component);
self.storage.data.inner.get_mut(self.id)
Expand Down
Loading

0 comments on commit 75dde3e

Please sign in to comment.