diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index 877b72bfb..36fd00067 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -45,6 +45,7 @@ rustls = { version = "0.20.3", default-features = false, features = ["quic"], op thiserror = "1.0.21" tracing = "0.1.10" tokio = { version = "1.13.0", features = ["sync"] } +tokio-util = { version = "0.6.9", features = ["time"] } udp = { package = "quinn-udp", path = "../quinn-udp", version = "0.3", default-features = false } webpki = { version = "0.22", default-features = false, optional = true } diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 1c3a65740..e36a983b7 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -1,5 +1,6 @@ use std::{ any::Any, + collections::VecDeque, fmt, future::Future, net::{IpAddr, SocketAddr}, @@ -9,13 +10,16 @@ use std::{ time::{Duration, Instant}, }; -use crate::runtime::{AsyncTimer, Runtime}; use bytes::Bytes; use pin_project_lite::pin_project; use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId}; use rustc_hash::FxHashMap; use thiserror::Error; -use tokio::sync::{futures::Notified, mpsc, oneshot, Notify}; +use tokio::{ + sync::{futures::Notified, mpsc, oneshot, Notify}, + time::Instant as TokioInstant, +}; +use tokio_util::time::delay_queue; use tracing::debug_span; use udp::UdpState; @@ -23,7 +27,7 @@ use crate::{ mutex::Mutex, recv_stream::RecvStream, send_stream::{SendStream, WriteError}, - EndpointEvent, VarInt, + VarInt, }; use proto::congestion::Controller; @@ -38,25 +42,22 @@ pub struct Connecting { impl Connecting { pub(crate) fn new( + dirty: mpsc::UnboundedSender, handle: ConnectionHandle, conn: proto::Connection, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, udp_state: Arc, - runtime: Arc, ) -> (Connecting, ConnectionRef) { let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel(); let (on_connected_send, on_connected_recv) = oneshot::channel(); let conn = ConnectionRef::new( handle, conn, - endpoint_events, + dirty, on_handshake_data_send, on_connected_send, udp_state, - runtime.clone(), ); - runtime.spawn(Box::pin(ConnectionDriver(conn.clone()))); ( Connecting { conn: Some(conn.clone()), @@ -202,53 +203,6 @@ impl Future for ZeroRttAccepted { } } -/// A future that drives protocol logic for a connection -/// -/// This future handles the protocol logic for a single connection, routing events from the -/// `Connection` API object to the `Endpoint` task and the related stream-related interfaces. -/// It also keeps track of outstanding timeouts for the `Connection`. -/// -/// If the connection encounters an error condition, this future will yield an error. It will -/// terminate (yielding `Ok(())`) if the connection was closed without error. Unlike other -/// connection-related futures, this waits for the draining period to complete to ensure that -/// packets still in flight from the peer are handled gracefully. -#[must_use = "connection drivers must be spawned for their connections to function"] -#[derive(Debug)] -struct ConnectionDriver(ConnectionRef); - -impl Future for ConnectionDriver { - type Output = (); - - #[allow(unused_mut)] // MSRV - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let conn = &mut *self.0.state.lock("poll"); - - let span = debug_span!("drive", id = conn.handle.0); - let _guard = span.enter(); - - let mut keep_going = conn.drive_transmit(); - // If a timer expires, there might be more to transmit. When we transmit something, we - // might need to reset a timer. Hence, we must loop until neither happens. - keep_going |= conn.drive_timer(cx); - conn.forward_endpoint_events(); - conn.forward_app_events(&self.0.shared); - - if !conn.inner.is_drained() { - if keep_going { - // If the connection hasn't processed all tasks, schedule it again - cx.waker().wake_by_ref(); - } else { - conn.driver = Some(cx.waker().clone()); - } - return Poll::Pending; - } - if conn.error.is_none() { - unreachable!("drained connections always have an error"); - } - Poll::Ready(()) - } -} - /// A QUIC connection. /// /// If all references to a connection (including every clone of the `Connection` handle, streams of @@ -741,23 +695,24 @@ impl ConnectionRef { fn new( handle: ConnectionHandle, conn: proto::Connection, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + dirty: mpsc::UnboundedSender, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, udp_state: Arc, - runtime: Arc, ) -> Self { + let _ = dirty.send(handle); Self(Arc::new(ConnectionInner { state: Mutex::new(State { inner: conn, - driver: None, handle, + span: debug_span!("connection", id = handle.0), + is_dirty: true, + dirty, on_handshake_data: Some(on_handshake_data), on_connected: Some(on_connected), connected: false, - timer: None, + timer_handle: None, timer_deadline: None, - endpoint_events, blocked_writers: FxHashMap::default(), blocked_readers: FxHashMap::default(), finishing: FxHashMap::default(), @@ -765,7 +720,6 @@ impl ConnectionRef { error: None, ref_count: 0, udp_state, - runtime, }), shared: Shared::default(), })) @@ -825,14 +779,18 @@ pub(crate) struct Shared { pub(crate) struct State { pub(crate) inner: proto::Connection, - driver: Option, handle: ConnectionHandle, + pub(crate) span: tracing::Span, + /// Whether `handle` has been sent to `dirty` since the last time this connection was driven by + /// the endpoint. Ensures `dirty`'s size stays bounded regardless of activity. + pub(crate) is_dirty: bool, + /// `handle` is sent here when `wake` is called, prompting the endpoint to drive the connection + dirty: mpsc::UnboundedSender, on_handshake_data: Option>, on_connected: Option>, connected: bool, - timer: Option>>, - timer_deadline: Option, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + pub(crate) timer_handle: Option, + pub(crate) timer_deadline: Option, pub(crate) blocked_writers: FxHashMap, pub(crate) blocked_readers: FxHashMap, pub(crate) finishing: FxHashMap>>, @@ -842,11 +800,10 @@ pub(crate) struct State { /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, udp_state: Arc, - runtime: Arc, } impl State { - fn drive_transmit(&mut self) -> bool { + pub(crate) fn drive_transmit(&mut self, out: &mut VecDeque) -> bool { let now = Instant::now(); let mut transmits = 0; @@ -857,10 +814,7 @@ impl State { None => 1, Some(s) => (t.contents.len() + s - 1) / s, // round up }; - // If the endpoint driver is gone, noop. - let _ = self - .endpoint_events - .send((self.handle, EndpointEvent::Transmit(t))); + out.push_back(t); if transmits >= MAX_TRANSMIT_DATAGRAMS { // TODO: What isn't ideal here yet is that if we don't poll all @@ -874,16 +828,7 @@ impl State { false } - fn forward_endpoint_events(&mut self) { - while let Some(event) = self.inner.poll_endpoint_events() { - // If the endpoint driver is gone, noop. - let _ = self - .endpoint_events - .send((self.handle, EndpointEvent::Proto(event))); - } - } - - fn forward_app_events(&mut self, shared: &Shared) { + pub(crate) fn forward_app_events(&mut self, shared: &Shared) { while let Some(event) = self.inner.poll() { use proto::Event::*; match event { @@ -949,61 +894,14 @@ impl State { } } - fn drive_timer(&mut self, cx: &mut Context) -> bool { - // Check whether we need to (re)set the timer. If so, we must poll again to ensure the - // timer is registered with the runtime (and check whether it's already - // expired). - match self.inner.poll_timeout() { - Some(deadline) => { - if let Some(delay) = &mut self.timer { - // There is no need to reset the tokio timer if the deadline - // did not change - if self - .timer_deadline - .map(|current_deadline| current_deadline != deadline) - .unwrap_or(true) - { - delay.as_mut().reset(deadline); - } - } else { - self.timer = Some(self.runtime.new_timer(deadline)); - } - // Store the actual expiration time of the timer - self.timer_deadline = Some(deadline); - } - None => { - self.timer_deadline = None; - return false; - } - } - - if self.timer_deadline.is_none() { - return false; - } - - let delay = self - .timer - .as_mut() - .expect("timer must exist in this state") - .as_mut(); - if delay.poll(cx).is_pending() { - // Since there wasn't a timeout event, there is nothing new - // for the connection to do - return false; - } - - // A timer expired, so the caller needs to check for - // new transmits, which might cause new timers to be set. - self.inner.handle_timeout(Instant::now()); - self.timer_deadline = None; - true - } - - /// Wake up a blocked `Driver` task to process I/O + /// Wake up endpoint to process I/O by marking it as "dirty" for the endpoint pub(crate) fn wake(&mut self) { - if let Some(x) = self.driver.take() { - x.wake(); + if self.is_dirty { + return; } + self.is_dirty = true; + // Take no action if the endpoint is gone + let _ = self.dirty.send(self.handle); } /// Used to wake up all blocked futures when the connection becomes closed for any reason @@ -1058,18 +956,6 @@ impl State { } } -impl Drop for State { - fn drop(&mut self) { - if !self.inner.is_drained() { - // Ensure the endpoint can tidy up - let _ = self.endpoint_events.send(( - self.handle, - EndpointEvent::Proto(proto::EndpointEvent::drained()), - )); - } - } -} - impl fmt::Debug for State { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("State").field("inner", &self.inner).finish() diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 94718f45a..1265452d0 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -20,12 +20,13 @@ use proto::{ }; use rustc_hash::FxHashMap; use tokio::sync::{futures::Notified, mpsc, Notify}; +use tokio_util::time::DelayQueue; use udp::{RecvMeta, UdpState, BATCH_SIZE}; use crate::{ connection::{Connecting, ConnectionRef}, work_limiter::WorkLimiter, - EndpointConfig, EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, SEND_TIME_BOUND, + EndpointConfig, VarInt, RECV_TIME_BOUND, SEND_TIME_BOUND, }; /// A QUIC endpoint. @@ -119,7 +120,6 @@ impl Endpoint { socket, proto::Endpoint::new(Arc::new(config), server_config.map(Arc::new)), addr.is_ipv6(), - runtime.clone(), ); let driver = EndpointDriver(rc.clone()); runtime.spawn(Box::pin(async { @@ -192,9 +192,8 @@ impl Endpoint { }; let (ch, conn) = endpoint.inner.connect(config, addr, server_name)?; let udp_state = endpoint.udp_state.clone(); - Ok(endpoint - .connections - .insert(ch, conn, udp_state, self.runtime.clone())) + let dirty = endpoint.dirty_send.clone(); + Ok(endpoint.connections.insert(dirty, ch, conn, udp_state)) } /// Switch to a new UDP socket @@ -297,15 +296,13 @@ impl Future for EndpointDriver { #[allow(unused_mut)] // MSRV fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let mut endpoint = self.0.state.lock().unwrap(); + let mut endpoint = &mut *self.0.state.lock().unwrap(); if endpoint.driver.is_none() { endpoint.driver = Some(cx.waker().clone()); } - let now = Instant::now(); - let mut keep_going = false; - keep_going |= endpoint.drive_recv(cx, now)?; - keep_going |= endpoint.handle_events(cx, &self.0.shared); + let mut keep_going = endpoint.drive_recv(cx, Instant::now())?; + keep_going |= endpoint.drive_connections(cx, &self.0.shared); keep_going |= endpoint.drive_send(cx)?; if !endpoint.incoming.is_empty() { @@ -315,10 +312,6 @@ impl Future for EndpointDriver { if endpoint.ref_count == 0 && endpoint.connections.is_empty() { Poll::Ready(Ok(())) } else { - drop(endpoint); - // If there is more work to do schedule the endpoint task again. - // `wake_by_ref()` is called outside the lock to minimize - // lock contention on a multithreaded runtime. if keep_going { cx.waker().wake_by_ref(); } @@ -351,14 +344,19 @@ pub(crate) struct State { driver: Option, ipv6: bool, connections: ConnectionSet, - events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>, /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, driver_lost: bool, recv_limiter: WorkLimiter, recv_buf: Box<[u8]>, send_limiter: WorkLimiter, - runtime: Arc, + /// Connections add themselves to this queue when they need to be driven + /// + /// Occurs e.g. due to application-layer activity + dirty_recv: mpsc::UnboundedReceiver, + /// Passed in to connections to enable the above + dirty_send: mpsc::UnboundedSender, + timers: DelayQueue, } #[derive(Debug)] @@ -396,10 +394,10 @@ impl State { { Some((handle, DatagramEvent::NewConnection(conn))) => { let conn = self.connections.insert( + self.dirty_send.clone(), handle, conn, self.udp_state.clone(), - self.runtime.clone(), ); self.incoming.push_back(conn); } @@ -478,44 +476,86 @@ impl State { result } - fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool { - use EndpointEvent::*; - - for _ in 0..IO_LOOP_BOUND { - match self.events.poll_recv(cx) { - Poll::Ready(Some((ch, event))) => match event { - Proto(e) => { - if e.is_drained() { - self.connections.refs.remove(&ch); - if self.connections.is_empty() { - shared.idle.notify_waiters(); - } - } - if let Some(event) = self.inner.handle_event(ch, e) { - let conn = self.connections.refs.get(&ch).unwrap(); - let mut conn = conn.state.lock("handle_event"); - conn.inner.handle_event(event); - conn.wake(); + /// Process connections on which there's been timeouts, packets received, or application + /// activity ("dirty" connections) + fn drive_connections(&mut self, cx: &mut Context, shared: &Shared) -> bool { + let mut keep_going = false; + + while let Poll::Ready(Some(result)) = self.timers.poll_expired(cx) { + let conn_handle = result.unwrap().into_inner(); + let conn = match self.connections.refs.get(&conn_handle) { + Some(c) => c, + None => continue, + }; + let mut state = &mut *conn.state.lock("poll timeouts"); + let _guard = state.span.clone().entered(); + state.inner.handle_timeout(Instant::now()); + state.timer_handle = None; + state.timer_deadline = None; + state.wake(); + } + + let mut dirty_buffer = Vec::new(); + + // Buffer the list of initially dirty connections, guaranteeing that the connection + // processing loop below takes a predictable amount of time. + while let Poll::Ready(Some(conn_handle)) = self.dirty_recv.poll_recv(cx) { + dirty_buffer.push(conn_handle); + } + + let mut drained = Vec::new(); + for conn_handle in dirty_buffer { + let conn = match self.connections.refs.get(&conn_handle) { + Some(c) => c, + None => continue, + }; + let mut state = conn.state.lock("poll dirty"); + state.is_dirty = false; + let _guard = state.span.clone().entered(); + let mut keep_conn_going = state.drive_transmit(&mut self.outgoing); + if let Some(deadline) = state.inner.poll_timeout() { + let deadline = tokio::time::Instant::from(deadline); + if Some(deadline) != state.timer_deadline { + match state.timer_handle { + Some(ref key) => self.timers.reset_at(key, deadline), + None => { + state.timer_handle = Some(self.timers.insert_at(conn_handle, deadline)); } } - Transmit(t) => self.outgoing.push_back(t), - }, - Poll::Ready(None) => unreachable!("EndpointInner owns one sender"), - Poll::Pending => { - return false; + // self.timers may need to be polled + keep_going = true; } } + while let Some(event) = state.inner.poll_endpoint_events() { + if event.is_drained() { + drained.push(conn_handle); + } + if let Some(event) = self.inner.handle_event(conn_handle, event) { + state.inner.handle_event(event); + keep_conn_going = true; + } + } + state.forward_app_events(&conn.shared); + if keep_conn_going { + state.wake(); + keep_going = true; + } } - true + for conn_handle in drained { + self.connections.refs.remove(&conn_handle); + } + if self.connections.is_empty() { + shared.idle.notify_waiters(); + } + + keep_going } } #[derive(Debug)] struct ConnectionSet { refs: FxHashMap, - /// Stored to give out clones to new ConnectionInners - sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, /// Set if the endpoint has been manually closed close: Option<(VarInt, Bytes)>, } @@ -523,12 +563,12 @@ struct ConnectionSet { impl ConnectionSet { fn insert( &mut self, + dirty: mpsc::UnboundedSender, handle: ConnectionHandle, conn: proto::Connection, udp_state: Arc, - runtime: Arc, ) -> Connecting { - let (future, conn) = Connecting::new(handle, conn, self.sender.clone(), udp_state, runtime); + let (future, conn) = Connecting::new(dirty, handle, conn, udp_state); if let Some((error_code, ref reason)) = self.close { let mut state = conn.state.lock("close"); state.close(error_code, reason.clone(), &conn.shared); @@ -589,12 +629,7 @@ impl<'a> Future for Accept<'a> { pub(crate) struct EndpointRef(Arc); impl EndpointRef { - pub(crate) fn new( - socket: Box, - inner: proto::Endpoint, - ipv6: bool, - runtime: Arc, - ) -> Self { + pub(crate) fn new(socket: Box, inner: proto::Endpoint, ipv6: bool) -> Self { let udp_state = Arc::new(UdpState::new()); let recv_buf = vec![ 0; @@ -602,7 +637,7 @@ impl EndpointRef { * udp_state.gro_segments() * BATCH_SIZE ]; - let (sender, events) = mpsc::unbounded_channel(); + let (dirty_send, dirty_recv) = mpsc::unbounded_channel(); Self(Arc::new(EndpointInner { shared: Shared { incoming: Notify::new(), @@ -613,13 +648,11 @@ impl EndpointRef { udp_state, inner, ipv6, - events, outgoing: VecDeque::new(), incoming: VecDeque::new(), driver: None, connections: ConnectionSet { refs: FxHashMap::default(), - sender, close: None, }, ref_count: 0, @@ -627,7 +660,9 @@ impl EndpointRef { recv_buf: recv_buf.into(), recv_limiter: WorkLimiter::new(RECV_TIME_BOUND), send_limiter: WorkLimiter::new(SEND_TIME_BOUND), - runtime, + dirty_recv, + dirty_send, + timers: DelayQueue::new(), }), })) } diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index e3d420384..70cff11e8 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -81,18 +81,6 @@ pub use crate::send_stream::{SendStream, StoppedError, WriteError}; #[cfg(test)] mod tests; -#[derive(Debug)] -enum EndpointEvent { - Proto(proto::EndpointEvent), - Transmit(proto::Transmit), -} - -/// Maximum number of datagrams processed in send/recv calls to make before moving on to other processing -/// -/// This helps ensure we don't starve anything when the CPU is slower than the link. -/// Value is selected by picking a low number which didn't degrade throughput in benchmarks. -const IO_LOOP_BOUND: usize = 160; - /// The maximum amount of time that should be spent in `recvmsg()` calls per endpoint iteration /// /// 50us are chosen so that an endpoint iteration with a 50us sendmsg limit blocks