diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 4fa5768c0..1c3a65740 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -23,7 +23,7 @@ use crate::{ mutex::Mutex, recv_stream::RecvStream, send_stream::{SendStream, WriteError}, - ConnectionEvent, EndpointEvent, VarInt, + EndpointEvent, VarInt, }; use proto::congestion::Controller; @@ -41,17 +41,15 @@ impl Connecting { handle: ConnectionHandle, conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, - conn_events: mpsc::UnboundedReceiver, udp_state: Arc, runtime: Arc, - ) -> Connecting { + ) -> (Connecting, ConnectionRef) { let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel(); let (on_connected_send, on_connected_recv) = oneshot::channel(); let conn = ConnectionRef::new( handle, conn, endpoint_events, - conn_events, on_handshake_data_send, on_connected_send, udp_state, @@ -59,12 +57,14 @@ impl Connecting { ); runtime.spawn(Box::pin(ConnectionDriver(conn.clone()))); - - Connecting { - conn: Some(conn), - connected: on_connected_recv, - handshake_data_ready: Some(on_handshake_data_recv), - } + ( + Connecting { + conn: Some(conn.clone()), + connected: on_connected_recv, + handshake_data_ready: Some(on_handshake_data_recv), + }, + conn, + ) } /// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened security @@ -226,10 +226,6 @@ impl Future for ConnectionDriver { let span = debug_span!("drive", id = conn.handle.0); let _guard = span.enter(); - if let Err(e) = conn.process_conn_events(&self.0.shared, cx) { - conn.terminate(e, &self.0.shared); - return Poll::Ready(()); - } let mut keep_going = conn.drive_transmit(); // If a timer expires, there might be more to transmit. When we transmit something, we // might need to reset a timer. Hence, we must loop until neither happens. @@ -746,7 +742,6 @@ impl ConnectionRef { handle: ConnectionHandle, conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, - conn_events: mpsc::UnboundedReceiver, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, udp_state: Arc, @@ -762,7 +757,6 @@ impl ConnectionRef { connected: false, timer: None, timer_deadline: None, - conn_events, endpoint_events, blocked_writers: FxHashMap::default(), blocked_readers: FxHashMap::default(), @@ -838,7 +832,6 @@ pub(crate) struct State { connected: bool, timer: Option>>, timer_deadline: Option, - conn_events: mpsc::UnboundedReceiver, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, pub(crate) blocked_writers: FxHashMap, pub(crate) blocked_readers: FxHashMap, @@ -890,37 +883,6 @@ impl State { } } - /// If this returns `Err`, the endpoint is dead, so the driver should exit immediately. - fn process_conn_events( - &mut self, - shared: &Shared, - cx: &mut Context, - ) -> Result<(), ConnectionError> { - loop { - match self.conn_events.poll_recv(cx) { - Poll::Ready(Some(ConnectionEvent::Ping)) => { - self.inner.ping(); - } - Poll::Ready(Some(ConnectionEvent::Proto(event))) => { - self.inner.handle_event(event); - } - Poll::Ready(Some(ConnectionEvent::Close { reason, error_code })) => { - self.close(error_code, reason, shared); - } - Poll::Ready(None) => { - return Err(ConnectionError::TransportError(proto::TransportError { - code: proto::TransportErrorCode::INTERNAL_ERROR, - frame: None, - reason: "endpoint driver future was dropped".to_string(), - })); - } - Poll::Pending => { - return Ok(()); - } - } - } - } - fn forward_app_events(&mut self, shared: &Shared) { while let Some(event) = self.inner.poll() { use proto::Event::*; @@ -1073,7 +1035,7 @@ impl State { shared.closed.notify_waiters(); } - fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) { + pub fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) { self.inner.close(Instant::now(), error_code, reason); self.terminate(ConnectionError::LocallyClosed, shared); self.wake(); diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index c134231c7..94718f45a 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -23,8 +23,9 @@ use tokio::sync::{futures::Notified, mpsc, Notify}; use udp::{RecvMeta, UdpState, BATCH_SIZE}; use crate::{ - connection::Connecting, work_limiter::WorkLimiter, ConnectionEvent, EndpointConfig, - EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, SEND_TIME_BOUND, + connection::{Connecting, ConnectionRef}, + work_limiter::WorkLimiter, + EndpointConfig, EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, SEND_TIME_BOUND, }; /// A QUIC endpoint. @@ -210,9 +211,10 @@ impl Endpoint { inner.ipv6 = addr.is_ipv6(); // Generate some activity so peers notice the rebind - for sender in inner.connections.senders.values() { - // Ignoring errors from dropped connections - let _ = sender.send(ConnectionEvent::Ping); + for conn in inner.connections.refs.values() { + let mut state = conn.state.lock("ping"); + state.inner.ping(); + state.wake(); } Ok(()) @@ -244,12 +246,9 @@ impl Endpoint { let reason = Bytes::copy_from_slice(reason); let mut endpoint = self.inner.state.lock().unwrap(); endpoint.connections.close = Some((error_code, reason.clone())); - for sender in endpoint.connections.senders.values() { - // Ignoring errors from dropped connections - let _ = sender.send(ConnectionEvent::Close { - error_code, - reason: reason.clone(), - }); + for conn in endpoint.connections.refs.values() { + let mut state = conn.state.lock("close"); + state.close(error_code, reason.clone(), &conn.shared); } self.inner.shared.incoming.notify_waiters(); } @@ -333,9 +332,6 @@ impl Drop for EndpointDriver { let mut endpoint = self.0.state.lock().unwrap(); endpoint.driver_lost = true; self.0.shared.incoming.notify_waiters(); - // Drop all outgoing channels, signaling the termination of the endpoint to the associated - // connections. - endpoint.connections.senders.clear(); } } @@ -408,13 +404,10 @@ impl State { self.incoming.push_back(conn); } Some((handle, DatagramEvent::ConnectionEvent(event))) => { - // Ignoring errors from dropped connections that haven't yet been cleaned up - let _ = self - .connections - .senders - .get_mut(&handle) - .unwrap() - .send(ConnectionEvent::Proto(event)); + let conn = self.connections.refs.get(&handle).unwrap(); + let mut state = conn.state.lock("handle_event"); + state.inner.handle_event(event); + state.wake(); } None => {} } @@ -493,19 +486,16 @@ impl State { Poll::Ready(Some((ch, event))) => match event { Proto(e) => { if e.is_drained() { - self.connections.senders.remove(&ch); + self.connections.refs.remove(&ch); if self.connections.is_empty() { shared.idle.notify_waiters(); } } if let Some(event) = self.inner.handle_event(ch, e) { - // Ignoring errors from dropped connections that haven't yet been cleaned up - let _ = self - .connections - .senders - .get_mut(&ch) - .unwrap() - .send(ConnectionEvent::Proto(event)); + let conn = self.connections.refs.get(&ch).unwrap(); + let mut conn = conn.state.lock("handle_event"); + conn.inner.handle_event(event); + conn.wake(); } } Transmit(t) => self.outgoing.push_back(t), @@ -523,8 +513,7 @@ impl State { #[derive(Debug)] struct ConnectionSet { - /// Senders for communicating with the endpoint's connections - senders: FxHashMap>, + refs: FxHashMap, /// Stored to give out clones to new ConnectionInners sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, /// Set if the endpoint has been manually closed @@ -539,20 +528,17 @@ impl ConnectionSet { udp_state: Arc, runtime: Arc, ) -> Connecting { - let (send, recv) = mpsc::unbounded_channel(); + let (future, conn) = Connecting::new(handle, conn, self.sender.clone(), udp_state, runtime); if let Some((error_code, ref reason)) = self.close { - send.send(ConnectionEvent::Close { - error_code, - reason: reason.clone(), - }) - .unwrap(); + let mut state = conn.state.lock("close"); + state.close(error_code, reason.clone(), &conn.shared); } - self.senders.insert(handle, send); - Connecting::new(handle, conn, self.sender.clone(), recv, udp_state, runtime) + self.refs.insert(handle, conn); + future } fn is_empty(&self) -> bool { - self.senders.is_empty() + self.refs.is_empty() } } @@ -632,7 +618,7 @@ impl EndpointRef { incoming: VecDeque::new(), driver: None, connections: ConnectionSet { - senders: FxHashMap::default(), + refs: FxHashMap::default(), sender, close: None, }, diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 4aa366474..e3d420384 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -81,16 +81,6 @@ pub use crate::send_stream::{SendStream, StoppedError, WriteError}; #[cfg(test)] mod tests; -#[derive(Debug)] -enum ConnectionEvent { - Close { - error_code: VarInt, - reason: bytes::Bytes, - }, - Proto(proto::ConnectionEvent), - Ping, -} - #[derive(Debug)] enum EndpointEvent { Proto(proto::EndpointEvent),