Skip to content

Commit

Permalink
Unify connection/endpoint drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralith committed Nov 17, 2021
1 parent cf9eb35 commit 01d0a8c
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 207 deletions.
1 change: 1 addition & 0 deletions quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ rustls = { version = "0.20", default-features = false, features = ["quic"], opti
thiserror = "1.0.21"
tracing = "0.1.10"
tokio = { version = "1.0.1", features = ["rt", "time"] }
tokio-util = { version = "0.6.9", features = ["time"] }
udp = { package = "quinn-udp", path = "../quinn-udp", version = "0.1.0" }
webpki = { version = "0.22", default-features = false, optional = true }

Expand Down
185 changes: 34 additions & 151 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
any::Any,
collections::VecDeque,
fmt,
future::Future,
mem,
Expand All @@ -11,21 +12,21 @@ use std::{
};

use bytes::Bytes;
use futures_channel::{mpsc, oneshot};
use futures_channel::oneshot;
use futures_util::FutureExt;
use fxhash::FxHashMap;
use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId};
use thiserror::Error;
use tokio::time::{sleep_until, Instant as TokioInstant, Sleep};
use tokio::{sync::mpsc, time::Instant as TokioInstant};
use tokio_util::time::delay_queue;
use tracing::info_span;
use udp::UdpState;

use crate::{
broadcast::{self, Broadcast},
mutex::Mutex,
recv_stream::RecvStream,
send_stream::{SendStream, WriteError},
EndpointEvent, VarInt,
VarInt,
};
use proto::congestion::Controller;

Expand All @@ -40,23 +41,20 @@ pub struct Connecting {

impl Connecting {
pub(crate) fn new(
dirty: mpsc::UnboundedSender<ConnectionHandle>,
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
udp_state: Arc<UdpState>,
) -> (Connecting, ConnectionRef) {
let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel();
let (on_connected_send, on_connected_recv) = oneshot::channel();
let conn = ConnectionRef::new(
handle,
conn,
endpoint_events,
dirty,
on_handshake_data_send,
on_connected_send,
udp_state,
);

tokio::spawn(ConnectionDriver(conn.clone()));
(
Connecting {
conn: Some(conn.clone()),
Expand Down Expand Up @@ -101,8 +99,10 @@ impl Connecting {
drop(conn);

if is_ok {
let conn = self.conn.take().unwrap();
Ok((NewConnection::new(conn), ZeroRttAccepted(self.connected)))
Ok((
NewConnection::new(self.conn.take().unwrap()),
ZeroRttAccepted(self.connected),
))
} else {
Err(self)
}
Expand Down Expand Up @@ -253,53 +253,6 @@ impl NewConnection {
}
}

/// A future that drives protocol logic for a connection
///
/// This future handles the protocol logic for a single connection, routing events from the
/// `Connection` API object to the `Endpoint` task and the related stream-related interfaces.
/// It also keeps track of outstanding timeouts for the `Connection`.
///
/// If the connection encounters an error condition, this future will yield an error. It will
/// terminate (yielding `Ok(())`) if the connection was closed without error. Unlike other
/// connection-related futures, this waits for the draining period to complete to ensure that
/// packets still in flight from the peer are handled gracefully.
#[must_use = "connection drivers must be spawned for their connections to function"]
#[derive(Debug)]
struct ConnectionDriver(ConnectionRef);

impl Future for ConnectionDriver {
type Output = ();

#[allow(unused_mut)] // MSRV
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let conn = &mut *self.0.lock("poll");

let span = info_span!("drive", id = conn.handle.0);
let _guard = span.enter();

let mut keep_going = conn.drive_transmit();
// If a timer expires, there might be more to transmit. When we transmit something, we
// might need to reset a timer. Hence, we must loop until neither happens.
keep_going |= conn.drive_timer(cx);
conn.forward_endpoint_events();
conn.forward_app_events();

if !conn.inner.is_drained() {
if keep_going {
// If the connection hasn't processed all tasks, schedule it again
cx.waker().wake_by_ref();
} else {
conn.driver = Some(cx.waker().clone());
}
return Poll::Pending;
}
if conn.error.is_none() {
unreachable!("drained connections always have an error");
}
Poll::Ready(())
}
}

/// A QUIC connection.
///
/// If all references to a connection (including every clone of the `Connection` handle, streams of
Expand Down Expand Up @@ -662,21 +615,21 @@ impl ConnectionRef {
fn new(
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
dirty: mpsc::UnboundedSender<ConnectionHandle>,
on_handshake_data: oneshot::Sender<()>,
on_connected: oneshot::Sender<bool>,
udp_state: Arc<UdpState>,
) -> Self {
Self(Arc::new(Mutex::new(ConnectionInner {
inner: conn,
driver: None,
handle,
span: info_span!("connection", id = handle.0),
is_dirty: false,
dirty,
on_handshake_data: Some(on_handshake_data),
on_connected: Some(on_connected),
connected: false,
timer: None,
timer_handle: None,
timer_deadline: None,
endpoint_events,
blocked_writers: FxHashMap::default(),
blocked_readers: FxHashMap::default(),
uni_opening: Broadcast::new(),
Expand All @@ -688,7 +641,6 @@ impl ConnectionRef {
stopped: FxHashMap::default(),
error: None,
ref_count: 0,
udp_state,
})))
}

Expand Down Expand Up @@ -729,14 +681,15 @@ impl std::ops::Deref for ConnectionRef {

pub struct ConnectionInner {
pub(crate) inner: proto::Connection,
driver: Option<Waker>,
handle: ConnectionHandle,
pub(crate) span: tracing::Span,
pub(crate) is_dirty: bool,
dirty: mpsc::UnboundedSender<ConnectionHandle>,
on_handshake_data: Option<oneshot::Sender<()>>,
on_connected: Option<oneshot::Sender<bool>>,
connected: bool,
timer: Option<Pin<Box<Sleep>>>,
timer_deadline: Option<TokioInstant>,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
pub(crate) timer_handle: Option<delay_queue::Key>,
pub(crate) timer_deadline: Option<TokioInstant>,
pub(crate) blocked_writers: FxHashMap<StreamId, Waker>,
pub(crate) blocked_readers: FxHashMap<StreamId, Waker>,
uni_opening: Broadcast,
Expand All @@ -750,25 +703,23 @@ pub struct ConnectionInner {
pub(crate) error: Option<ConnectionError>,
/// Number of live handles that can be used to initiate or handle I/O; excludes the driver
ref_count: usize,
udp_state: Arc<UdpState>,
}

impl ConnectionInner {
fn drive_transmit(&mut self) -> bool {
pub(crate) fn drive_transmit(
&mut self,
max_datagrams: usize,
out: &mut VecDeque<proto::Transmit>,
) -> bool {
let now = Instant::now();
let mut transmits = 0;

let max_datagrams = self.udp_state.max_gso_segments();

while let Some(t) = self.inner.poll_transmit(now, max_datagrams) {
transmits += match t.segment_size {
None => 1,
Some(s) => (t.contents.len() + s - 1) / s, // round up
};
// If the endpoint driver is gone, noop.
let _ = self
.endpoint_events
.unbounded_send((self.handle, EndpointEvent::Transmit(t)));
out.push_back(t);

if transmits >= MAX_TRANSMIT_DATAGRAMS {
// TODO: What isn't ideal here yet is that if we don't poll all
Expand All @@ -782,16 +733,7 @@ impl ConnectionInner {
false
}

fn forward_endpoint_events(&mut self) {
while let Some(event) = self.inner.poll_endpoint_events() {
// If the endpoint driver is gone, noop.
let _ = self
.endpoint_events
.unbounded_send((self.handle, EndpointEvent::Proto(event)));
}
}

fn forward_app_events(&mut self) {
pub(crate) fn forward_app_events(&mut self) {
while let Some(event) = self.inner.poll() {
use proto::Event::*;
match event {
Expand Down Expand Up @@ -863,61 +805,14 @@ impl ConnectionInner {
}
}

fn drive_timer(&mut self, cx: &mut Context) -> bool {
// Check whether we need to (re)set the timer. If so, we must poll again to ensure the
// timer is registered with the runtime (and check whether it's already
// expired).
match self.inner.poll_timeout().map(TokioInstant::from_std) {
Some(deadline) => {
if let Some(delay) = &mut self.timer {
// There is no need to reset the tokio timer if the deadline
// did not change
if self
.timer_deadline
.map(|current_deadline| current_deadline != deadline)
.unwrap_or(true)
{
delay.as_mut().reset(deadline);
}
} else {
self.timer = Some(Box::pin(sleep_until(deadline)));
}
// Store the actual expiration time of the timer
self.timer_deadline = Some(deadline);
}
None => {
self.timer_deadline = None;
return false;
}
}

if self.timer_deadline.is_none() {
return false;
}

let delay = self
.timer
.as_mut()
.expect("timer must exist in this state")
.as_mut();
if delay.poll(cx).is_pending() {
// Since there wasn't a timeout event, there is nothing new
// for the connection to do
return false;
}

// A timer expired, so the caller needs to check for
// new transmits, which might cause new timers to be set.
self.inner.handle_timeout(Instant::now());
self.timer_deadline = None;
true
}

/// Wake up a blocked `Driver` task to process I/O
/// Wake up endpoint to process I/O
pub(crate) fn wake(&mut self) {
if let Some(x) = self.driver.take() {
x.wake();
if self.is_dirty {
return;
}
self.is_dirty = true;
// Take no action if the endpoint is gone
let _ = self.dirty.send(self.handle);
}

/// Used to wake up all blocked futures when the connection becomes closed for any reason
Expand Down Expand Up @@ -977,18 +872,6 @@ impl ConnectionInner {
}
}

impl Drop for ConnectionInner {
fn drop(&mut self) {
if !self.inner.is_drained() {
// Ensure the endpoint can tidy up
let _ = self.endpoint_events.unbounded_send((
self.handle,
EndpointEvent::Proto(proto::EndpointEvent::drained()),
));
}
}
}

impl fmt::Debug for ConnectionInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ConnectionInner")
Expand Down
Loading

0 comments on commit 01d0a8c

Please sign in to comment.