Skip to content

Commit

Permalink
Merge #1011
Browse files Browse the repository at this point in the history
1011: Use pointers instead of `&self` in `Latch::set` r=cuviper a=cuviper

`Latch::set` can invalidate its own `&self`, because it releases the
owning thread to continue execution, which may then invalidate the latch
by deallocation, reuse, etc. We've known about this problem when it
comes to accessing latch fields too late, but the possibly dangling
reference was still a problem, like rust-lang/rust#55005.

The result of that was rust-lang/rust#98017, omitting the LLVM attribute
`dereferenceable` on references to `!Freeze` types -- those containing
`UnsafeCell`. However, miri's Stacked Borrows implementation is finer-
grained than that, only relaxing for the cell itself in the `!Freeze`
type. For rayon, that solves the dangling reference in atomic calls, but
remains a problem for other fields of a `Latch`.

This easiest fix for rayon is to use a raw pointer instead of `&self`.
We still end up with some temporary references for stuff like atomics,
but those should be fine with the rules above.


Co-authored-by: Josh Stone <[email protected]>
  • Loading branch information
bors[bot] and cuviper authored Jan 21, 2023
2 parents ed98853 + f880d02 commit 8cee824
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 58 deletions.
5 changes: 4 additions & 1 deletion rayon-core/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::job::{ArcJob, StackJob};
use crate::latch::LatchRef;
use crate::registry::{Registry, WorkerThread};
use crate::scope::ScopeLatch;
use std::fmt;
Expand Down Expand Up @@ -107,7 +108,9 @@ where
let n_threads = registry.num_threads();
let current_thread = WorkerThread::current().as_ref();
let latch = ScopeLatch::with_count(n_threads, current_thread);
let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, &latch)).collect();
let jobs: Vec<_> = (0..n_threads)
.map(|_| StackJob::new(&f, LatchRef::new(&latch)))
.collect();
let job_refs = jobs.iter().map(|job| job.as_job_ref());

registry.inject_broadcast(job_refs);
Expand Down
2 changes: 1 addition & 1 deletion rayon-core/src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ where
let abort = unwind::AbortIfPanic;
let func = (*this.func.get()).take().unwrap();
(*this.result.get()) = JobResult::call(func);
this.latch.set();
Latch::set(&this.latch);
mem::forget(abort);
}
}
Expand Down
90 changes: 62 additions & 28 deletions rayon-core/src/latch.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::usize;
Expand Down Expand Up @@ -37,10 +39,15 @@ pub(super) trait Latch {
///
/// Setting a latch triggers other threads to wake up and (in some
/// cases) complete. This may, in turn, cause memory to be
/// allocated and so forth. One must be very careful about this,
/// deallocated and so forth. One must be very careful about this,
/// and it's typically better to read all the fields you will need
/// to access *before* a latch is set!
fn set(&self);
///
/// This function operates on `*const Self` instead of `&self` to allow it
/// to become dangling during this call. The caller must ensure that the
/// pointer is valid upon entry, and not invalidated during the call by any
/// actions other than `set` itself.
unsafe fn set(this: *const Self);
}

pub(super) trait AsCoreLatch {
Expand Down Expand Up @@ -123,8 +130,8 @@ impl CoreLatch {
/// doing some wakeups; those are encapsulated in the surrounding
/// latch code.
#[inline]
fn set(&self) -> bool {
let old_state = self.state.swap(SET, Ordering::AcqRel);
unsafe fn set(this: *const Self) -> bool {
let old_state = (*this).state.swap(SET, Ordering::AcqRel);
old_state == SLEEPING
}

Expand Down Expand Up @@ -186,29 +193,29 @@ impl<'r> AsCoreLatch for SpinLatch<'r> {

impl<'r> Latch for SpinLatch<'r> {
#[inline]
fn set(&self) {
unsafe fn set(this: *const Self) {
let cross_registry;

let registry: &Registry = if self.cross {
let registry: &Registry = if (*this).cross {
// Ensure the registry stays alive while we notify it.
// Otherwise, it would be possible that we set the spin
// latch and the other thread sees it and exits, causing
// the registry to be deallocated, all before we get a
// chance to invoke `registry.notify_worker_latch_is_set`.
cross_registry = Arc::clone(self.registry);
cross_registry = Arc::clone((*this).registry);
&cross_registry
} else {
// If this is not a "cross-registry" spin-latch, then the
// thread which is performing `set` is itself ensuring
// that the registry stays alive. However, that doesn't
// include this *particular* `Arc` handle if the waiting
// thread then exits, so we must completely dereference it.
self.registry
(*this).registry
};
let target_worker_index = self.target_worker_index;
let target_worker_index = (*this).target_worker_index;

// NOTE: Once we `set`, the target may proceed and invalidate `&self`!
if self.core_latch.set() {
// NOTE: Once we `set`, the target may proceed and invalidate `this`!
if CoreLatch::set(&(*this).core_latch) {
// Subtle: at this point, we can no longer read from
// `self`, because the thread owning this spin latch may
// have awoken and deallocated the latch. Therefore, we
Expand Down Expand Up @@ -255,10 +262,10 @@ impl LockLatch {

impl Latch for LockLatch {
#[inline]
fn set(&self) {
let mut guard = self.m.lock().unwrap();
unsafe fn set(this: *const Self) {
let mut guard = (*this).m.lock().unwrap();
*guard = true;
self.v.notify_all();
(*this).v.notify_all();
}
}

Expand Down Expand Up @@ -307,9 +314,9 @@ impl CountLatch {
/// count, then the latch is **set**, and calls to `probe()` will
/// return true. Returns whether the latch was set.
#[inline]
pub(super) fn set(&self) -> bool {
if self.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
self.core_latch.set();
pub(super) unsafe fn set(this: *const Self) -> bool {
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
CoreLatch::set(&(*this).core_latch);
true
} else {
false
Expand All @@ -320,8 +327,12 @@ impl CountLatch {
/// the latch is set, then the specific worker thread is tickled,
/// which should be the one that owns this latch.
#[inline]
pub(super) fn set_and_tickle_one(&self, registry: &Registry, target_worker_index: usize) {
if self.set() {
pub(super) unsafe fn set_and_tickle_one(
this: *const Self,
registry: &Registry,
target_worker_index: usize,
) {
if Self::set(this) {
registry.notify_worker_latch_is_set(target_worker_index);
}
}
Expand Down Expand Up @@ -362,19 +373,42 @@ impl CountLockLatch {

impl Latch for CountLockLatch {
#[inline]
fn set(&self) {
if self.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
self.lock_latch.set();
unsafe fn set(this: *const Self) {
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
LockLatch::set(&(*this).lock_latch);
}
}
}

impl<'a, L> Latch for &'a L
where
L: Latch,
{
/// `&L` without any implication of `dereferenceable` for `Latch::set`
pub(super) struct LatchRef<'a, L> {
inner: *const L,
marker: PhantomData<&'a L>,
}

impl<L> LatchRef<'_, L> {
pub(super) fn new(inner: &L) -> LatchRef<'_, L> {
LatchRef {
inner,
marker: PhantomData,
}
}
}

unsafe impl<L: Sync> Sync for LatchRef<'_, L> {}

impl<L> Deref for LatchRef<'_, L> {
type Target = L;

fn deref(&self) -> &L {
// SAFETY: if we have &self, the inner latch is still alive
unsafe { &*self.inner }
}
}

impl<L: Latch> Latch for LatchRef<'_, L> {
#[inline]
fn set(&self) {
L::set(self);
unsafe fn set(this: *const Self) {
L::set((*this).inner);
}
}
10 changes: 5 additions & 5 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::job::{JobFifo, JobRef, StackJob};
use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LockLatch, SpinLatch};
use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LatchRef, LockLatch, SpinLatch};
use crate::log::Event::*;
use crate::log::Logger;
use crate::sleep::Sleep;
Expand Down Expand Up @@ -505,7 +505,7 @@ impl Registry {
assert!(injected && !worker_thread.is_null());
op(&*worker_thread, true)
},
l,
LatchRef::new(l),
);
self.inject(&[job.as_job_ref()]);
job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.
Expand Down Expand Up @@ -575,7 +575,7 @@ impl Registry {
pub(super) fn terminate(&self) {
if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
for (i, thread_info) in self.thread_infos.iter().enumerate() {
thread_info.terminate.set_and_tickle_one(self, i);
unsafe { CountLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
}
}
}
Expand Down Expand Up @@ -869,7 +869,7 @@ unsafe fn main_loop(
let registry = &*worker_thread.registry;

// let registry know we are ready to do work
registry.thread_infos[index].primed.set();
Latch::set(&registry.thread_infos[index].primed);

// Worker threads should not panic. If they do, just abort, as the
// internal state of the threadpool is corrupted. Note that if
Expand All @@ -892,7 +892,7 @@ unsafe fn main_loop(
debug_assert!(worker_thread.take_local_job().is_none());

// let registry know we are done
registry.thread_infos[index].stopped.set();
Latch::set(&registry.thread_infos[index].stopped);

// Normal termination, do not abort.
mem::forget(abort_guard);
Expand Down
46 changes: 23 additions & 23 deletions rayon-core/src/scope/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,10 @@ impl<'scope> Scope<'scope> {
BODY: FnOnce(&Scope<'scope>) + Send + 'scope,
{
let scope_ptr = ScopePtr(self);
let job = HeapJob::new(move || {
let job = HeapJob::new(move || unsafe {
// SAFETY: this job will execute before the scope ends.
let scope = unsafe { scope_ptr.as_ref() };
scope.base.execute_job(move || body(scope))
let scope = scope_ptr.as_ref();
ScopeBase::execute_job(&scope.base, move || body(scope))
});
let job_ref = self.base.heap_job_ref(job);

Expand All @@ -562,12 +562,12 @@ impl<'scope> Scope<'scope> {
BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
{
let scope_ptr = ScopePtr(self);
let job = ArcJob::new(move || {
let job = ArcJob::new(move || unsafe {
// SAFETY: this job will execute before the scope ends.
let scope = unsafe { scope_ptr.as_ref() };
let scope = scope_ptr.as_ref();
let body = &body;
let func = move || BroadcastContext::with(move |ctx| body(scope, ctx));
scope.base.execute_job(func);
ScopeBase::execute_job(&scope.base, func)
});
self.base.inject_broadcast(job)
}
Expand Down Expand Up @@ -600,10 +600,10 @@ impl<'scope> ScopeFifo<'scope> {
BODY: FnOnce(&ScopeFifo<'scope>) + Send + 'scope,
{
let scope_ptr = ScopePtr(self);
let job = HeapJob::new(move || {
let job = HeapJob::new(move || unsafe {
// SAFETY: this job will execute before the scope ends.
let scope = unsafe { scope_ptr.as_ref() };
scope.base.execute_job(move || body(scope))
let scope = scope_ptr.as_ref();
ScopeBase::execute_job(&scope.base, move || body(scope))
});
let job_ref = self.base.heap_job_ref(job);

Expand All @@ -628,12 +628,12 @@ impl<'scope> ScopeFifo<'scope> {
BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
{
let scope_ptr = ScopePtr(self);
let job = ArcJob::new(move || {
let job = ArcJob::new(move || unsafe {
// SAFETY: this job will execute before the scope ends.
let scope = unsafe { scope_ptr.as_ref() };
let scope = scope_ptr.as_ref();
let body = &body;
let func = move || BroadcastContext::with(move |ctx| body(scope, ctx));
scope.base.execute_job(func);
ScopeBase::execute_job(&scope.base, func)
});
self.base.inject_broadcast(job)
}
Expand Down Expand Up @@ -688,36 +688,36 @@ impl<'scope> ScopeBase<'scope> {
where
FUNC: FnOnce() -> R,
{
let result = self.execute_job_closure(func);
let result = unsafe { Self::execute_job_closure(self, func) };
self.job_completed_latch.wait(owner);
self.maybe_propagate_panic();
result.unwrap() // only None if `op` panicked, and that would have been propagated
}

/// Executes `func` as a job, either aborting or executing as
/// appropriate.
fn execute_job<FUNC>(&self, func: FUNC)
unsafe fn execute_job<FUNC>(this: *const Self, func: FUNC)
where
FUNC: FnOnce(),
{
let _: Option<()> = self.execute_job_closure(func);
let _: Option<()> = Self::execute_job_closure(this, func);
}

/// Executes `func` as a job in scope. Adjusts the "job completed"
/// counters and also catches any panic and stores it into
/// `scope`.
fn execute_job_closure<FUNC, R>(&self, func: FUNC) -> Option<R>
unsafe fn execute_job_closure<FUNC, R>(this: *const Self, func: FUNC) -> Option<R>
where
FUNC: FnOnce() -> R,
{
match unwind::halt_unwinding(func) {
Ok(r) => {
self.job_completed_latch.set();
Latch::set(&(*this).job_completed_latch);
Some(r)
}
Err(err) => {
self.job_panicked(err);
self.job_completed_latch.set();
(*this).job_panicked(err);
Latch::set(&(*this).job_completed_latch);
None
}
}
Expand Down Expand Up @@ -797,14 +797,14 @@ impl ScopeLatch {
}

impl Latch for ScopeLatch {
fn set(&self) {
match self {
unsafe fn set(this: *const Self) {
match &*this {
ScopeLatch::Stealing {
latch,
registry,
worker_index,
} => latch.set_and_tickle_one(registry, *worker_index),
ScopeLatch::Blocking { latch } => latch.set(),
} => CountLatch::set_and_tickle_one(latch, registry, *worker_index),
ScopeLatch::Blocking { latch } => Latch::set(latch),
}
}
}
Expand Down

0 comments on commit 8cee824

Please sign in to comment.