Skip to content

Commit

Permalink
Only work-steal in the main loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Zoxc committed Aug 28, 2023
1 parent f192a48 commit b72936d
Show file tree
Hide file tree
Showing 18 changed files with 284 additions and 117 deletions.
1 change: 1 addition & 0 deletions rayon-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ num_cpus = "1.2"
crossbeam-channel = "0.5.0"
crossbeam-deque = "0.8.1"
crossbeam-utils = "0.8.0"
smallvec = "1.11.0"

[dev-dependencies]
rand = "0.8"
Expand Down
24 changes: 21 additions & 3 deletions rayon-core/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::registry::{Registry, WorkerThread};
use crate::scope::ScopeLatch;
use std::fmt;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

mod test;
Expand Down Expand Up @@ -100,13 +101,22 @@ where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
let current_thread = WorkerThread::current();
let current_thread_addr = current_thread as usize;
let started = &AtomicBool::new(false);
let f = move |injected: bool| {
debug_assert!(injected);

// Mark as started if we are on the thread that initiated the broadcast.
if current_thread_addr == WorkerThread::current() as usize {
started.store(true, Ordering::Relaxed);
}

BroadcastContext::with(&op)
};

let n_threads = registry.num_threads();
let current_thread = WorkerThread::current().as_ref();
let current_thread = current_thread.as_ref();
let tlv = crate::tlv::get();
let latch = ScopeLatch::with_count(n_threads, current_thread);
let jobs: Vec<_> = (0..n_threads)
Expand All @@ -116,8 +126,16 @@ where

registry.inject_broadcast(job_refs);

let current_thread_job_id = current_thread
.and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker))
.map(|worker| jobs[worker.index].as_job_ref().id());

// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
latch.wait(current_thread);
latch.wait(
current_thread,
|| started.load(Ordering::Relaxed),
|job| Some(job.id()) == current_thread_job_id,
);
jobs.into_iter().map(|job| job.into_result()).collect()
}

Expand All @@ -133,7 +151,7 @@ where
{
let job = ArcJob::new({
let registry = Arc::clone(registry);
move || {
move |_| {
registry.catch_unwind(|| BroadcastContext::with(&op));
registry.terminate(); // (*) permit registry to terminate now
}
Expand Down
2 changes: 2 additions & 0 deletions rayon-core/src/broadcast/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ fn spawn_broadcast_self() {
}

#[test]
#[ignore]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn broadcast_mutual() {
let count = AtomicUsize::new(0);
Expand Down Expand Up @@ -97,6 +98,7 @@ fn spawn_broadcast_mutual() {
}

#[test]
#[ignore]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn broadcast_mutual_sleepy() {
let count = AtomicUsize::new(0);
Expand Down
40 changes: 26 additions & 14 deletions rayon-core/src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ pub(super) trait Job {
unsafe fn execute(this: *const ());
}

#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
pub(super) struct JobRefId {
pointer: usize,
}

/// Effectively a Job trait object. Each JobRef **must** be executed
/// exactly once, or else data may leak.
///
Expand Down Expand Up @@ -54,11 +59,11 @@ impl JobRef {
}
}

/// Returns an opaque handle that can be saved and compared,
/// without making `JobRef` itself `Copy + Eq`.
#[inline]
pub(super) fn id(&self) -> impl Eq {
(self.pointer, self.execute_fn)
pub(super) fn id(&self) -> JobRefId {
JobRefId {
pointer: self.pointer as usize,
}
}

#[inline]
Expand Down Expand Up @@ -102,8 +107,13 @@ where
JobRef::new(self)
}

pub(super) unsafe fn run_inline(self, stolen: bool) -> R {
self.func.into_inner().unwrap()(stolen)
pub(super) unsafe fn run_inline(&self, stolen: bool) {
let func = (*self.func.get()).take().unwrap();
(*self.result.get()) = match unwind::halt_unwinding(|| func(stolen)) {
Ok(x) => JobResult::Ok(x),
Err(x) => JobResult::Panic(x),
};
Latch::set(&self.latch);
}

pub(super) unsafe fn into_result(self) -> R {
Expand Down Expand Up @@ -136,15 +146,15 @@ where
/// (Probably `StackJob` should be refactored in a similar fashion.)
pub(super) struct HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
job: BODY,
tlv: Tlv,
}

impl<BODY> HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
pub(super) fn new(tlv: Tlv, job: BODY) -> Box<Self> {
Box::new(HeapJob { job, tlv })
Expand All @@ -168,27 +178,28 @@ where

impl<BODY> Job for HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
unsafe fn execute(this: *const ()) {
let pointer = this as usize;
let this = Box::from_raw(this as *mut Self);
tlv::set(this.tlv);
(this.job)();
(this.job)(JobRefId { pointer });
}
}

/// Represents a job stored in an `Arc` -- like `HeapJob`, but may
/// be turned into multiple `JobRef`s and called multiple times.
pub(super) struct ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
BODY: Fn(JobRefId) + Send + Sync,
{
job: BODY,
}

impl<BODY> ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
BODY: Fn(JobRefId) + Send + Sync,
{
pub(super) fn new(job: BODY) -> Arc<Self> {
Arc::new(ArcJob { job })
Expand All @@ -212,11 +223,12 @@ where

impl<BODY> Job for ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
BODY: Fn(JobRefId) + Send + Sync,
{
unsafe fn execute(this: *const ()) {
let pointer = this as usize;
let this = Arc::from_raw(this as *mut Self);
(this.job)();
(this.job)(JobRefId { pointer });
}
}

Expand Down
82 changes: 28 additions & 54 deletions rayon-core/src/join/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::job::JobRef;
use crate::job::StackJob;
use crate::latch::SpinLatch;
use crate::registry::{self, WorkerThread};
use crate::tlv::{self, Tlv};
use crate::registry;
use crate::tlv;
use crate::unwind;
use std::any::Any;
use std::sync::atomic::{AtomicBool, Ordering};

use crate::FnContext;

Expand Down Expand Up @@ -135,68 +136,41 @@ where
// Create virtual wrapper for task b; this all has to be
// done here so that the stack frame can keep it all live
// long enough.
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new(worker_thread));
let job_b_started = AtomicBool::new(false);
let job_b = StackJob::new(
tlv,
|migrated| {
job_b_started.store(true, Ordering::Relaxed);
call_b(oper_b)(migrated)
},
SpinLatch::new(worker_thread),
);
let job_b_ref = job_b.as_job_ref();
let job_b_id = job_b_ref.id();
worker_thread.push(job_b_ref);

// Execute task a; hopefully b gets stolen in the meantime.
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
let result_a = match status_a {
Ok(v) => v,
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
};

// Now that task A has finished, try to pop job B from the
// local stack. It may already have been popped by job A; it
// may also have been stolen. There may also be some tasks
// pushed on top of it in the stack, and we will have to pop
// those off to get to it.
while !job_b.latch.probe() {
if let Some(job) = worker_thread.take_local_job() {
if job_b_id == job.id() {
// Found it! Let's run it.
//
// Note that this could panic, but it's ok if we unwind here.

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

let result_b = job_b.run_inline(injected);
return (result_a, result_b);
} else {
worker_thread.execute(job);
}
} else {
// Local deque is empty. Time to steal from other
// threads.
worker_thread.wait_until(&job_b.latch);
debug_assert!(job_b.latch.probe());
break;
}
}
// Wait for job B or execute it if it's in the local queue.
worker_thread.wait_for_jobs::<_, false>(
&job_b.latch,
|| job_b_started.load(Ordering::Relaxed),
|job| job.id() == job_b_id,
|job: JobRef| {
debug_assert_eq!(job.id(), job_b_id);
job_b.run_inline(injected);
},
);

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

let result_a = match status_a {
Ok(v) => v,
Err(err) => unwind::resume_unwinding(err),
};

(result_a, job_b.into_result())
})
}

/// If job A panics, we still cannot return until we are sure that job
/// B is complete. This is because it may contain references into the
/// enclosing stack frame(s).
#[cold] // cold path
unsafe fn join_recover_from_panic(
worker_thread: &WorkerThread,
job_b_latch: &SpinLatch<'_>,
err: Box<dyn Any + Send>,
tlv: Tlv,
) -> ! {
worker_thread.wait_until(job_b_latch);

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

unwind::resume_unwinding(err)
}
1 change: 1 addition & 0 deletions rayon-core/src/join/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ fn join_context_both() {
}

#[test]
#[ignore]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn join_context_neither() {
// If we're already in a 1-thread pool, neither job should be stolen.
Expand Down
5 changes: 0 additions & 5 deletions rayon-core/src/latch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,6 @@ impl<'r> SpinLatch<'r> {
..SpinLatch::new(thread)
}
}

#[inline]
pub(super) fn probe(&self) -> bool {
self.core_latch.probe()
}
}

impl<'r> AsCoreLatch for SpinLatch<'r> {
Expand Down
Loading

0 comments on commit b72936d

Please sign in to comment.