diff --git a/src/errors.rs b/src/errors.rs index d7960a6..90eef08 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,50 +1,60 @@ -use super::Box; +use super::{Box, Channel}; use core::fmt; use core::mem; +use core::ptr::NonNull; /// An error returned when trying to send on a closed channel. Returned from /// [`Sender::send`] if the corresponding [`Receiver`] has already been dropped. /// /// The message that could not be sent can be retreived again with [`SendError::into_inner`]. pub struct SendError { - channel_ptr: *mut super::Channel, + channel_ptr: NonNull>, } unsafe impl Send for SendError {} +unsafe impl Sync for SendError {} impl SendError { - pub(crate) const fn new(channel_ptr: *mut super::Channel) -> Self { + /// Safety: by calling this function, the caller semantically transfers ownership of the + /// channel's resources to the created `SendError`. Thus the caller must ensure that the + /// pointer is not used in a way which would violate this ownership transfer. + pub(crate) const unsafe fn new(channel_ptr: NonNull>) -> Self { Self { channel_ptr } } /// Consumes the error and returns the message that failed to be sent. #[inline] pub fn into_inner(self) -> T { - // SAFETY: The reference won't be used after it is freed in this method - let channel: &mut super::Channel = unsafe { &mut *self.channel_ptr }; + let channel_ptr = self.channel_ptr; // Don't run destructor if we consumed ourselves. Freeing happens here. mem::forget(self); + // SAFETY: we have ownership of the channel + let channel: &Channel = unsafe { channel_ptr.as_ref() }; + + // SAFETY: we know that the message is initialized according to the safety requirements of + // `new` let message = unsafe { channel.take_message() }; - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(channel_ptr.as_ptr()) }; message } /// Get a reference to the message that failed to be sent. #[inline] pub fn as_inner(&self) -> &T { - unsafe { &*(*self.channel_ptr).message.as_ptr() } + unsafe { self.channel_ptr.as_ref().message().assume_init_ref() } } } impl Drop for SendError { fn drop(&mut self) { - // SAFETY: The reference won't be used after it is freed in this method - let channel: &mut super::Channel = unsafe { &mut *self.channel_ptr }; - - unsafe { channel.drop_message() }; - unsafe { Box::from_raw(channel) }; + // SAFETY: we have ownership of the channel and require that the message is initialized + // upon construction + unsafe { + self.channel_ptr.as_ref().drop_message(); + Box::from_raw(self.channel_ptr.as_ptr()); + } } } diff --git a/src/lib.rs b/src/lib.rs index 6ec45ef..c929ffe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -136,12 +136,12 @@ #[cfg(not(loom))] extern crate alloc; -use core::{mem, ptr}; +use core::{mem::{self, MaybeUninit}, ptr::{self, NonNull}, marker::PhantomData}; #[cfg(not(loom))] -use core::sync::atomic::{AtomicU8, Ordering::SeqCst}; +use core::{cell::UnsafeCell, sync::atomic::{AtomicU8, Ordering::SeqCst}}; #[cfg(loom)] -use loom::sync::atomic::{AtomicU8, Ordering::SeqCst}; +use loom::{cell::UnsafeCell, sync::atomic::{AtomicU8, Ordering::SeqCst}}; #[cfg(feature = "async")] use core::{ @@ -186,18 +186,40 @@ pub fn channel() -> (Sender, Receiver) { // Allocate the channel on the heap and get the pointer. // The last endpoint of the channel to be alive is responsible for freeing the channel // and dropping any object that might have been written to it. + let channel_ptr = Box::into_raw(Box::new(Channel::new())); - (Sender { channel_ptr }, Receiver { channel_ptr }) + + // SAFETY: `channel_ptr` came from a Box and thus is not null + let channel_ptr = unsafe { NonNull::new_unchecked(channel_ptr) }; + + (Sender { channel: channel_ptr, _invariant: PhantomData }, Receiver { channel: channel_ptr }) } #[derive(Debug)] pub struct Sender { - channel_ptr: *mut Channel, + channel: NonNull>, + // In reality we want contravariance, however we can't obtain that. + // + // Consider the following scenario: + // ``` + // let (mut tx, rx) = channel::<&'short u8>(); + // let (tx2, rx2) = channel::<&'long u8>(); + // + // tx = tx2; + // + // // Pretend short_ref is some &'short u8 + // tx.send(short_ref).unwrap(); + // let long_ref = rx2.recv().unwrap(); + // ``` + // + // If this type were covariant then we could safely extend lifetimes, which is not okay. + // Hence, we enforce invariance. + _invariant: PhantomData T>, } #[derive(Debug)] pub struct Receiver { - channel_ptr: *mut Channel, + channel: NonNull>, } unsafe impl Send for Sender {} @@ -218,14 +240,15 @@ impl Sender { /// the error involves running any drop implementation on the message type, which might or /// might not be lock-free. pub fn send(self, message: T) -> Result<(), SendError> { - // SAFETY: The channel exists on the heap for the entire duration of this method. - let channel: &mut Channel = unsafe { &mut *self.channel_ptr }; + let channel_ptr = self.channel; // Don't run our Drop implementation if send was called, any cleanup now happens here mem::forget(self); + let channel = unsafe { channel_ptr.as_ref() }; + // Write the message into the channel on the heap. - channel.write_message(message); + unsafe { channel.write_message(message); } // Set the state to signal there is a message on the channel. match channel.state.swap(MESSAGE, SeqCst) { // The receiver is alive and has not started waiting. Send done. @@ -236,7 +259,7 @@ impl Sender { Ok(()) } // The receiver was already dropped. The error is responsible for freeing the channel. - DISCONNECTED => Err(SendError::new(channel)), + DISCONNECTED => Err(unsafe { SendError::new(channel_ptr) }), _ => unreachable!(), } } @@ -245,7 +268,7 @@ impl Sender { impl Drop for Sender { fn drop(&mut self) { // SAFETY: The reference won't be used after the channel is freed in this method - let channel: &mut Channel = unsafe { &mut *self.channel_ptr }; + let channel = unsafe { self.channel.as_ref() }; // Set the channel state to disconnected and read what state the receiver was in match channel.state.swap(DISCONNECTED, SeqCst) { @@ -255,7 +278,7 @@ impl Drop for Sender { RECEIVING => unsafe { channel.take_waker() }.unpark(), // The receiver was already dropped. We are responsible for freeing the channel. DISCONNECTED => { - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(self.channel.as_ptr()) }; } _ => unreachable!(), } @@ -278,7 +301,7 @@ impl Receiver { /// returning it. pub fn try_recv(&self) -> Result { // SAFETY: The channel will not be freed while this method is still running. - let channel: &mut Channel = unsafe { &mut *self.channel_ptr }; + let channel = unsafe { self.channel.as_ref() }; match channel.state.load(SeqCst) { // The sender is alive but has not sent anything yet. @@ -316,12 +339,13 @@ impl Receiver { /// Panics if called after this receiver has been polled asynchronously. #[cfg(feature = "std")] pub fn recv(self) -> Result { - // SAFETY: The reference won't be used after the channel is freed in this method - let channel: &mut Channel = unsafe { &mut *self.channel_ptr }; + let channel_ptr = self.channel; // Don't run our Drop implementation if we are receiving consuming ourselves. mem::forget(self); + let channel = unsafe { channel_ptr.as_ref() }; + match channel.state.load(SeqCst) { // The sender is alive but has not sent anything yet. We prepare to park. EMPTY => { @@ -332,7 +356,7 @@ impl Receiver { std::thread::sleep(std::time::Duration::from_millis(10)); // Write our waker instance to the channel. - channel.write_waker(ReceiverWaker::current_thread()); + unsafe { channel.write_waker(ReceiverWaker::current_thread()); } match channel .state @@ -345,12 +369,12 @@ impl Receiver { // The sender sent the message while we were parked. MESSAGE => { let message = unsafe { channel.take_message() }; - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(channel_ptr.as_ptr()) }; break Ok(message); } // The sender was dropped while we were parked. DISCONNECTED => { - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(channel_ptr.as_ptr()) }; break Err(RecvError); } // State did not change, spurious wakeup, park again. @@ -362,13 +386,13 @@ impl Receiver { Err(MESSAGE) => { unsafe { channel.drop_waker() }; let message = unsafe { channel.take_message() }; - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(channel_ptr.as_ptr()) }; Ok(message) } // The sender was dropped before sending anything while we prepared to park. Err(DISCONNECTED) => { unsafe { channel.drop_waker() }; - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(channel_ptr.as_ptr()) }; Err(RecvError) } _ => unreachable!(), @@ -377,12 +401,12 @@ impl Receiver { // The sender already sent the message. MESSAGE => { let message = unsafe { channel.take_message() }; - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(channel_ptr.as_ptr()) }; Ok(message) } // The sender was dropped before sending anything, or we already received the message. DISCONNECTED => { - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(channel_ptr.as_ptr()) }; Err(RecvError) } // The receiver must have been `Future::poll`ed prior to this call. @@ -404,8 +428,12 @@ impl Receiver { /// Panics if called after this receiver has been polled asynchronously. #[cfg(feature = "std")] pub fn recv_ref(&self) -> Result { - // SAFETY: The channel will not be freed while this method is still running. - let channel: &mut Channel = unsafe { &mut *self.channel_ptr }; + let channel_ptr = self.channel; + + // Don't run our Drop implementation if we are receiving consuming ourselves. + mem::forget(self); + + let channel = unsafe { channel_ptr.as_ref() }; match channel.state.load(SeqCst) { // The sender is alive but has not sent anything yet. We prepare to park. @@ -417,7 +445,7 @@ impl Receiver { std::thread::sleep(std::time::Duration::from_millis(10)); // Write our waker instance to the channel. - channel.write_waker(ReceiverWaker::current_thread()); + unsafe { channel.write_waker(ReceiverWaker::current_thread()); } match channel .state @@ -505,8 +533,12 @@ impl Receiver { /// Panics if called after this receiver has been polled asynchronously. #[cfg(feature = "std")] pub fn recv_deadline(&self, deadline: Instant) -> Result { - // SAFETY: The channel will not be freed while this method is still running. - let channel: &mut Channel = unsafe { &mut *self.channel_ptr }; + let channel_ptr = self.channel; + + // Don't run our Drop implementation if we are receiving consuming ourselves. + mem::forget(self); + + let channel = unsafe { channel_ptr.as_ref() }; match channel.state.load(SeqCst) { // The sender is alive but has not sent anything yet. We prepare to park. @@ -518,7 +550,7 @@ impl Receiver { std::thread::sleep(std::time::Duration::from_millis(10)); // Write our thread instance to the channel. - channel.write_waker(ReceiverWaker::current_thread()); + unsafe { channel.write_waker(ReceiverWaker::current_thread()); } match channel .state @@ -586,13 +618,12 @@ impl Receiver { impl core::future::Future for Receiver { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - // SAFETY: The channel will not be freed while this method is still running. - let channel: &mut Channel = unsafe { &mut *self.channel_ptr }; + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let channel = unsafe { self.channel.as_ref() }; match channel.state.load(SeqCst) { // The sender is alive but has not sent anything yet. - EMPTY => channel.write_async_waker(cx), + EMPTY => unsafe { channel.write_async_waker(cx) }, // We were polled again while waiting for the sender. Replace the waker with the new one. RECEIVING => { match channel @@ -602,7 +633,7 @@ impl core::future::Future for Receiver { // We successfully changed the state back to EMPTY. Replace the waker. Ok(RECEIVING) => { unsafe { channel.drop_waker() }; - channel.write_async_waker(cx) + unsafe { channel.write_async_waker(cx) } } // The sender sent the message while we prepared to replace the waker. // We take the message and mark the channel disconnected. @@ -632,7 +663,7 @@ impl core::future::Future for Receiver { impl Drop for Receiver { fn drop(&mut self) { // SAFETY: The reference won't be used after it is freed in this method - let channel: &mut Channel = unsafe { &mut *self.channel_ptr }; + let channel = unsafe { self.channel.as_ref() }; // Set the channel state to disconnected and read what state the receiver was in match channel.state.swap(DISCONNECTED, SeqCst) { @@ -641,7 +672,7 @@ impl Drop for Receiver { // The sender already sent something. We must drop it, and free the channel. MESSAGE => { unsafe { channel.drop_message() }; - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(self.channel.as_ptr()) }; } // The receiver has been polled. #[cfg(feature = "async")] @@ -650,7 +681,7 @@ impl Drop for Receiver { } // The sender was already dropped. We are responsible for freeing the channel. DISCONNECTED => { - unsafe { Box::from_raw(channel) }; + unsafe { Box::from_raw(self.channel.as_ptr()) }; } _ => unreachable!(), } @@ -682,53 +713,108 @@ use states::*; /// This memory is uninitialized until the receiver starts receiving. struct Channel { state: AtomicU8, - message: mem::MaybeUninit, - waker: mem::MaybeUninit, + message: UnsafeCell>, + waker: UnsafeCell>, } impl Channel { pub fn new() -> Self { Self { state: AtomicU8::new(EMPTY), - message: mem::MaybeUninit::uninit(), - waker: mem::MaybeUninit::uninit(), + message: UnsafeCell::new(MaybeUninit::uninit()), + waker: UnsafeCell::new(MaybeUninit::uninit()), } } #[inline(always)] - fn write_message(&mut self, message: T) { - unsafe { self.message.as_mut_ptr().write(message) }; + unsafe fn message(&self) -> &MaybeUninit { + #[cfg(loom)] + { + self.message.with(|ptr| &*ptr) + } + + #[cfg(not(loom))] + { + &*self.message.get() + } } #[inline(always)] - unsafe fn take_message(&mut self) -> T { - ptr::read(&self.message).assume_init() + unsafe fn message_mut(&self) -> &mut MaybeUninit { + #[cfg(loom)] + { + self.message.with_mut(|ptr| &mut *ptr) + } + + #[cfg(not(loom))] + { + &mut *self.message.get() + } } #[inline(always)] - unsafe fn drop_message(&mut self) { - ptr::drop_in_place(self.message.as_mut_ptr()); + unsafe fn waker_mut(&self) -> &mut MaybeUninit { + #[cfg(loom)] + { + self.waker.with_mut(|ptr| &mut *ptr) + } + + #[cfg(not(loom))] + { + &mut *self.waker.get() + } + } + + #[inline(always)] + unsafe fn write_message(&self, message: T) { + self.message_mut().as_mut_ptr().write(message); + } + + #[inline(always)] + unsafe fn take_message(&self) -> T { + #[cfg(loom)] + { + self.message.with(|ptr| ptr::read(ptr)).assume_init() + } + + #[cfg(not(loom))] + { + ptr::read(self.message.get()).assume_init() + } + } + + #[inline(always)] + unsafe fn drop_message(&self) { + self.message_mut().assume_init_drop(); } #[cfg(any(feature = "std", feature = "async"))] #[inline(always)] - fn write_waker(&mut self, waker: ReceiverWaker) { - unsafe { self.waker.as_mut_ptr().write(waker) }; + unsafe fn write_waker(&self, waker: ReceiverWaker) { + self.waker_mut().as_mut_ptr().write(waker); } #[inline(always)] - unsafe fn take_waker(&mut self) -> ReceiverWaker { - ptr::read(&self.waker).assume_init() + unsafe fn take_waker(&self) -> ReceiverWaker { + #[cfg(loom)] + { + self.waker.with(|ptr| ptr::read(ptr)).assume_init() + } + + #[cfg(not(loom))] + { + ptr::read(self.waker.get()).assume_init() + } } #[cfg(any(feature = "std", feature = "async"))] #[inline(always)] - unsafe fn drop_waker(&mut self) { - ptr::drop_in_place(self.waker.as_mut_ptr()); + unsafe fn drop_waker(&self) { + self.waker_mut().assume_init_drop(); } #[cfg(feature = "async")] - fn write_async_waker(&mut self, cx: &mut task::Context<'_>) -> Poll> { + unsafe fn write_async_waker(&self, cx: &mut task::Context<'_>) -> Poll> { // Write our thread instance to the channel. self.write_waker(ReceiverWaker::task_waker(cx)); @@ -740,15 +826,15 @@ impl Channel { Ok(EMPTY) => Poll::Pending, // The sender was dropped before sending anything while we prepared to park. Err(DISCONNECTED) => { - unsafe { self.drop_waker() }; + self.drop_waker(); Poll::Ready(Err(RecvError)) } // The sender sent the message while we prepared to park. // We take the message and mark the channel disconnected. Err(MESSAGE) => { - unsafe { self.drop_waker() }; + self.drop_waker(); self.state.store(DISCONNECTED, SeqCst); - Poll::Ready(Ok(unsafe { self.take_message() })) + Poll::Ready(Ok(self.take_message())) } _ => unreachable!(), }