Skip to content

Commit

Permalink
net: rewrite TcpStream::connect with async fn
Browse files Browse the repository at this point in the history
This also removes `TcpStream::connect_std` as the conversion functions
from `std` need to be rethought. A note tracking this has been added
to #1209.
  • Loading branch information
carllerche committed Aug 27, 2019
1 parent 807d536 commit 852db79
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 108 deletions.
127 changes: 22 additions & 105 deletions tokio-net/src/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ use iovec::IoVec;
use mio;
use std::convert::TryFrom;
use std::fmt;
use std::future::Future;
use std::io::{self, Read, Write};
use std::mem;
use std::net::{self, Shutdown, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -55,19 +53,6 @@ pub struct TcpStream {
io: PollEvented<mio::net::TcpStream>,
}

/// Future returned by `TcpStream::connect` which will resolve to a `TcpStream`
/// when the stream is connected.
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct ConnectFuture {
inner: ConnectFutureState,
}

enum ConnectFutureState {
Waiting(TcpStream),
Error(io::Error),
Empty,
}

impl TcpStream {
/// Create a new TCP stream connected to the specified address.
///
Expand Down Expand Up @@ -96,23 +81,37 @@ impl TcpStream {
/// Ok(())
/// }
/// ```
pub fn connect(addr: &SocketAddr) -> impl Future<Output = io::Result<TcpStream>> {
use self::ConnectFutureState::*;
pub async fn connect(addr: &SocketAddr) -> io::Result<TcpStream> {
let sys = mio::net::TcpStream::connect(addr)?;
let stream = TcpStream::new(sys);

let inner = match mio::net::TcpStream::connect(addr) {
Ok(tcp) => Waiting(TcpStream::new(tcp)),
Err(e) => Error(e),
};
stream.finish_connect().await?;

Ok(stream)
}

async fn finish_connect(&self) -> io::Result<()> {
// Once we've connected, wait for the stream to be writable as
// that's when the actual connection has been initiated. Once we're
// writable we check for `take_socket_error` to see if the connect
// actually hit an error or not.
//
// If all that succeeded then we ship everything on up.
poll_fn(|cx| self.io.poll_write_ready(cx)).await?;

ConnectFuture { inner }
if let Some(e) = self.io.get_ref().take_error()? {
return Err(e);
}

Ok(())
}

pub(crate) fn new(connected: mio::net::TcpStream) -> TcpStream {
let io = PollEvented::new(connected);
TcpStream { io }
}

/// Create a new `TcpStream` from a `net::TcpStream`.
/// Create a new `TcpStream` from a `std::net::TcpStream`.
///
/// This function will convert a TCP stream created by the standard library
/// to a TCP stream ready to be used with the provided event loop handle.
Expand All @@ -137,42 +136,6 @@ impl TcpStream {
Ok(TcpStream { io })
}

/// Creates a new `TcpStream` from the pending socket inside the given
/// `std::net::TcpStream`, connecting it to the address specified.
///
/// This constructor allows configuring the socket before it's actually
/// connected, and this function will transfer ownership to the returned
/// `TcpStream` if successful. An unconnected `TcpStream` can be created
/// with the `net2::TcpBuilder` type (and also configured via that route).
///
/// The platform specific behavior of this function looks like:
///
/// * On Unix, the socket is placed into nonblocking mode and then a
/// `connect` call is issued.
///
/// * On Windows, the address is stored internally and the connect operation
/// is issued when the returned `TcpStream` is registered with an event
/// loop. Note that on Windows you must `bind` a socket before it can be
/// connected, so if a custom `TcpBuilder` is used it should be bound
/// (perhaps to `INADDR_ANY`) before this method is called.
pub fn connect_std(
stream: net::TcpStream,
addr: &SocketAddr,
handle: &Handle,
) -> impl Future<Output = io::Result<TcpStream>> {
use self::ConnectFutureState::*;

let io = mio::net::TcpStream::connect_stream(stream, addr)
.and_then(|io| PollEvented::new_with_handle(io, handle));

let inner = match io {
Ok(io) => Waiting(TcpStream { io }),
Err(e) => Error(e),
};

ConnectFuture { inner }
}

/// Returns the local address that this stream is bound to.
///
/// # Examples
Expand Down Expand Up @@ -846,52 +809,6 @@ impl fmt::Debug for TcpStream {
}
}

impl Future for ConnectFuture {
type Output = io::Result<TcpStream>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<TcpStream>> {
self.inner.poll_inner(|io| io.poll_write_ready(cx))
}
}

impl ConnectFutureState {
fn poll_inner<F>(&mut self, f: F) -> Poll<io::Result<TcpStream>>
where
F: FnOnce(&mut PollEvented<mio::net::TcpStream>) -> Poll<io::Result<mio::Ready>>,
{
{
let stream = match *self {
ConnectFutureState::Waiting(ref mut s) => s,
ConnectFutureState::Error(_) => {
let e = match mem::replace(self, ConnectFutureState::Empty) {
ConnectFutureState::Error(e) => e,
_ => unreachable!(),
};
return Poll::Ready(Err(e));
}
ConnectFutureState::Empty => panic!("can't poll TCP stream twice"),
};

// Once we've connected, wait for the stream to be writable as
// that's when the actual connection has been initiated. Once we're
// writable we check for `take_socket_error` to see if the connect
// actually hit an error or not.
//
// If all that succeeded then we ship everything on up.
ready!(f(&mut stream.io))?;

if let Some(e) = stream.io.get_ref().take_error()? {
return Poll::Ready(Err(e));
}
}

match mem::replace(self, ConnectFutureState::Empty) {
ConnectFutureState::Waiting(stream) => Poll::Ready(Ok(stream)),
_ => unreachable!(),
}
}
}

#[cfg(unix)]
mod sys {
use super::TcpStream;
Expand Down
8 changes: 5 additions & 3 deletions tokio-net/tests/tcp_split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use tokio_net::tcp::{TcpListener, TcpStream};
#[tokio::test]
async fn split_reunite() -> std::io::Result<()> {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap())?;
let stream = TcpStream::connect(&listener.local_addr()?).await?;
let addr = listener.local_addr()?;
let stream = TcpStream::connect(&addr).await?;

let (r, w) = stream.split();
assert!(r.reunite(w).is_ok());
Expand All @@ -13,8 +14,9 @@ async fn split_reunite() -> std::io::Result<()> {
#[tokio::test]
async fn split_reunite_error() -> std::io::Result<()> {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap())?;
let stream = TcpStream::connect(&listener.local_addr()?).await?;
let stream1 = TcpStream::connect(&listener.local_addr()?).await?;
let addr = listener.local_addr()?;
let stream = TcpStream::connect(&addr).await?;
let stream1 = TcpStream::connect(&addr).await?;

let (r, _) = stream.split();
let (_, w) = stream1.split();
Expand Down

0 comments on commit 852db79

Please sign in to comment.