From 976544eebdf28679a618e3e81b528d916a26f0cb Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Fri, 15 Oct 2021 16:26:34 -0500 Subject: [PATCH] Fix inmemory stream getting stuck --- distant-core/src/net/transport/inmemory.rs | 74 ++++++++++++++++------ 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/distant-core/src/net/transport/inmemory.rs b/distant-core/src/net/transport/inmemory.rs index be7b50ad..b6e7ec3c 100644 --- a/distant-core/src/net/transport/inmemory.rs +++ b/distant-core/src/net/transport/inmemory.rs @@ -1,5 +1,8 @@ use super::{DataStream, PlainCodec, Transport}; +use futures::ready; use std::{ + fmt, + future::Future, pin::Pin, task::{Context, Poll}, }; @@ -116,51 +119,86 @@ impl AsyncRead for InmemoryStreamReadHalf { } // Otherwise, we poll for the next batch to read in - self.rx.poll_recv(cx).map(|x| match x { + match ready!(self.rx.poll_recv(cx)) { Some(mut x) => { if x.len() > buf.remaining() { self.overflow = x.split_off(buf.remaining()); } buf.put_slice(&x); - Ok(()) + Poll::Ready(Ok(())) } - None => Ok(()), - }) + None => Poll::Ready(Ok(())), + } } } /// Write portion of an inmemory channel -#[derive(Debug)] -pub struct InmemoryStreamWriteHalf(mpsc::Sender>); +pub struct InmemoryStreamWriteHalf { + tx: Option>>, + task: Option> + Send + Sync + 'static>>>, +} impl InmemoryStreamWriteHalf { pub fn new(tx: mpsc::Sender>) -> Self { - Self(tx) + Self { + tx: Some(tx), + task: None, + } + } +} + +impl fmt::Debug for InmemoryStreamWriteHalf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("InmemoryStreamWriteHalf") + .field("tx", &self.tx) + .field( + "task", + &if self.tx.is_some() { + "Some(...)" + } else { + "None" + }, + ) + .finish() } } impl AsyncWrite for InmemoryStreamWriteHalf { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - use futures::FutureExt; - let n = buf.len(); - let f = self.0.send(buf.to_vec()).map(|x| match x { - Ok(_) => Ok(n), - Err(_) => Ok(0), - }); - tokio::pin!(f); - f.poll_unpin(cx) + loop { + match self.task.as_mut() { + Some(task) => { + let res = ready!(task.as_mut().poll(cx)); + self.task.take(); + return Poll::Ready(res); + } + None => match self.tx.as_mut() { + Some(tx) => { + let n = buf.len(); + let tx_2 = tx.clone(); + let data = buf.to_vec(); + let task = + Box::pin(async move { tx_2.send(data).await.map(|_| n).or(Ok(0)) }); + self.task.replace(task); + } + None => return Poll::Ready(Ok(0)), + }, + } + } } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.tx.take(); + self.task.take(); + Poll::Ready(Ok(())) } }