Skip to content

Commit

Permalink
TcpSocket split
Browse files Browse the repository at this point in the history
  • Loading branch information
blckngm committed Jun 28, 2019
1 parent e4415d9 commit 6f266e2
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 41 deletions.
2 changes: 1 addition & 1 deletion tokio-tcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ futures-core-preview = { version = "0.3.0-alpha.16", optional = true }
[dev-dependencies]
#env_logger = { version = "0.5", default-features = false }
#net2 = "*"
#tokio = { version = "0.2.0", path = "../tokio" }
tokio = { version = "0.2.0", path = "../tokio" }
1 change: 1 addition & 0 deletions tokio-tcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ macro_rules! ready {
#[cfg(feature = "incoming")]
mod incoming;
mod listener;
pub mod split;
mod stream;

pub use self::listener::TcpListener;
Expand Down
142 changes: 142 additions & 0 deletions tokio-tcp/src/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
//! `TcpStream` split support.
//!
//! A `TcpStream` can be split into a `TcpStreamReadHalf` and a
//! `TcpStreamWriteHalf` with the `TcpStream::split` method. `TcpStreamReadHalf`
//! implements `AsyncRead` while `TcpStreamWriteHalf` implements `AsyncWrite`.
//! The two halves can be used concurrently, even from multiple tasks.
//!
//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
//! split gives read and write halves that are faster and smaller, because they
//! do not use locks. They also provide access to the underlying `TcpStream`
//! after split, implementing `AsRef<TcpStream>`. This allows you to call
//! `TcpStream` methods that takes `&self`, e.g., to get local and peer
//! addresses, to get and set socket options, and to shutdown the sockets.
use super::TcpStream;
use bytes::{Buf, BufMut};
use std::error::Error;
use std::fmt;
use std::io;
use std::net::Shutdown;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite};

/// Read half of a `TcpStream`.
#[derive(Debug)]
pub struct TcpStreamReadHalf(Arc<TcpStream>);

/// Write half of a `TcpStream`.
#[derive(Debug)]
pub struct TcpStreamWriteHalf(Arc<TcpStream>);

pub(crate) fn split(stream: TcpStream) -> (TcpStreamReadHalf, TcpStreamWriteHalf) {
let shared = Arc::new(stream);
(
TcpStreamReadHalf(shared.clone()),
TcpStreamWriteHalf(shared),
)
}

/// Error indicating two halves were not from the same stream, and thus could
/// not be `reunite`d.
#[derive(Debug)]
pub struct ReuniteError(pub TcpStreamReadHalf, pub TcpStreamWriteHalf);

impl fmt::Display for ReuniteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same stream"
)
}
}

impl Error for ReuniteError {}

impl TcpStreamReadHalf {
/// Attempts to put the two "halves" of a `TcpStream` back together and
/// recover the original stream. Succeeds only if the two "halves"
/// originated from the same call to `TcpStream::split`.
pub fn reunite(self, other: TcpStreamWriteHalf) -> Result<TcpStream, ReuniteError> {
if Arc::ptr_eq(&self.0, &other.0) {
drop(other);
Ok(Arc::try_unwrap(self.0).unwrap())
} else {
Err(ReuniteError(self, other))
}
}
}

impl TcpStreamWriteHalf {
/// Attempts to put the two "halves" of a `TcpStream` back together and
/// recover the original stream. Succeeds only if the two "halves"
/// originated from the same call to `TcpStream::split`.
pub fn reunite(self, other: TcpStreamReadHalf) -> Result<TcpStream, ReuniteError> {
other.reunite(self)
}
}

impl AsRef<TcpStream> for TcpStreamReadHalf {
fn as_ref(&self) -> &TcpStream {
&self.0
}
}

impl AsRef<TcpStream> for TcpStreamWriteHalf {
fn as_ref(&self) -> &TcpStream {
&self.0
}
}

impl AsyncRead for TcpStreamReadHalf {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.0.poll_read_priv(cx, buf)
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.0.poll_read_buf_priv(cx, buf)
}
}

impl AsyncWrite for TcpStreamWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.0.poll_write_priv(cx, buf)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
Poll::Ready(Ok(()))
}

// `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.shutdown(Shutdown::Write).into()
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.0.poll_write_buf_priv(cx, buf)
}
}
147 changes: 108 additions & 39 deletions tokio-tcp/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::split::{split, TcpStreamReadHalf, TcpStreamWriteHalf};
use bytes::{Buf, BufMut};
use iovec::IoVec;
use mio;
use std::convert::TryFrom;
use std::fmt;
use std::future::Future;
use std::io;
use std::io::{self, Read, Write};
use std::mem;
use std::net::{self, Shutdown, SocketAddr};
use std::pin::Pin;
Expand Down Expand Up @@ -712,37 +713,45 @@ impl TcpStream {
let msg = "`TcpStream::try_clone()` is deprecated because it doesn't work as intended";
Err(io::Error::new(io::ErrorKind::Other, msg))
}
}

impl TryFrom<TcpStream> for mio::net::TcpStream {
type Error = io::Error;

/// Consumes value, returning the mio I/O object.
/// Split a `TcpStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently.
///
/// See [`tokio_reactor::PollEvented::into_inner`] for more details about
/// resource deregistration that happens during the call.
fn try_from(value: TcpStream) -> Result<Self, Self::Error> {
value.io.into_inner()
/// See the module level documenation of [`split`](super::split) for more
/// details.
pub fn split(self) -> (TcpStreamReadHalf, TcpStreamWriteHalf) {
split(self)
}
}

// ===== impl Read / Write =====

impl AsyncRead for TcpStream {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}

fn poll_read(
mut self: Pin<&mut Self>,
// == Poll IO functions that takes `&self` ==
//
// They are not public because (taken from the doc of `PollEvented`):
//
// While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the
// caller must ensure that there are at most two tasks that use a
// `PollEvented` instance concurrently. One for reading and one for writing.
// While violating this requirement is "safe" from a Rust memory model point
// of view, it will result in unexpected behavior in the form of lost
// notifications and tasks hanging.

pub(crate) fn poll_read_priv(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_read(cx, buf)
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;

match self.io.get_ref().read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_read_ready(cx, mio::Ready::readable())?;
Poll::Pending
}
x => Poll::Ready(x),
}
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
pub(crate) fn poll_read_buf_priv<B: BufMut>(
&self,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -804,29 +813,25 @@ impl AsyncRead for TcpStream {
Err(e) => Poll::Ready(Err(e)),
}
}
}

impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
pub(crate) fn poll_write_priv(
&self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write(cx, buf)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
Poll::Ready(Ok(()))
}
ready!(self.io.poll_write_ready(cx))?;

fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
match self.io.get_ref().write(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_write_ready(cx)?;
Poll::Pending
}
x => Poll::Ready(x),
}
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
pub(crate) fn poll_write_buf_priv<B: Buf>(
&self,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -856,6 +861,70 @@ impl AsyncWrite for TcpStream {
}
}

impl TryFrom<TcpStream> for mio::net::TcpStream {
type Error = io::Error;

/// Consumes value, returning the mio I/O object.
///
/// See [`tokio_reactor::PollEvented::into_inner`] for more details about
/// resource deregistration that happens during the call.
fn try_from(value: TcpStream) -> Result<Self, Self::Error> {
value.io.into_inner()
}
}

// ===== impl Read / Write =====

impl AsyncRead for TcpStream {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.poll_read_priv(cx, buf)
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.poll_read_buf_priv(cx, buf)
}
}

impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write_priv(cx, buf)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.poll_write_buf_priv(cx, buf)
}
}

impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.io.get_ref().fmt(f)
Expand Down
25 changes: 25 additions & 0 deletions tokio-tcp/tests/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#![feature(async_await)]

use tokio_tcp::{TcpListener, TcpStream};

#[tokio::test]
async fn split_reunite() -> std::io::Result<()> {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap())?;
let stream = TcpStream::connect(&listener.local_addr()?).await?;

let (r, w) = stream.split();
assert!(r.reunite(w).is_ok());
Ok(())
}

#[tokio::test]
async fn split_reunite_error() -> std::io::Result<()> {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap())?;
let stream = TcpStream::connect(&listener.local_addr()?).await?;
let stream1 = TcpStream::connect(&listener.local_addr()?).await?;

let (r, _) = stream.split();
let (_, w) = stream1.split();
assert!(r.reunite(w).is_err());
Ok(())
}
Loading

0 comments on commit 6f266e2

Please sign in to comment.