Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TcpSocket specialized split #1217

Merged
merged 4 commits into from
Jun 29, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio-tcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ futures-core-preview = { version = "0.3.0-alpha.16", optional = true }
[dev-dependencies]
#env_logger = { version = "0.5", default-features = false }
#net2 = "*"
#tokio = { version = "0.2.0", path = "../tokio" }
tokio = { version = "0.2.0", path = "../tokio" }
jonhoo marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions tokio-tcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ macro_rules! ready {
#[cfg(feature = "incoming")]
mod incoming;
mod listener;
pub mod split;
mod stream;

pub use self::listener::TcpListener;
Expand Down
145 changes: 145 additions & 0 deletions tokio-tcp/src/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//! `TcpStream` split support.
//!
//! A `TcpStream` can be split into a `TcpStreamReadHalf` and a
//! `TcpStreamWriteHalf` with the `TcpStream::split` method. `TcpStreamReadHalf`
//! implements `AsyncRead` while `TcpStreamWriteHalf` implements `AsyncWrite`.
//! The two halves can be used concurrently, even from multiple tasks.
//!
//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
//! split gives read and write halves that are faster and smaller, because they
//! do not use locks. They also provide access to the underlying `TcpStream`
//! after split, implementing `AsRef<TcpStream>`. This allows you to call
//! `TcpStream` methods that takes `&self`, e.g., to get local and peer
//! addresses, to get and set socket options, and to shutdown the sockets.

use super::TcpStream;
use bytes::{Buf, BufMut};
use std::error::Error;
use std::fmt;
use std::io;
use std::net::Shutdown;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio_io::{AsyncRead, AsyncWrite};

/// Read half of a `TcpStream`.
#[derive(Debug)]
pub struct TcpStreamReadHalf(Arc<TcpStream>);

/// Write half of a `TcpStream`.
///
/// Note that in the `AsyncWrite` implemenation of `TcpStreamWriteHalf`,
/// `poll_shutdown` actually shuts down the TCP stream in the write direction.
#[derive(Debug)]
pub struct TcpStreamWriteHalf(Arc<TcpStream>);

pub(crate) fn split(stream: TcpStream) -> (TcpStreamReadHalf, TcpStreamWriteHalf) {
let shared = Arc::new(stream);
(
TcpStreamReadHalf(shared.clone()),
TcpStreamWriteHalf(shared),
)
}

/// Error indicating two halves were not from the same stream, and thus could
/// not be `reunite`d.
#[derive(Debug)]
pub struct ReuniteError(pub TcpStreamReadHalf, pub TcpStreamWriteHalf);

impl fmt::Display for ReuniteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same stream"
)
}
}

impl Error for ReuniteError {}

impl TcpStreamReadHalf {
/// Attempts to put the two "halves" of a `TcpStream` back together and
/// recover the original stream. Succeeds only if the two "halves"
/// originated from the same call to `TcpStream::split`.
pub fn reunite(self, other: TcpStreamWriteHalf) -> Result<TcpStream, ReuniteError> {
if Arc::ptr_eq(&self.0, &other.0) {
drop(other);
Ok(Arc::try_unwrap(self.0).unwrap())
jonhoo marked this conversation as resolved.
Show resolved Hide resolved
} else {
Err(ReuniteError(self, other))
}
}
}

impl TcpStreamWriteHalf {
/// Attempts to put the two "halves" of a `TcpStream` back together and
/// recover the original stream. Succeeds only if the two "halves"
/// originated from the same call to `TcpStream::split`.
pub fn reunite(self, other: TcpStreamReadHalf) -> Result<TcpStream, ReuniteError> {
other.reunite(self)
}
}

impl AsRef<TcpStream> for TcpStreamReadHalf {
fn as_ref(&self) -> &TcpStream {
&self.0
}
}

impl AsRef<TcpStream> for TcpStreamWriteHalf {
fn as_ref(&self) -> &TcpStream {
&self.0
}
}

impl AsyncRead for TcpStreamReadHalf {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.0.poll_read_priv(cx, buf)
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.0.poll_read_buf_priv(cx, buf)
}
}

impl AsyncWrite for TcpStreamWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.0.poll_write_priv(cx, buf)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
Poll::Ready(Ok(()))
}

// `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.shutdown(Shutdown::Write).into()
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.0.poll_write_buf_priv(cx, buf)
}
}
147 changes: 108 additions & 39 deletions tokio-tcp/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::split::{split, TcpStreamReadHalf, TcpStreamWriteHalf};
use bytes::{Buf, BufMut};
use iovec::IoVec;
use mio;
use std::convert::TryFrom;
use std::fmt;
use std::future::Future;
use std::io;
use std::io::{self, Read, Write};
use std::mem;
use std::net::{self, Shutdown, SocketAddr};
use std::pin::Pin;
Expand Down Expand Up @@ -712,37 +713,45 @@ impl TcpStream {
let msg = "`TcpStream::try_clone()` is deprecated because it doesn't work as intended";
Err(io::Error::new(io::ErrorKind::Other, msg))
}
}

impl TryFrom<TcpStream> for mio::net::TcpStream {
type Error = io::Error;

/// Consumes value, returning the mio I/O object.
/// Split a `TcpStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently.
///
/// See [`tokio_reactor::PollEvented::into_inner`] for more details about
/// resource deregistration that happens during the call.
fn try_from(value: TcpStream) -> Result<Self, Self::Error> {
value.io.into_inner()
/// See the module level documenation of [`split`](super::split) for more
/// details.
pub fn split(self) -> (TcpStreamReadHalf, TcpStreamWriteHalf) {
split(self)
}
}

// ===== impl Read / Write =====

impl AsyncRead for TcpStream {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}

fn poll_read(
mut self: Pin<&mut Self>,
// == Poll IO functions that takes `&self` ==
//
// They are not public because (taken from the doc of `PollEvented`):
//
// While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the
// caller must ensure that there are at most two tasks that use a
// `PollEvented` instance concurrently. One for reading and one for writing.
// While violating this requirement is "safe" from a Rust memory model point
// of view, it will result in unexpected behavior in the form of lost
// notifications and tasks hanging.

pub(crate) fn poll_read_priv(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_read(cx, buf)
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;

match self.io.get_ref().read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_read_ready(cx, mio::Ready::readable())?;
Poll::Pending
}
x => Poll::Ready(x),
}
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
pub(crate) fn poll_read_buf_priv<B: BufMut>(
&self,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -804,29 +813,25 @@ impl AsyncRead for TcpStream {
Err(e) => Poll::Ready(Err(e)),
}
}
}

impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
pub(crate) fn poll_write_priv(
&self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write(cx, buf)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
Poll::Ready(Ok(()))
}
ready!(self.io.poll_write_ready(cx))?;

fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
match self.io.get_ref().write(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_write_ready(cx)?;
Poll::Pending
}
x => Poll::Ready(x),
}
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
pub(crate) fn poll_write_buf_priv<B: Buf>(
&self,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -856,6 +861,70 @@ impl AsyncWrite for TcpStream {
}
}

impl TryFrom<TcpStream> for mio::net::TcpStream {
type Error = io::Error;

/// Consumes value, returning the mio I/O object.
///
/// See [`tokio_reactor::PollEvented::into_inner`] for more details about
/// resource deregistration that happens during the call.
fn try_from(value: TcpStream) -> Result<Self, Self::Error> {
value.io.into_inner()
}
}

// ===== impl Read / Write =====

impl AsyncRead for TcpStream {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.poll_read_priv(cx, buf)
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.poll_read_buf_priv(cx, buf)
}
}

impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write_priv(cx, buf)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.poll_write_buf_priv(cx, buf)
}
}

impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.io.get_ref().fmt(f)
Expand Down
25 changes: 25 additions & 0 deletions tokio-tcp/tests/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#![feature(async_await)]

use tokio_tcp::{TcpListener, TcpStream};

#[tokio::test]
async fn split_reunite() -> std::io::Result<()> {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap())?;
let stream = TcpStream::connect(&listener.local_addr()?).await?;

let (r, w) = stream.split();
assert!(r.reunite(w).is_ok());
Ok(())
}

#[tokio::test]
async fn split_reunite_error() -> std::io::Result<()> {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap())?;
let stream = TcpStream::connect(&listener.local_addr()?).await?;
let stream1 = TcpStream::connect(&listener.local_addr()?).await?;

let (r, _) = stream.split();
let (_, w) = stream1.split();
assert!(r.reunite(w).is_err());
Ok(())
}
2 changes: 1 addition & 1 deletion tokio/src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub mod tcp {
//! [`TcpListener`]: struct.TcpListener.html
//! [incoming_method]: struct.TcpListener.html#method.incoming
//! [`Incoming`]: struct.Incoming.html
pub use tokio_tcp::{TcpListener, TcpStream};
pub use tokio_tcp::{split, TcpListener, TcpStream};
}
#[cfg(feature = "tcp")]
pub use self::tcp::{TcpListener, TcpStream};
Expand Down