diff --git a/perf/src/bin/perf_client.rs b/perf/src/bin/perf_client.rs index 62c9810050..9cf371a9dc 100644 --- a/perf/src/bin/perf_client.rs +++ b/perf/src/bin/perf_client.rs @@ -6,6 +6,7 @@ use std::{ use anyhow::{Context, Result}; use bytes::Bytes; +use quinn::TokioRuntime; use structopt::StructOpt; use tokio::sync::Semaphore; use tracing::{debug, error, info}; @@ -103,7 +104,7 @@ async fn run(opt: Opt) -> Result<()> { let socket = bind_socket(bind_addr, opt.send_buffer_size, opt.recv_buffer_size)?; - let (endpoint, _) = quinn::Endpoint::new(Default::default(), None, socket)?; + let (endpoint, _) = quinn::Endpoint::new(Default::default(), None, socket, TokioRuntime)?; let mut crypto = rustls::ClientConfig::builder() .with_cipher_suites(perf::PERF_CIPHER_SUITES) diff --git a/perf/src/bin/perf_server.rs b/perf/src/bin/perf_server.rs index 39d1757193..a304d2623a 100644 --- a/perf/src/bin/perf_server.rs +++ b/perf/src/bin/perf_server.rs @@ -2,6 +2,7 @@ use std::{fs, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; use anyhow::{Context, Result}; use bytes::Bytes; +use quinn::TokioRuntime; use structopt::StructOpt; use tracing::{debug, error, info}; @@ -77,9 +78,13 @@ async fn run(opt: Opt) -> Result<()> { let socket = bind_socket(opt.listen, opt.send_buffer_size, opt.recv_buffer_size)?; - let (endpoint, mut incoming) = - quinn::Endpoint::new(Default::default(), Some(server_config), socket) - .context("creating endpoint")?; + let (endpoint, mut incoming) = quinn::Endpoint::new( + Default::default(), + Some(server_config), + socket, + TokioRuntime, + ) + .context("creating endpoint")?; info!("listening on {}", endpoint.local_addr().unwrap()); diff --git a/quinn-udp/Cargo.toml b/quinn-udp/Cargo.toml index 118ddba9c3..201d7c317f 100644 --- a/quinn-udp/Cargo.toml +++ b/quinn-udp/Cargo.toml @@ -20,4 +20,9 @@ libc = "0.2.69" proto = { package = "quinn-proto", path = "../quinn-proto", version = "0.8", default-features = false } socket2 = "0.4" tracing = "0.1.10" -tokio = { version = "1.0.1", features = ["net"] } +tokio = { version = "1.0.1", features = [ "net" ], optional = true } +async-io = { version = "1.6", optional = true } + +[features] +runtime-tokio = [ "tokio" ] +runtime-async-std = [ "async-io" ] diff --git a/quinn-udp/src/fallback.rs b/quinn-udp/src/fallback.rs index e3855c6b8b..22e3fac44a 100644 --- a/quinn-udp/src/fallback.rs +++ b/quinn-udp/src/fallback.rs @@ -5,8 +5,8 @@ use std::{ time::Instant, }; +use crate::runtime::AsyncWrappedUdpSocket; use proto::Transmit; -use tokio::io::ReadBuf; use super::{log_sendmsg_error, RecvMeta, UdpState, IO_ERROR_LOG_INTERVAL}; @@ -16,16 +16,15 @@ use super::{log_sendmsg_error, RecvMeta, UdpState, IO_ERROR_LOG_INTERVAL}; /// platforms. #[derive(Debug)] pub struct UdpSocket { - io: tokio::net::UdpSocket, + io: Box, last_send_error: Instant, } impl UdpSocket { - pub fn from_std(socket: std::net::UdpSocket) -> io::Result { - socket.set_nonblocking(true)?; + pub fn new(socket: Box) -> io::Result { let now = Instant::now(); Ok(UdpSocket { - io: tokio::net::UdpSocket::from_std(socket)?, + io: socket, last_send_error: now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now), }) } @@ -38,10 +37,28 @@ impl UdpSocket { ) -> Poll> { let mut sent = 0; for transmit in transmits { - match self - .io - .poll_send_to(cx, &transmit.contents, transmit.destination) - { + let io_res = loop { + let poll_res = self.io.poll_write_ready(cx); + + break match poll_res { + Poll::Ready(Ok(())) => Poll::Ready( + match self + .io + .try_send_to(&transmit.contents, transmit.destination) + { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_write_ready(cx); + continue; // try again + } + res => res, + }, + ), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + }; + }; + + match io_res { Poll::Ready(Ok(_)) => { sent += 1; } @@ -75,10 +92,21 @@ impl UdpSocket { meta: &mut [RecvMeta], ) -> Poll> { debug_assert!(!bufs.is_empty()); - let mut buf = ReadBuf::new(&mut bufs[0]); - let addr = ready!(self.io.poll_recv_from(cx, &mut buf))?; + + let (len, addr) = loop { + ready!(self.io.poll_read_ready(cx))?; + + break match self.io.try_recv_from(&mut bufs[0]) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_read_ready(cx); + continue; // try again + } + res => res?, + }; + }; + meta[0] = RecvMeta { - len: buf.filled().len(), + len, addr, ecn: None, dst_ip: None, diff --git a/quinn-udp/src/lib.rs b/quinn-udp/src/lib.rs index fd232b0236..c0d9daddd5 100644 --- a/quinn-udp/src/lib.rs +++ b/quinn-udp/src/lib.rs @@ -28,6 +28,8 @@ mod imp; #[path = "fallback.rs"] mod imp; +pub mod runtime; + pub use imp::UdpSocket; /// Number of UDP packets to send/receive at a time diff --git a/quinn-udp/src/runtime/async_std_runtime.rs b/quinn-udp/src/runtime/async_std_runtime.rs new file mode 100644 index 0000000000..4eea079ced --- /dev/null +++ b/quinn-udp/src/runtime/async_std_runtime.rs @@ -0,0 +1,51 @@ +use super::{AsyncWrappedUdpSocket, Runtime}; +use async_io::Async; +use std::io; +use std::task::{Context, Poll}; + +impl AsyncWrappedUdpSocket for Async { + fn poll_read_ready(&self, cx: &mut Context) -> Poll> { + Async::poll_readable(self, cx) + } + + fn poll_write_ready(&self, cx: &mut Context) -> Poll> { + Async::poll_writable(self, cx) + } + + fn clear_read_ready(&self, _cx: &mut Context) { + // async-std doesn't need this + } + + fn clear_write_ready(&self, _cx: &mut Context) { + // async-std doesn't need this + } + + fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, std::net::SocketAddr)> { + self.get_ref().recv_from(buf) + } + + fn try_send_to(&self, buf: &[u8], target: std::net::SocketAddr) -> io::Result { + self.get_ref().send_to(buf, target) + } + + fn local_addr(&self) -> io::Result { + self.get_ref().local_addr() + } + + #[cfg(unix)] + fn get_ref(&self) -> &std::net::UdpSocket { + Async::get_ref(self) + } +} + +#[derive(Clone, Debug)] +pub struct AsyncStdRuntime; + +impl Runtime for AsyncStdRuntime { + fn wrap_udp_socket( + &self, + t: std::net::UdpSocket, + ) -> io::Result> { + Ok(Box::new(Async::new(t)?)) + } +} diff --git a/quinn-udp/src/runtime/mod.rs b/quinn-udp/src/runtime/mod.rs new file mode 100644 index 0000000000..75788f08b9 --- /dev/null +++ b/quinn-udp/src/runtime/mod.rs @@ -0,0 +1,39 @@ +#[cfg(feature = "runtime-tokio")] +mod tokio_runtime; +#[cfg(feature = "runtime-tokio")] +pub use tokio_runtime::*; + +#[cfg(feature = "runtime-async-std")] +mod async_std_runtime; +#[cfg(feature = "runtime-async-std")] +pub use async_std_runtime::*; + +use std::fmt::Debug; +use std::io; +use std::task::{Context, Poll}; + +pub trait AsyncWrappedUdpSocket: Send + Debug { + fn poll_read_ready(&self, cx: &mut Context) -> Poll>; + + fn poll_write_ready(&self, cx: &mut Context) -> Poll>; + + fn clear_read_ready(&self, cx: &mut Context); + + fn clear_write_ready(&self, cx: &mut Context); + + fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, std::net::SocketAddr)>; + + fn try_send_to(&self, buf: &[u8], target: std::net::SocketAddr) -> io::Result; + + fn local_addr(&self) -> io::Result; + + // On Unix we expect to be able to access the underlying std UdpSocket + // to be able to implement more advanced features + #[cfg(unix)] + fn get_ref(&self) -> &std::net::UdpSocket; +} + +pub trait Runtime: Send + Sync + Debug + 'static { + fn wrap_udp_socket(&self, t: std::net::UdpSocket) + -> io::Result>; +} diff --git a/quinn-udp/src/runtime/tokio_runtime.rs b/quinn-udp/src/runtime/tokio_runtime.rs new file mode 100644 index 0000000000..6b1da94aac --- /dev/null +++ b/quinn-udp/src/runtime/tokio_runtime.rs @@ -0,0 +1,99 @@ +use super::{AsyncWrappedUdpSocket, Runtime}; +use std::io; +use std::task::{Context, Poll}; +#[cfg(unix)] +use tokio::io::unix::AsyncFd; + +#[cfg(unix)] +impl AsyncWrappedUdpSocket for AsyncFd { + fn poll_read_ready(&self, cx: &mut Context) -> Poll> { + AsyncFd::poll_read_ready(self, cx).map(|x| x.map(|_| ())) + } + + fn poll_write_ready(&self, cx: &mut Context) -> Poll> { + AsyncFd::poll_write_ready(self, cx).map(|x| x.map(|_| ())) + } + + fn clear_read_ready(&self, cx: &mut Context) { + match self.poll_read_ready(cx) { + Poll::Pending => {} + Poll::Ready(Err(_)) => {} + Poll::Ready(Ok(mut guard)) => guard.clear_ready(), + } + } + + fn clear_write_ready(&self, cx: &mut Context) { + match self.poll_write_ready(cx) { + Poll::Pending => {} + Poll::Ready(Err(_)) => {} + Poll::Ready(Ok(mut guard)) => guard.clear_ready(), + } + } + + fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, std::net::SocketAddr)> { + self.get_ref().recv_from(buf) + } + + fn try_send_to(&self, buf: &[u8], target: std::net::SocketAddr) -> io::Result { + self.get_ref().send_to(buf, target) + } + + fn local_addr(&self) -> io::Result { + self.get_ref().local_addr() + } + + fn get_ref(&self) -> &std::net::UdpSocket { + AsyncFd::get_ref(self) + } +} + +#[cfg(not(unix))] +impl AsyncWrappedUdpSocket for tokio::net::UdpSocket { + fn poll_read_ready(&self, cx: &mut Context) -> Poll> { + tokio::net::UdpSocket::poll_recv_ready(self, cx) + } + + fn poll_write_ready(&self, cx: &mut Context) -> Poll> { + tokio::net::UdpSocket::poll_send_ready(self, cx) + } + + fn clear_read_ready(&self, _cx: &mut Context) { + // not necessary because tokio::net::UdpSocket::try_recv_from already uses try_io + } + + fn clear_write_ready(&self, _cx: &mut Context) { + // not necessary because tokio::net::UdpSocket::try_send_from already uses try_io + } + + fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, std::net::SocketAddr)> { + tokio::net::UdpSocket::try_recv_from(self, buf) + } + + fn try_send_to(&self, buf: &[u8], target: std::net::SocketAddr) -> io::Result { + tokio::net::UdpSocket::try_send_to(self, buf, target) + } + + fn local_addr(&self) -> io::Result { + tokio::net::UdpSocket::local_addr(self) + } +} + +#[derive(Debug, Clone)] +pub struct TokioRuntime; + +impl Runtime for TokioRuntime { + fn wrap_udp_socket( + &self, + t: std::net::UdpSocket, + ) -> io::Result> { + t.set_nonblocking(true)?; + #[cfg(unix)] + { + Ok(Box::new(AsyncFd::new(t)?)) + } + #[cfg(not(unix))] + { + Ok(Box::new(tokio::net::UdpSocket::from_std(t)?)) + } + } +} diff --git a/quinn-udp/src/unix.rs b/quinn-udp/src/unix.rs index a3746ac1fb..79a1587878 100644 --- a/quinn-udp/src/unix.rs +++ b/quinn-udp/src/unix.rs @@ -11,9 +11,9 @@ use std::{ }; use proto::{EcnCodepoint, Transmit}; -use tokio::io::unix::AsyncFd; use super::{cmsg, log_sendmsg_error, RecvMeta, UdpState, IO_ERROR_LOG_INTERVAL}; +use crate::runtime::AsyncWrappedUdpSocket; #[cfg(target_os = "freebsd")] type IpTosTy = libc::c_uchar; @@ -45,17 +45,16 @@ fn only_v6(sock: &std::net::UdpSocket) -> io::Result { /// platforms. #[derive(Debug)] pub struct UdpSocket { - io: AsyncFd, + io: Box, last_send_error: Instant, } impl UdpSocket { - pub fn from_std(socket: std::net::UdpSocket) -> io::Result { - socket.set_nonblocking(true)?; - init(&socket)?; + pub fn new(socket: Box) -> io::Result { + init(socket.get_ref())?; let now = Instant::now(); Ok(UdpSocket { - io: AsyncFd::new(socket)?, + io: socket, last_send_error: now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now), }) } @@ -68,11 +67,12 @@ impl UdpSocket { ) -> Poll> { loop { let last_send_error = &mut self.last_send_error; - let mut guard = ready!(self.io.poll_write_ready(cx))?; - if let Ok(res) = - guard.try_io(|io| send(state, io.get_ref(), last_send_error, transmits)) - { - return Poll::Ready(res); + ready!(self.io.poll_write_ready(cx))?; + match send(state, self.io.get_ref(), last_send_error, transmits) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_write_ready(cx); + } + res => break Poll::Ready(res), } } } @@ -85,9 +85,12 @@ impl UdpSocket { ) -> Poll> { debug_assert!(!bufs.is_empty()); loop { - let mut guard = ready!(self.io.poll_read_ready(cx))?; - if let Ok(res) = guard.try_io(|io| recv(io.get_ref(), bufs, meta)) { - return Poll::Ready(res); + ready!(self.io.poll_read_ready(cx))?; + match recv(self.io.get_ref(), bufs, meta) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_read_ready(cx); + } + res => break Poll::Ready(res), } } } diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index 240dcdfd9f..d60796b919 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -15,7 +15,7 @@ rust-version = "1.53" all-features = true [features] -default = ["native-certs", "tls-rustls"] +default = ["native-certs", "tls-rustls", "runtime-tokio"] # Records how long locks are held, and warns if they are held >= 1ms lock_tracking = [] # Provides `ClientConfig::with_native_roots()` convenience method @@ -23,12 +23,16 @@ native-certs = ["proto/native-certs"] tls-rustls = ["rustls", "webpki", "proto/tls-rustls", "ring"] # Enables `Endpoint::client` and `Endpoint::server` conveniences ring = ["proto/ring"] +runtime-tokio = [ "tokio/time", "tokio/rt", "udp/runtime-tokio" ] +runtime-async-std = [ "async-io", "async-std", "udp/runtime-async-std" ] [badges] codecov = { repository = "djc/quinn" } maintenance = { status = "experimental" } [dependencies] +async-io = { version = "1.6", optional = true } +async-std = { version = "1.11", optional = true } bytes = "1" # Enables futures::io::{AsyncRead, AsyncWrite} support for streams futures-io = { version = "0.3.19", optional = true } @@ -39,11 +43,12 @@ proto = { package = "quinn-proto", path = "../quinn-proto", version = "0.8", def rustls = { version = "0.20.3", default-features = false, features = ["quic"], optional = true } thiserror = "1.0.21" tracing = "0.1.10" -tokio = { version = "1.0.1", features = ["rt", "time", "sync"] } +tokio = { version = "1.0.1", features = ["sync"] } udp = { package = "quinn-udp", path = "../quinn-udp", version = "0.1.0" } webpki = { version = "0.22", default-features = false, optional = true } [dev-dependencies] +async-std = { version = "1.11", features = [ "attributes" ] } anyhow = "1.0.22" crc = "3" bencher = "0.1.5" diff --git a/quinn/benches/bench.rs b/quinn/benches/bench.rs index f1dc2de720..0ccfef9249 100644 --- a/quinn/benches/bench.rs +++ b/quinn/benches/bench.rs @@ -9,7 +9,7 @@ use tokio::runtime::{Builder, Runtime}; use tracing::error_span; use tracing_futures::Instrument as _; -use quinn::Endpoint; +use quinn::{Endpoint, TokioRuntime}; benchmark_group!( benches, @@ -102,7 +102,7 @@ impl Context { let runtime = rt(); let (_, mut incoming) = { let _guard = runtime.enter(); - Endpoint::new(Default::default(), Some(config), sock).unwrap() + Endpoint::new(Default::default(), Some(config), sock, TokioRuntime).unwrap() }; let handle = runtime.spawn( async move { diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index dc920f90a6..13fdcd971a 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -10,12 +10,12 @@ use std::{ time::{Duration, Instant}, }; +use crate::runtime::{AsyncTimer, Runtime}; use bytes::Bytes; use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId}; use rustc_hash::FxHashMap; use thiserror::Error; use tokio::sync::{mpsc, oneshot, Notify}; -use tokio::time::{sleep_until, Instant as TokioInstant, Sleep}; use tracing::info_span; use udp::UdpState; @@ -44,6 +44,7 @@ impl Connecting { endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, conn_events: mpsc::UnboundedReceiver, udp_state: Arc, + runtime: Arc>, ) -> Connecting { let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel(); let (on_connected_send, on_connected_recv) = oneshot::channel(); @@ -55,9 +56,10 @@ impl Connecting { on_handshake_data_send, on_connected_send, udp_state, + runtime.clone(), ); - tokio::spawn(ConnectionDriver(conn.clone())); + runtime.spawn(Box::pin(ConnectionDriver(conn.clone()))); Connecting { conn: Some(conn), @@ -692,6 +694,7 @@ impl futures_core::Stream for Datagrams { pub struct ConnectionRef(Arc>); impl ConnectionRef { + #[allow(clippy::too_many_arguments)] fn new( handle: ConnectionHandle, conn: proto::Connection, @@ -700,6 +703,7 @@ impl ConnectionRef { on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, udp_state: Arc, + runtime: Arc>, ) -> Self { Self(Arc::new(Mutex::new(ConnectionInner { inner: conn, @@ -723,6 +727,7 @@ impl ConnectionRef { error: None, ref_count: 0, udp_state, + runtime, }))) } @@ -768,8 +773,8 @@ pub struct ConnectionInner { on_handshake_data: Option>, on_connected: Option>, connected: bool, - timer: Option>>, - timer_deadline: Option, + timer: Option>>, + timer_deadline: Option, conn_events: mpsc::UnboundedReceiver, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, pub(crate) blocked_writers: FxHashMap, @@ -785,6 +790,7 @@ pub struct ConnectionInner { /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, udp_state: Arc, + runtime: Arc>, } impl ConnectionInner { @@ -924,7 +930,7 @@ impl ConnectionInner { // Check whether we need to (re)set the timer. If so, we must poll again to ensure the // timer is registered with the runtime (and check whether it's already // expired). - match self.inner.poll_timeout().map(TokioInstant::from_std) { + match self.inner.poll_timeout() { Some(deadline) => { if let Some(delay) = &mut self.timer { // There is no need to reset the tokio timer if the deadline @@ -937,7 +943,7 @@ impl ConnectionInner { delay.as_mut().reset(deadline); } } else { - self.timer = Some(Box::pin(sleep_until(deadline))); + self.timer = Some(self.runtime.new_timer(deadline)); } // Store the actual expiration time of the timer self.timer_deadline = Some(deadline); diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 698b706e78..45d7721a6d 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -12,6 +12,7 @@ use std::{ time::Instant, }; +use crate::runtime::{make_runtime, Runtime}; use bytes::Bytes; use proto::{ self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig, @@ -35,15 +36,13 @@ use crate::{ pub struct Endpoint { pub(crate) inner: EndpointRef, pub(crate) default_client_config: Option, + runtime: Arc>, } impl Endpoint { /// Helper to construct an endpoint for use with outgoing connections only /// - /// Must be called from within a tokio runtime context. Note that `addr` is the *local* address - /// to bind to, which should usually be a wildcard address like `0.0.0.0:0` or `[::]:0`, which - /// allow communication with any reachable IPv4 or IPv6 address respectively from an OS-assigned - /// port. + /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or IPv6 address respectively from an OS-assigned port. /// /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard /// IPv6 address on Windows will not by default be able to communicate with IPv4 @@ -52,13 +51,12 @@ impl Endpoint { #[cfg(feature = "ring")] pub fn client(addr: SocketAddr) -> io::Result { let socket = std::net::UdpSocket::bind(addr)?; - Ok(Self::new(EndpointConfig::default(), None, socket)?.0) + let runtime = make_runtime(); + Ok(Self::new_with_runtime(EndpointConfig::default(), None, socket, runtime)?.0) } /// Helper to construct an endpoint for use with both incoming and outgoing connections /// - /// Must be called from within a tokio runtime context. - /// /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard /// IPv6 address on Windows will not by default be able to communicate with IPv4 /// addresses. Portable applications should bind an address that matches the family they wish to @@ -66,34 +64,47 @@ impl Endpoint { #[cfg(feature = "ring")] pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<(Self, Incoming)> { let socket = std::net::UdpSocket::bind(addr)?; - Self::new(EndpointConfig::default(), Some(config), socket) + let runtime = make_runtime(); + Self::new_with_runtime(EndpointConfig::default(), Some(config), socket, runtime) } /// Construct an endpoint with arbitrary configuration - /// - /// Must be called from within a tokio runtime context. pub fn new( config: EndpointConfig, server_config: Option, socket: std::net::UdpSocket, + runtime: impl Runtime, + ) -> io::Result<(Self, Incoming)> { + let runtime: Box = Box::new(runtime); + Self::new_with_runtime(config, server_config, socket, runtime) + } + + fn new_with_runtime( + config: EndpointConfig, + server_config: Option, + socket: std::net::UdpSocket, + runtime: Box, ) -> io::Result<(Self, Incoming)> { + let runtime = Arc::new(runtime); let addr = socket.local_addr()?; - let socket = UdpSocket::from_std(socket)?; + let socket = UdpSocket::new(runtime.wrap_udp_socket(socket)?)?; let rc = EndpointRef::new( socket, proto::Endpoint::new(Arc::new(config), server_config.map(Arc::new)), addr.is_ipv6(), + runtime.clone(), ); let driver = EndpointDriver(rc.clone()); - tokio::spawn(async { + runtime.spawn(Box::pin(async { if let Err(e) = driver.await { tracing::error!("I/O error: {}", e); } - }); + })); Ok(( Self { inner: rc.clone(), default_client_config: None, + runtime, }, Incoming::new(rc), )) @@ -146,7 +157,9 @@ impl Endpoint { }; let (ch, conn) = endpoint.inner.connect(config, addr, server_name)?; let udp_state = endpoint.udp_state.clone(); - Ok(endpoint.connections.insert(ch, conn, udp_state)) + Ok(endpoint + .connections + .insert(ch, conn, udp_state, self.runtime.clone())) } /// Switch to a new UDP socket @@ -157,7 +170,7 @@ impl Endpoint { /// On error, the old UDP socket is retained. pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> { let addr = socket.local_addr()?; - let socket = UdpSocket::from_std(socket)?; + let socket = UdpSocket::new(self.runtime.wrap_udp_socket(socket)?)?; let mut inner = self.inner.lock().unwrap(); inner.socket = socket; inner.ipv6 = addr.is_ipv6(); @@ -324,6 +337,7 @@ pub(crate) struct EndpointInner { recv_buf: Box<[u8]>, send_limiter: WorkLimiter, idle: Arc, + runtime: Arc>, } impl EndpointInner { @@ -352,9 +366,12 @@ impl EndpointInner { .handle(now, meta.addr, meta.dst_ip, meta.ecn, data) { Some((handle, DatagramEvent::NewConnection(conn))) => { - let conn = - self.connections - .insert(handle, conn, self.udp_state.clone()); + let conn = self.connections.insert( + handle, + conn, + self.udp_state.clone(), + self.runtime.clone(), + ); self.incoming.push_back(conn); } Some((handle, DatagramEvent::ConnectionEvent(event))) => { @@ -486,6 +503,7 @@ impl ConnectionSet { handle: ConnectionHandle, conn: proto::Connection, udp_state: Arc, + runtime: Arc>, ) -> Connecting { let (send, recv) = mpsc::unbounded_channel(); if let Some((error_code, ref reason)) = self.close { @@ -496,7 +514,7 @@ impl ConnectionSet { .unwrap(); } self.senders.insert(handle, send); - Connecting::new(handle, conn, self.sender.clone(), recv, udp_state) + Connecting::new(handle, conn, self.sender.clone(), recv, udp_state, runtime) } fn is_empty(&self) -> bool { @@ -564,7 +582,12 @@ impl Drop for Incoming { pub(crate) struct EndpointRef(Arc>); impl EndpointRef { - pub(crate) fn new(socket: UdpSocket, inner: proto::Endpoint, ipv6: bool) -> Self { + pub(crate) fn new( + socket: UdpSocket, + inner: proto::Endpoint, + ipv6: bool, + runtime: Arc>, + ) -> Self { let recv_buf = vec![0; inner.config().get_max_udp_payload_size().min(64 * 1024) as usize * BATCH_SIZE]; let (sender, events) = mpsc::unbounded_channel(); @@ -589,6 +612,7 @@ impl EndpointRef { recv_limiter: WorkLimiter::new(RECV_TIME_BOUND), send_limiter: WorkLimiter::new(SEND_TIME_BOUND), idle: Arc::new(Notify::new()), + runtime, }))) } } diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 1a964c78ed..a0374ec71f 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -59,6 +59,7 @@ mod connection; mod endpoint; mod mutex; mod recv_stream; +mod runtime; mod send_stream; mod work_limiter; @@ -74,6 +75,10 @@ pub use crate::connection::{ }; pub use crate::endpoint::{Endpoint, Incoming}; pub use crate::recv_stream::{ReadError, ReadExactError, ReadToEndError, RecvStream}; +#[cfg(feature = "runtime-async-std")] +pub use crate::runtime::AsyncStdRuntime; +#[cfg(feature = "runtime-tokio")] +pub use crate::runtime::TokioRuntime; pub use crate::send_stream::{SendStream, StoppedError, WriteError}; #[cfg(test)] diff --git a/quinn/src/runtime/async_std_runtime.rs b/quinn/src/runtime/async_std_runtime.rs new file mode 100644 index 0000000000..ef60643c52 --- /dev/null +++ b/quinn/src/runtime/async_std_runtime.rs @@ -0,0 +1,26 @@ +use super::{AsyncTimer, Runtime}; +use async_io::Timer; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; + +impl AsyncTimer for Timer { + fn reset(mut self: Pin<&mut Self>, t: Instant) { + self.set_at(t) + } + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + Future::poll(self, cx).map(|_| ()) + } +} + +pub use udp::runtime::AsyncStdRuntime; + +impl Runtime for AsyncStdRuntime { + fn new_timer(&self, t: Instant) -> Pin> { + Box::pin(Timer::at(t)) + } + fn spawn(&self, future: Pin + Send>>) { + async_std::task::spawn(future); + } +} diff --git a/quinn/src/runtime/mod.rs b/quinn/src/runtime/mod.rs new file mode 100644 index 0000000000..eff9038995 --- /dev/null +++ b/quinn/src/runtime/mod.rs @@ -0,0 +1,42 @@ +#[cfg(feature = "runtime-tokio")] +mod tokio_runtime; +#[cfg(feature = "runtime-tokio")] +pub use tokio_runtime::*; + +#[cfg(feature = "runtime-async-std")] +mod async_std_runtime; +#[cfg(feature = "runtime-async-std")] +pub use async_std_runtime::*; + +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; + +pub trait AsyncTimer: Send + Debug { + fn reset(self: Pin<&mut Self>, i: Instant); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()>; +} + +pub trait Runtime: udp::runtime::Runtime { + fn new_timer(&self, i: Instant) -> Pin>; + fn spawn(&self, future: Pin + Send>>); +} + +pub fn make_runtime() -> Box { + #[cfg(feature = "runtime-tokio")] + { + if let Ok(_) = tokio::runtime::Handle::try_current() { + return Box::new(crate::TokioRuntime); + } + } + + #[cfg(feature = "runtime-async-std")] + { + return Box::new(crate::AsyncStdRuntime); + } + + #[cfg(not(feature = "runtime-async-std"))] + panic!("No usable runtime found"); +} diff --git a/quinn/src/runtime/tokio_runtime.rs b/quinn/src/runtime/tokio_runtime.rs new file mode 100644 index 0000000000..6d3eb3093a --- /dev/null +++ b/quinn/src/runtime/tokio_runtime.rs @@ -0,0 +1,27 @@ +use super::{AsyncTimer, Runtime}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::time::{sleep_until, Sleep}; + +impl AsyncTimer for Sleep { + fn reset(self: Pin<&mut Self>, t: Instant) { + Sleep::reset(self, t.into()) + } + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + Future::poll(self, cx) + } +} + +pub use udp::runtime::TokioRuntime; + +impl Runtime for TokioRuntime { + fn new_timer(&self, t: Instant) -> Pin> { + Box::pin(sleep_until(t.into())) + } + + fn spawn(&self, future: Pin + Send>>) { + tokio::spawn(future); + } +} diff --git a/quinn/src/send_stream.rs b/quinn/src/send_stream.rs index 0ed445a73f..e756e3f2a7 100644 --- a/quinn/src/send_stream.rs +++ b/quinn/src/send_stream.rs @@ -227,7 +227,7 @@ impl SendStream { #[cfg(feature = "futures-io")] impl futures_io::AsyncWrite for SendStream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { - tokio::io::AsyncWrite::poll_write(self, cx, buf) + SendStream::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { @@ -235,10 +235,11 @@ impl futures_io::AsyncWrite for SendStream { } fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - tokio::io::AsyncWrite::poll_shutdown(self, cx) + self.get_mut().poll_finish(cx).map_err(Into::into) } } +#[cfg(feature = "runtime-tokio")] impl tokio::io::AsyncWrite for SendStream { fn poll_write( self: Pin<&mut Self>, diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index e0259ca759..66a9448ba0 100644 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -8,6 +8,7 @@ use std::{ sync::Arc, }; +use crate::runtime::TokioRuntime; use bytes::Bytes; use rand::{rngs::StdRng, RngCore, SeedableRng}; use tokio::{ @@ -103,7 +104,7 @@ fn local_addr() { let runtime = rt_basic(); let (ep, _) = { let _guard = runtime.enter(); - Endpoint::new(Default::default(), None, socket).unwrap() + Endpoint::new(Default::default(), None, socket, TokioRuntime).unwrap() }; assert_eq!( addr, @@ -454,7 +455,13 @@ fn run_echo(args: EchoArgs) { let server_addr = server_sock.local_addr().unwrap(); let (server, mut server_incoming) = { let _guard = runtime.enter(); - Endpoint::new(Default::default(), Some(server_config), server_sock).unwrap() + Endpoint::new( + Default::default(), + Some(server_config), + server_sock, + TokioRuntime, + ) + .unwrap() }; let mut roots = rustls::RootCertStore::empty();