diff --git a/library/std/src/sys/sgx/abi/mod.rs b/library/std/src/sys/sgx/abi/mod.rs index a5e453034762c..f9536c4203df2 100644 --- a/library/std/src/sys/sgx/abi/mod.rs +++ b/library/std/src/sys/sgx/abi/mod.rs @@ -62,10 +62,12 @@ unsafe extern "C" fn tcs_init(secondary: bool) { extern "C" fn entry(p1: u64, p2: u64, p3: u64, secondary: bool, p4: u64, p5: u64) -> EntryReturn { // FIXME: how to support TLS in library mode? let tls = Box::new(tls::Tls::new()); - let _tls_guard = unsafe { tls.activate() }; + let tls_guard = unsafe { tls.activate() }; if secondary { - super::thread::Thread::entry(); + let join_notifier = super::thread::Thread::entry(); + drop(tls_guard); + drop(join_notifier); EntryReturn(0, 0) } else { diff --git a/library/std/src/sys/sgx/thread.rs b/library/std/src/sys/sgx/thread.rs index 55ef460cc90c5..67e2e8b59d397 100644 --- a/library/std/src/sys/sgx/thread.rs +++ b/library/std/src/sys/sgx/thread.rs @@ -9,26 +9,37 @@ pub struct Thread(task_queue::JoinHandle); pub const DEFAULT_MIN_STACK_SIZE: usize = 4096; +pub use self::task_queue::JoinNotifier; + mod task_queue { - use crate::sync::mpsc; + use super::wait_notify; use crate::sync::{Mutex, MutexGuard, Once}; - pub type JoinHandle = mpsc::Receiver<()>; + pub type JoinHandle = wait_notify::Waiter; + + pub struct JoinNotifier(Option); + + impl Drop for JoinNotifier { + fn drop(&mut self) { + self.0.take().unwrap().notify(); + } + } pub(super) struct Task { p: Box, - done: mpsc::Sender<()>, + done: JoinNotifier, } impl Task { pub(super) fn new(p: Box) -> (Task, JoinHandle) { - let (done, recv) = mpsc::channel(); + let (done, recv) = wait_notify::new(); + let done = JoinNotifier(Some(done)); (Task { p, done }, recv) } - pub(super) fn run(self) { + pub(super) fn run(self) -> JoinNotifier { (self.p)(); - let _ = self.done.send(()); + self.done } } @@ -47,6 +58,48 @@ mod task_queue { } } +/// This module provides a synchronization primitive that does not use thread +/// local variables. This is needed for signaling that a thread has finished +/// execution. The signal is sent once all TLS destructors have finished at +/// which point no new thread locals should be created. +pub mod wait_notify { + use super::super::waitqueue::{SpinMutex, WaitQueue, WaitVariable}; + use crate::sync::Arc; + + pub struct Notifier(Arc>>); + + impl Notifier { + /// Notify the waiter. The waiter is either notified right away (if + /// currently blocked in `Waiter::wait()`) or later when it calls the + /// `Waiter::wait()` method. + pub fn notify(self) { + let mut guard = self.0.lock(); + *guard.lock_var_mut() = true; + let _ = WaitQueue::notify_one(guard); + } + } + + pub struct Waiter(Arc>>); + + impl Waiter { + /// Wait for a notification. If `Notifier::notify()` has already been + /// called, this will return immediately, otherwise the current thread + /// is blocked until notified. + pub fn wait(self) { + let guard = self.0.lock(); + if *guard.lock_var() { + return; + } + WaitQueue::wait(guard, || {}); + } + } + + pub fn new() -> (Notifier, Waiter) { + let inner = Arc::new(SpinMutex::new(WaitVariable::new(false))); + (Notifier(inner.clone()), Waiter(inner)) + } +} + impl Thread { // unsafe: see thread::Builder::spawn_unchecked for safety requirements pub unsafe fn new(_stack: usize, p: Box) -> io::Result { @@ -57,7 +110,7 @@ impl Thread { Ok(Thread(handle)) } - pub(super) fn entry() { + pub(super) fn entry() -> JoinNotifier { let mut pending_tasks = task_queue::lock(); let task = rtunwrap!(Some, pending_tasks.pop()); drop(pending_tasks); // make sure to not hold the task queue lock longer than necessary @@ -78,7 +131,7 @@ impl Thread { } pub fn join(self) { - let _ = self.0.recv(); + self.0.wait(); } } diff --git a/library/std/src/thread/local/tests.rs b/library/std/src/thread/local/tests.rs index 80e6798d847b1..f33d612961931 100644 --- a/library/std/src/thread/local/tests.rs +++ b/library/std/src/thread/local/tests.rs @@ -1,4 +1,5 @@ use crate::cell::{Cell, UnsafeCell}; +use crate::sync::atomic::{AtomicU8, Ordering}; use crate::sync::mpsc::{channel, Sender}; use crate::thread::{self, LocalKey}; use crate::thread_local; @@ -207,3 +208,110 @@ fn dtors_in_dtors_in_dtors_const_init() { }); rx.recv().unwrap(); } + +// This test tests that TLS destructors have run before the thread joins. The +// test has no false positives (meaning: if the test fails, there's actually +// an ordering problem). It may have false negatives, where the test passes but +// join is not guaranteed to be after the TLS destructors. However, false +// negatives should be exceedingly rare due to judicious use of +// thread::yield_now and running the test several times. +#[test] +fn join_orders_after_tls_destructors() { + // We emulate a synchronous MPSC rendezvous channel using only atomics and + // thread::yield_now. We can't use std::mpsc as the implementation itself + // may rely on thread locals. + // + // The basic state machine for an SPSC rendezvous channel is: + // FRESH -> THREAD1_WAITING -> MAIN_THREAD_RENDEZVOUS + // where the first transition is done by the “receiving” thread and the 2nd + // transition is done by the “sending” thread. + // + // We add an additional state `THREAD2_LAUNCHED` between `FRESH` and + // `THREAD1_WAITING` to block until all threads are actually running. + // + // A thread that joins on the “receiving” thread completion should never + // observe the channel in the `THREAD1_WAITING` state. If this does occur, + // we switch to the “poison” state `THREAD2_JOINED` and panic all around. + // (This is equivalent to “sending” from an alternate producer thread.) + const FRESH: u8 = 0; + const THREAD2_LAUNCHED: u8 = 1; + const THREAD1_WAITING: u8 = 2; + const MAIN_THREAD_RENDEZVOUS: u8 = 3; + const THREAD2_JOINED: u8 = 4; + static SYNC_STATE: AtomicU8 = AtomicU8::new(FRESH); + + for _ in 0..10 { + SYNC_STATE.store(FRESH, Ordering::SeqCst); + + let jh = thread::Builder::new() + .name("thread1".into()) + .spawn(move || { + struct TlDrop; + + impl Drop for TlDrop { + fn drop(&mut self) { + let mut sync_state = SYNC_STATE.swap(THREAD1_WAITING, Ordering::SeqCst); + loop { + match sync_state { + THREAD2_LAUNCHED | THREAD1_WAITING => thread::yield_now(), + MAIN_THREAD_RENDEZVOUS => break, + THREAD2_JOINED => panic!( + "Thread 1 still running after thread 2 joined on thread 1" + ), + v => unreachable!("sync state: {}", v), + } + sync_state = SYNC_STATE.load(Ordering::SeqCst); + } + } + } + + thread_local! { + static TL_DROP: TlDrop = TlDrop; + } + + TL_DROP.with(|_| {}); + + loop { + match SYNC_STATE.load(Ordering::SeqCst) { + FRESH => thread::yield_now(), + THREAD2_LAUNCHED => break, + v => unreachable!("sync state: {}", v), + } + } + }) + .unwrap(); + + let jh2 = thread::Builder::new() + .name("thread2".into()) + .spawn(move || { + assert_eq!(SYNC_STATE.swap(THREAD2_LAUNCHED, Ordering::SeqCst), FRESH); + jh.join().unwrap(); + match SYNC_STATE.swap(THREAD2_JOINED, Ordering::SeqCst) { + MAIN_THREAD_RENDEZVOUS => return, + THREAD2_LAUNCHED | THREAD1_WAITING => { + panic!("Thread 2 running after thread 1 join before main thread rendezvous") + } + v => unreachable!("sync state: {:?}", v), + } + }) + .unwrap(); + + loop { + match SYNC_STATE.compare_exchange_weak( + THREAD1_WAITING, + MAIN_THREAD_RENDEZVOUS, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => break, + Err(FRESH) => thread::yield_now(), + Err(THREAD2_LAUNCHED) => thread::yield_now(), + Err(THREAD2_JOINED) => { + panic!("Main thread rendezvous after thread 2 joined thread 1") + } + v => unreachable!("sync state: {:?}", v), + } + } + jh2.join().unwrap(); + } +}