diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index 6fe05a57d5..38d9181565 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -35,6 +35,7 @@ rustls = { version = "0.20", default-features = false, features = ["quic"], opti thiserror = "1.0.21" tracing = "0.1.10" tokio = { version = "1.0.1", features = ["rt", "time"] } +tokio-util = { version = "0.6.9", features = ["time"] } udp = { package = "quinn-udp", path = "../quinn-udp", version = "0.1.0" } webpki = { version = "0.22", default-features = false, optional = true } diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index e2983e568b..f1ee996c2b 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -1,5 +1,6 @@ use std::{ any::Any, + collections::VecDeque, fmt, future::Future, mem, @@ -11,21 +12,21 @@ use std::{ }; use bytes::Bytes; -use futures_channel::{mpsc, oneshot}; +use futures_channel::oneshot; use futures_util::FutureExt; use fxhash::FxHashMap; use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId}; use thiserror::Error; -use tokio::time::{sleep_until, Instant as TokioInstant, Sleep}; +use tokio::{sync::mpsc, time::Instant as TokioInstant}; +use tokio_util::time::delay_queue; use tracing::info_span; -use udp::UdpState; use crate::{ broadcast::{self, Broadcast}, mutex::Mutex, recv_stream::RecvStream, send_stream::{SendStream, WriteError}, - EndpointEvent, VarInt, + VarInt, }; use proto::congestion::Controller; @@ -40,23 +41,20 @@ pub struct Connecting { impl Connecting { pub(crate) fn new( + dirty: mpsc::UnboundedSender, handle: ConnectionHandle, conn: proto::Connection, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, - udp_state: Arc, ) -> (Connecting, ConnectionRef) { let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel(); let (on_connected_send, on_connected_recv) = oneshot::channel(); let conn = ConnectionRef::new( handle, conn, - endpoint_events, + dirty, on_handshake_data_send, on_connected_send, - udp_state, ); - tokio::spawn(ConnectionDriver(conn.clone())); ( Connecting { conn: Some(conn.clone()), @@ -101,8 +99,10 @@ impl Connecting { drop(conn); if is_ok { - let conn = self.conn.take().unwrap(); - Ok((NewConnection::new(conn), ZeroRttAccepted(self.connected))) + Ok(( + NewConnection::new(self.conn.take().unwrap()), + ZeroRttAccepted(self.connected), + )) } else { Err(self) } @@ -253,53 +253,6 @@ impl NewConnection { } } -/// A future that drives protocol logic for a connection -/// -/// This future handles the protocol logic for a single connection, routing events from the -/// `Connection` API object to the `Endpoint` task and the related stream-related interfaces. -/// It also keeps track of outstanding timeouts for the `Connection`. -/// -/// If the connection encounters an error condition, this future will yield an error. It will -/// terminate (yielding `Ok(())`) if the connection was closed without error. Unlike other -/// connection-related futures, this waits for the draining period to complete to ensure that -/// packets still in flight from the peer are handled gracefully. -#[must_use = "connection drivers must be spawned for their connections to function"] -#[derive(Debug)] -struct ConnectionDriver(ConnectionRef); - -impl Future for ConnectionDriver { - type Output = (); - - #[allow(unused_mut)] // MSRV - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let conn = &mut *self.0.lock("poll"); - - let span = info_span!("drive", id = conn.handle.0); - let _guard = span.enter(); - - let mut keep_going = conn.drive_transmit(); - // If a timer expires, there might be more to transmit. When we transmit something, we - // might need to reset a timer. Hence, we must loop until neither happens. - keep_going |= conn.drive_timer(cx); - conn.forward_endpoint_events(); - conn.forward_app_events(); - - if !conn.inner.is_drained() { - if keep_going { - // If the connection hasn't processed all tasks, schedule it again - cx.waker().wake_by_ref(); - } else { - conn.driver = Some(cx.waker().clone()); - } - return Poll::Pending; - } - if conn.error.is_none() { - unreachable!("drained connections always have an error"); - } - Poll::Ready(()) - } -} - /// A QUIC connection. /// /// If all references to a connection (including every clone of the `Connection` handle, streams of @@ -662,21 +615,21 @@ impl ConnectionRef { fn new( handle: ConnectionHandle, conn: proto::Connection, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + dirty: mpsc::UnboundedSender, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, - udp_state: Arc, ) -> Self { Self(Arc::new(Mutex::new(ConnectionInner { inner: conn, - driver: None, handle, + span: info_span!("connection", id = handle.0), + is_dirty: false, + dirty, on_handshake_data: Some(on_handshake_data), on_connected: Some(on_connected), connected: false, - timer: None, + timer_handle: None, timer_deadline: None, - endpoint_events, blocked_writers: FxHashMap::default(), blocked_readers: FxHashMap::default(), uni_opening: Broadcast::new(), @@ -688,7 +641,6 @@ impl ConnectionRef { stopped: FxHashMap::default(), error: None, ref_count: 0, - udp_state, }))) } @@ -729,14 +681,15 @@ impl std::ops::Deref for ConnectionRef { pub struct ConnectionInner { pub(crate) inner: proto::Connection, - driver: Option, handle: ConnectionHandle, + pub(crate) span: tracing::Span, + pub(crate) is_dirty: bool, + dirty: mpsc::UnboundedSender, on_handshake_data: Option>, on_connected: Option>, connected: bool, - timer: Option>>, - timer_deadline: Option, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + pub(crate) timer_handle: Option, + pub(crate) timer_deadline: Option, pub(crate) blocked_writers: FxHashMap, pub(crate) blocked_readers: FxHashMap, uni_opening: Broadcast, @@ -750,25 +703,23 @@ pub struct ConnectionInner { pub(crate) error: Option, /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, - udp_state: Arc, } impl ConnectionInner { - fn drive_transmit(&mut self) -> bool { + pub(crate) fn drive_transmit( + &mut self, + max_datagrams: usize, + out: &mut VecDeque, + ) -> bool { let now = Instant::now(); let mut transmits = 0; - let max_datagrams = self.udp_state.max_gso_segments(); - while let Some(t) = self.inner.poll_transmit(now, max_datagrams) { transmits += match t.segment_size { None => 1, Some(s) => (t.contents.len() + s - 1) / s, // round up }; - // If the endpoint driver is gone, noop. - let _ = self - .endpoint_events - .unbounded_send((self.handle, EndpointEvent::Transmit(t))); + out.push_back(t); if transmits >= MAX_TRANSMIT_DATAGRAMS { // TODO: What isn't ideal here yet is that if we don't poll all @@ -782,16 +733,7 @@ impl ConnectionInner { false } - fn forward_endpoint_events(&mut self) { - while let Some(event) = self.inner.poll_endpoint_events() { - // If the endpoint driver is gone, noop. - let _ = self - .endpoint_events - .unbounded_send((self.handle, EndpointEvent::Proto(event))); - } - } - - fn forward_app_events(&mut self) { + pub(crate) fn forward_app_events(&mut self) { while let Some(event) = self.inner.poll() { use proto::Event::*; match event { @@ -863,61 +805,14 @@ impl ConnectionInner { } } - fn drive_timer(&mut self, cx: &mut Context) -> bool { - // Check whether we need to (re)set the timer. If so, we must poll again to ensure the - // timer is registered with the runtime (and check whether it's already - // expired). - match self.inner.poll_timeout().map(TokioInstant::from_std) { - Some(deadline) => { - if let Some(delay) = &mut self.timer { - // There is no need to reset the tokio timer if the deadline - // did not change - if self - .timer_deadline - .map(|current_deadline| current_deadline != deadline) - .unwrap_or(true) - { - delay.as_mut().reset(deadline); - } - } else { - self.timer = Some(Box::pin(sleep_until(deadline))); - } - // Store the actual expiration time of the timer - self.timer_deadline = Some(deadline); - } - None => { - self.timer_deadline = None; - return false; - } - } - - if self.timer_deadline.is_none() { - return false; - } - - let delay = self - .timer - .as_mut() - .expect("timer must exist in this state") - .as_mut(); - if delay.poll(cx).is_pending() { - // Since there wasn't a timeout event, there is nothing new - // for the connection to do - return false; - } - - // A timer expired, so the caller needs to check for - // new transmits, which might cause new timers to be set. - self.inner.handle_timeout(Instant::now()); - self.timer_deadline = None; - true - } - - /// Wake up a blocked `Driver` task to process I/O + /// Wake up endpoint to process I/O pub(crate) fn wake(&mut self) { - if let Some(x) = self.driver.take() { - x.wake(); + if self.is_dirty { + return; } + self.is_dirty = true; + // Take no action if the endpoint is gone + let _ = self.dirty.send(self.handle); } /// Used to wake up all blocked futures when the connection becomes closed for any reason @@ -977,18 +872,6 @@ impl ConnectionInner { } } -impl Drop for ConnectionInner { - fn drop(&mut self) { - if !self.inner.is_drained() { - // Ensure the endpoint can tidy up - let _ = self.endpoint_events.unbounded_send(( - self.handle, - EndpointEvent::Proto(proto::EndpointEvent::drained()), - )); - } - } -} - impl fmt::Debug for ConnectionInner { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("ConnectionInner") diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index bb7e583f8a..ae760eb69e 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -13,19 +13,19 @@ use std::{ }; use bytes::Bytes; -use futures_channel::mpsc; -use futures_util::StreamExt; use fxhash::FxHashMap; use proto::{ self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig, }; +use tokio::sync::mpsc; +use tokio_util::time::DelayQueue; use udp::{RecvMeta, UdpSocket, UdpState, BATCH_SIZE}; use crate::{ broadcast::{self, Broadcast}, connection::{Connecting, ConnectionRef}, work_limiter::WorkLimiter, - EndpointConfig, EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, SEND_TIME_BOUND, + EndpointConfig, VarInt, RECV_TIME_BOUND, SEND_TIME_BOUND, }; /// A QUIC endpoint. @@ -146,8 +146,9 @@ impl Endpoint { addr }; let (ch, conn) = endpoint.inner.connect(config, addr, server_name)?; - let udp_state = endpoint.udp_state.clone(); - Ok(endpoint.connections.insert(ch, conn, udp_state)) + let dirty = endpoint.dirty_send.clone(); + dirty.send(ch).unwrap(); + Ok(endpoint.connections.insert(dirty, ch, conn)) } /// Switch to a new UDP socket @@ -252,15 +253,76 @@ impl Future for EndpointDriver { #[allow(unused_mut)] // MSRV fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let mut endpoint = self.0.lock().unwrap(); + let mut endpoint = &mut *self.0.lock().unwrap(); if endpoint.driver.is_none() { endpoint.driver = Some(cx.waker().clone()); } - let now = Instant::now(); let mut keep_going = false; - keep_going |= endpoint.drive_recv(cx, now)?; - keep_going |= endpoint.handle_events(cx); + + while let Poll::Ready(Some(result)) = endpoint.timers.poll_expired(cx) { + let conn_handle = result.unwrap().into_inner(); + let conn = match endpoint.connections.connections.get(&conn_handle) { + Some(c) => c, + None => continue, + }; + let mut conn = &mut *conn.lock("poll timeouts"); + let _guard = conn.span.clone().entered(); + conn.inner.handle_timeout(Instant::now()); + conn.timer_handle = None; + conn.timer_deadline = None; + conn.wake(); + } + + let max_datagrams = endpoint.udp_state.max_gso_segments(); + let mut drained = Vec::new(); + while let Poll::Ready(Some(conn_handle)) = endpoint.dirty.poll_recv(cx) { + let conn = match endpoint.connections.connections.get(&conn_handle) { + Some(c) => c, + None => continue, + }; + let mut conn = conn.lock("poll dirty"); + conn.is_dirty = false; + let _guard = conn.span.clone().entered(); + let mut keep_conn_going = conn.drive_transmit(max_datagrams, &mut endpoint.outgoing); + if let Some(deadline) = conn.inner.poll_timeout() { + let deadline = tokio::time::Instant::from(deadline); + if Some(deadline) != conn.timer_deadline { + match conn.timer_handle { + Some(ref key) => endpoint.timers.reset_at(key, deadline.into()), + None => { + conn.timer_handle = + Some(endpoint.timers.insert_at(conn_handle, deadline.into())); + } + } + // endpoint.timers may need to be polled + keep_going = true; + } + } + while let Some(event) = conn.inner.poll_endpoint_events() { + if event.is_drained() { + drained.push(conn_handle); + } + if let Some(event) = endpoint.inner.handle_event(conn_handle, event) { + conn.inner.handle_event(event); + keep_conn_going = true; + } + } + conn.forward_app_events(); + if keep_conn_going { + conn.wake(); + keep_going = true; + } + } + + for conn_handle in drained { + endpoint.connections.connections.remove(&conn_handle); + } + if endpoint.connections.is_empty() { + endpoint.idle.wake(); + } + + keep_going |= endpoint.drive_recv(cx, Instant::now())?; keep_going |= endpoint.drive_send(cx)?; if !endpoint.incoming.is_empty() { @@ -305,7 +367,6 @@ pub(crate) struct EndpointInner { driver: Option, ipv6: bool, connections: ConnectionSet, - events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>, /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, driver_lost: bool, @@ -313,6 +374,9 @@ pub(crate) struct EndpointInner { recv_buf: Box<[u8]>, send_limiter: WorkLimiter, idle: Broadcast, + dirty: mpsc::UnboundedReceiver, + dirty_send: mpsc::UnboundedSender, + timers: DelayQueue, } impl EndpointInner { @@ -343,7 +407,7 @@ impl EndpointInner { Some((handle, DatagramEvent::NewConnection(conn))) => { let conn = self.connections - .insert(handle, conn, self.udp_state.clone()); + .insert(self.dirty_send.clone(), handle, conn); self.incoming.push_back(conn); } Some((handle, DatagramEvent::ConnectionEvent(event))) => { @@ -419,45 +483,11 @@ impl EndpointInner { self.send_limiter.finish_cycle(); result } - - fn handle_events(&mut self, cx: &mut Context) -> bool { - use EndpointEvent::*; - - for _ in 0..IO_LOOP_BOUND { - match self.events.poll_next_unpin(cx) { - Poll::Ready(Some((ch, event))) => match event { - Proto(e) => { - if e.is_drained() { - self.connections.connections.remove(&ch); - if self.connections.is_empty() { - self.idle.wake(); - } - } - if let Some(event) = self.inner.handle_event(ch, e) { - let conn = self.connections.connections.get(&ch).unwrap(); - let mut conn = conn.lock("handle_event"); - conn.inner.handle_event(event); - conn.wake(); - } - } - Transmit(t) => self.outgoing.push_back(t), - }, - Poll::Ready(None) => unreachable!("EndpointInner owns one sender"), - Poll::Pending => { - return false; - } - } - } - - true - } } #[derive(Debug)] struct ConnectionSet { connections: FxHashMap, - /// Stored to give out clones to new ConnectionInners - sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, /// Set if the endpoint has been manually closed close: Option<(VarInt, Bytes)>, } @@ -465,11 +495,11 @@ struct ConnectionSet { impl ConnectionSet { fn insert( &mut self, + dirty: mpsc::UnboundedSender, handle: ConnectionHandle, conn: proto::Connection, - udp_state: Arc, ) -> Connecting { - let (future, conn) = Connecting::new(handle, conn, self.sender.clone(), udp_state); + let (future, conn) = Connecting::new(dirty, handle, conn); if let Some((error_code, ref reason)) = self.close { let mut conn = conn.lock("close"); conn.close(error_code, reason.clone()); @@ -534,20 +564,18 @@ impl EndpointRef { pub(crate) fn new(socket: UdpSocket, inner: proto::Endpoint, ipv6: bool) -> Self { let recv_buf = vec![0; inner.config().get_max_udp_payload_size().min(64 * 1024) as usize * BATCH_SIZE]; - let (sender, events) = mpsc::unbounded(); + let (dirty_send, dirty) = mpsc::unbounded_channel(); Self(Arc::new(Mutex::new(EndpointInner { socket, udp_state: Arc::new(UdpState::new()), inner, ipv6, - events, outgoing: VecDeque::new(), incoming: VecDeque::new(), incoming_reader: None, driver: None, connections: ConnectionSet { connections: FxHashMap::default(), - sender, close: None, }, ref_count: 0, @@ -556,6 +584,9 @@ impl EndpointRef { recv_limiter: WorkLimiter::new(RECV_TIME_BOUND), send_limiter: WorkLimiter::new(SEND_TIME_BOUND), idle: Broadcast::new(), + dirty, + dirty_send, + timers: DelayQueue::new(), }))) } } diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index ce2ee637f2..31a83cf22a 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -69,12 +69,6 @@ pub use crate::send_stream::{SendStream, StoppedError, WriteError}; #[cfg(test)] mod tests; -#[derive(Debug)] -enum EndpointEvent { - Proto(proto::EndpointEvent), - Transmit(proto::Transmit), -} - /// Maximum number of datagrams processed in send/recv calls to make before moving on to other processing /// /// This helps ensure we don't starve anything when the CPU is slower than the link.