Skip to content

Commit

Permalink
Unify connection/endpoint drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralith committed Mar 4, 2023
1 parent 1200f91 commit f8a8234
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 213 deletions.
1 change: 1 addition & 0 deletions quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ rustls = { version = "0.20.3", default-features = false, features = ["quic"], op
thiserror = "1.0.21"
tracing = "0.1.10"
tokio = { version = "1.13.0", features = ["sync"] }
tokio-util = { version = "0.6.9", features = ["time"] }
udp = { package = "quinn-udp", path = "../quinn-udp", version = "0.3", default-features = false }
webpki = { version = "0.22", default-features = false, optional = true }

Expand Down
178 changes: 32 additions & 146 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
any::Any,
collections::VecDeque,
fmt,
future::Future,
net::{IpAddr, SocketAddr},
Expand All @@ -9,21 +10,24 @@ use std::{
time::{Duration, Instant},
};

use crate::runtime::{AsyncTimer, Runtime};
use bytes::Bytes;
use pin_project_lite::pin_project;
use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId};
use rustc_hash::FxHashMap;
use thiserror::Error;
use tokio::sync::{futures::Notified, mpsc, oneshot, Notify};
use tokio::{
sync::{futures::Notified, mpsc, oneshot, Notify},
time::Instant as TokioInstant,
};
use tokio_util::time::delay_queue;
use tracing::debug_span;
use udp::UdpState;

use crate::{
mutex::Mutex,
recv_stream::RecvStream,
send_stream::{SendStream, WriteError},
EndpointEvent, VarInt,
VarInt,
};
use proto::congestion::Controller;

Expand All @@ -38,25 +42,22 @@ pub struct Connecting {

impl Connecting {
pub(crate) fn new(
dirty: mpsc::UnboundedSender<ConnectionHandle>,
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
udp_state: Arc<UdpState>,
runtime: Arc<dyn Runtime>,
) -> (Connecting, ConnectionRef) {
let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel();
let (on_connected_send, on_connected_recv) = oneshot::channel();
let conn = ConnectionRef::new(
handle,
conn,
endpoint_events,
dirty,
on_handshake_data_send,
on_connected_send,
udp_state,
runtime.clone(),
);

runtime.spawn(Box::pin(ConnectionDriver(conn.clone())));
(
Connecting {
conn: Some(conn.clone()),
Expand Down Expand Up @@ -202,53 +203,6 @@ impl Future for ZeroRttAccepted {
}
}

/// A future that drives protocol logic for a connection
///
/// This future handles the protocol logic for a single connection, routing events from the
/// `Connection` API object to the `Endpoint` task and the related stream-related interfaces.
/// It also keeps track of outstanding timeouts for the `Connection`.
///
/// If the connection encounters an error condition, this future will yield an error. It will
/// terminate (yielding `Ok(())`) if the connection was closed without error. Unlike other
/// connection-related futures, this waits for the draining period to complete to ensure that
/// packets still in flight from the peer are handled gracefully.
#[must_use = "connection drivers must be spawned for their connections to function"]
#[derive(Debug)]
struct ConnectionDriver(ConnectionRef);

impl Future for ConnectionDriver {
type Output = ();

#[allow(unused_mut)] // MSRV
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let conn = &mut *self.0.state.lock("poll");

let span = debug_span!("drive", id = conn.handle.0);
let _guard = span.enter();

let mut keep_going = conn.drive_transmit();
// If a timer expires, there might be more to transmit. When we transmit something, we
// might need to reset a timer. Hence, we must loop until neither happens.
keep_going |= conn.drive_timer(cx);
conn.forward_endpoint_events();
conn.forward_app_events(&self.0.shared);

if !conn.inner.is_drained() {
if keep_going {
// If the connection hasn't processed all tasks, schedule it again
cx.waker().wake_by_ref();
} else {
conn.driver = Some(cx.waker().clone());
}
return Poll::Pending;
}
if conn.error.is_none() {
unreachable!("drained connections always have an error");
}
Poll::Ready(())
}
}

/// A QUIC connection.
///
/// If all references to a connection (including every clone of the `Connection` handle, streams of
Expand Down Expand Up @@ -741,31 +695,31 @@ impl ConnectionRef {
fn new(
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
dirty: mpsc::UnboundedSender<ConnectionHandle>,
on_handshake_data: oneshot::Sender<()>,
on_connected: oneshot::Sender<bool>,
udp_state: Arc<UdpState>,
runtime: Arc<dyn Runtime>,
) -> Self {
let _ = dirty.send(handle);
Self(Arc::new(ConnectionInner {
state: Mutex::new(State {
inner: conn,
driver: None,
handle,
span: debug_span!("connection", id = handle.0),
is_dirty: true,
dirty,
on_handshake_data: Some(on_handshake_data),
on_connected: Some(on_connected),
connected: false,
timer: None,
timer_handle: None,
timer_deadline: None,
endpoint_events,
blocked_writers: FxHashMap::default(),
blocked_readers: FxHashMap::default(),
finishing: FxHashMap::default(),
stopped: FxHashMap::default(),
error: None,
ref_count: 0,
udp_state,
runtime,
}),
shared: Shared::default(),
}))
Expand Down Expand Up @@ -825,14 +779,18 @@ pub(crate) struct Shared {

pub(crate) struct State {
pub(crate) inner: proto::Connection,
driver: Option<Waker>,
handle: ConnectionHandle,
pub(crate) span: tracing::Span,
/// Whether `handle` has been sent to `dirty` since the last time this connection was driven by
/// the endpoint. Ensures `dirty`'s size stays bounded regardless of activity.
pub(crate) is_dirty: bool,
/// `handle` is sent here when `wake` is called, prompting the endpoint to drive the connection
dirty: mpsc::UnboundedSender<ConnectionHandle>,
on_handshake_data: Option<oneshot::Sender<()>>,
on_connected: Option<oneshot::Sender<bool>>,
connected: bool,
timer: Option<Pin<Box<dyn AsyncTimer>>>,
timer_deadline: Option<Instant>,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
pub(crate) timer_handle: Option<delay_queue::Key>,
pub(crate) timer_deadline: Option<TokioInstant>,
pub(crate) blocked_writers: FxHashMap<StreamId, Waker>,
pub(crate) blocked_readers: FxHashMap<StreamId, Waker>,
pub(crate) finishing: FxHashMap<StreamId, oneshot::Sender<Option<WriteError>>>,
Expand All @@ -842,11 +800,10 @@ pub(crate) struct State {
/// Number of live handles that can be used to initiate or handle I/O; excludes the driver
ref_count: usize,
udp_state: Arc<UdpState>,
runtime: Arc<dyn Runtime>,
}

impl State {
fn drive_transmit(&mut self) -> bool {
pub(crate) fn drive_transmit(&mut self, out: &mut VecDeque<proto::Transmit>) -> bool {
let now = Instant::now();
let mut transmits = 0;

Expand All @@ -857,10 +814,7 @@ impl State {
None => 1,
Some(s) => (t.contents.len() + s - 1) / s, // round up
};
// If the endpoint driver is gone, noop.
let _ = self
.endpoint_events
.send((self.handle, EndpointEvent::Transmit(t)));
out.push_back(t);

if transmits >= MAX_TRANSMIT_DATAGRAMS {
// TODO: What isn't ideal here yet is that if we don't poll all
Expand All @@ -874,16 +828,7 @@ impl State {
false
}

fn forward_endpoint_events(&mut self) {
while let Some(event) = self.inner.poll_endpoint_events() {
// If the endpoint driver is gone, noop.
let _ = self
.endpoint_events
.send((self.handle, EndpointEvent::Proto(event)));
}
}

fn forward_app_events(&mut self, shared: &Shared) {
pub(crate) fn forward_app_events(&mut self, shared: &Shared) {
while let Some(event) = self.inner.poll() {
use proto::Event::*;
match event {
Expand Down Expand Up @@ -949,61 +894,14 @@ impl State {
}
}

fn drive_timer(&mut self, cx: &mut Context) -> bool {
// Check whether we need to (re)set the timer. If so, we must poll again to ensure the
// timer is registered with the runtime (and check whether it's already
// expired).
match self.inner.poll_timeout() {
Some(deadline) => {
if let Some(delay) = &mut self.timer {
// There is no need to reset the tokio timer if the deadline
// did not change
if self
.timer_deadline
.map(|current_deadline| current_deadline != deadline)
.unwrap_or(true)
{
delay.as_mut().reset(deadline);
}
} else {
self.timer = Some(self.runtime.new_timer(deadline));
}
// Store the actual expiration time of the timer
self.timer_deadline = Some(deadline);
}
None => {
self.timer_deadline = None;
return false;
}
}

if self.timer_deadline.is_none() {
return false;
}

let delay = self
.timer
.as_mut()
.expect("timer must exist in this state")
.as_mut();
if delay.poll(cx).is_pending() {
// Since there wasn't a timeout event, there is nothing new
// for the connection to do
return false;
}

// A timer expired, so the caller needs to check for
// new transmits, which might cause new timers to be set.
self.inner.handle_timeout(Instant::now());
self.timer_deadline = None;
true
}

/// Wake up a blocked `Driver` task to process I/O
/// Wake up endpoint to process I/O by marking it as "dirty" for the endpoint
pub(crate) fn wake(&mut self) {
if let Some(x) = self.driver.take() {
x.wake();
if self.is_dirty {
return;
}
self.is_dirty = true;
// Take no action if the endpoint is gone
let _ = self.dirty.send(self.handle);
}

/// Used to wake up all blocked futures when the connection becomes closed for any reason
Expand Down Expand Up @@ -1058,18 +956,6 @@ impl State {
}
}

impl Drop for State {
fn drop(&mut self) {
if !self.inner.is_drained() {
// Ensure the endpoint can tidy up
let _ = self.endpoint_events.send((
self.handle,
EndpointEvent::Proto(proto::EndpointEvent::drained()),
));
}
}
}

impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("State").field("inner", &self.inner).finish()
Expand Down
Loading

0 comments on commit f8a8234

Please sign in to comment.