diff --git a/transports/quic/Cargo.toml b/transports/quic/Cargo.toml index a8dd90a9c55..9425209e68b 100644 --- a/transports/quic/Cargo.toml +++ b/transports/quic/Cargo.toml @@ -15,6 +15,7 @@ if-watch = "1.0.0" libp2p-core = { version = "0.34.0", path = "../../core" } parking_lot = "0.12.0" quinn-proto = { version = "0.8.2", default-features = false, features = ["tls-rustls"] } +rand = "0.8.5" rcgen = "0.9.2" ring = "0.16.20" rustls = { version = "0.20.2", default-features = false, features = ["dangerous_configuration"] } diff --git a/transports/quic/src/connection.rs b/transports/quic/src/connection.rs index cc883a25b7b..0c9217b190a 100644 --- a/transports/quic/src/connection.rs +++ b/transports/quic/src/connection.rs @@ -134,7 +134,7 @@ impl Connection { // In a normal case scenario this should not happen, because // we get want to get a local addr for a server connection only. tracing::error!("trying to get quinn::local_ip for a client"); - endpoint_addr.clone() + *endpoint_addr }) } diff --git a/transports/quic/src/endpoint.rs b/transports/quic/src/endpoint.rs index e474ea54dd8..59804a1bd96 100644 --- a/transports/quic/src/endpoint.rs +++ b/transports/quic/src/endpoint.rs @@ -28,9 +28,7 @@ //! the rest of the code only happens through channels. See the documentation of the //! [`background_task`] for a thorough description. -use crate::{connection::Connection, tls}; - -use std::net::{SocketAddr, UdpSocket}; +use crate::{connection::Connection, tls, transport}; use futures::{ channel::{mpsc, oneshot}, @@ -40,9 +38,10 @@ use futures::{ use quinn_proto::{ClientConfig as QuinnClientConfig, ServerConfig as QuinnServerConfig}; use std::{ collections::{HashMap, VecDeque}, - fmt, io, + fmt, + net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}, sync::{Arc, Weak}, - task::Poll, + task::{Poll, Waker}, time::{Duration, Instant}, }; @@ -99,45 +98,52 @@ pub struct Endpoint { } impl Endpoint { - /// Builds a new `Endpoint`. - pub fn new( + /// Builds a new `Endpoint` that is listening on the [`SocketAddr`]. + pub fn new_bidirectional( config: Config, socket_addr: SocketAddr, - ) -> Result<(Arc, mpsc::Receiver), io::Error> { + ) -> Result<(Arc, mpsc::Receiver), transport::Error> { + let (new_connections_tx, new_connections_rx) = mpsc::channel(1); + let endpoint = Self::new(config, socket_addr, Some(new_connections_tx))?; + Ok((endpoint, new_connections_rx)) + } + + /// Builds a new `Endpoint` that only supports outbound connections. + pub fn new_dialer(config: Config) -> Result, transport::Error> { + let socket_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0); + Self::new(config, socket_addr.into(), None) + } + + fn new( + config: Config, + socket_addr: SocketAddr, + new_connections: Option>, + ) -> Result, transport::Error> { // NOT blocking, as per man:bind(2), as we pass an IP address. let socket = std::net::UdpSocket::bind(&socket_addr)?; - // TODO: - /*let port_is_zero = local_socket_addr.port() == 0; - let local_socket_addr = socket.local_addr()?; - if port_is_zero { - assert_ne!(local_socket_addr.port(), 0); - assert_eq!(multiaddr.pop(), Some(Protocol::Quic)); - assert_eq!(multiaddr.pop(), Some(Protocol::Udp(0))); - multiaddr.push(Protocol::Udp(local_socket_addr.port())); - multiaddr.push(Protocol::Quic); - }*/ - let (to_endpoint_tx, to_endpoint_rx) = mpsc::channel(32); let to_endpoint2 = to_endpoint_tx.clone(); - let (new_connections_tx, new_connections_rx) = mpsc::channel(1); let endpoint = Arc::new(Endpoint { to_endpoint: Mutex::new(to_endpoint_tx), to_endpoint2, - socket_addr, + socket_addr: socket.local_addr()?, }); + let server_config = new_connections.map(|c| (c, config.server_config.clone())); + // TODO: just for testing, do proper task spawning async_global_executor::spawn(background_task( - config, + config.endpoint_config, + config.client_config, + server_config, Arc::downgrade(&endpoint), async_io::Async::::new(socket)?, - new_connections_tx, to_endpoint_rx.fuse(), )) .detach(); - Ok((endpoint, new_connections_rx)) + Ok(endpoint) } pub fn socket_addr(&self) -> &SocketAddr { @@ -335,17 +341,20 @@ enum ToEndpoint { /// for as long as any QUIC connection is open. /// async fn background_task( - config: Config, + endpoint_config: Arc, + client_config: quinn_proto::ClientConfig, + server_config: Option<(mpsc::Sender, Arc)>, endpoint_weak: Weak, udp_socket: async_io::Async, - mut new_connections: mpsc::Sender, mut receiver: stream::Fuse>, ) { + let (mut new_connections, server_config) = match server_config { + Some((a, b)) => (Some(a), Some(b)), + None => (None, None), + }; + // The actual QUIC state machine. - let mut endpoint = quinn_proto::Endpoint::new( - config.endpoint_config.clone(), - Some(config.server_config.clone()), - ); + let mut endpoint = quinn_proto::Endpoint::new(endpoint_config.clone(), server_config); // List of all active connections, with a sender to notify them of events. let mut alive_connections = HashMap::>::new(); @@ -365,6 +374,8 @@ async fn background_task( // code below. let mut next_packet_out: Option<(SocketAddr, Vec)> = None; + let mut new_connection_waker: Option = None; + // Main loop of the task. loop { // Start by flushing `next_packet_out`. @@ -409,7 +420,7 @@ async fn background_task( // name. While we don't use domain names, the underlying rustls library // is based upon the assumption that we do. let (connection_id, connection) = - match endpoint.connect(config.client_config.clone(), addr, "l") { + match endpoint.connect(client_config.clone(), addr, "l") { Ok(c) => c, Err(err) => { let _ = result.send(Err(err)); @@ -474,8 +485,17 @@ async fn background_task( readiness = { let active = !queued_new_connections.is_empty(); let new_connections = &mut new_connections; + let new_connection_waker = &mut new_connection_waker; future::poll_fn(move |cx| { - if active { new_connections.poll_ready(cx) } else { Poll::Pending } + match new_connections.as_mut() { + Some(ref mut c) if active => { + c.poll_ready(cx) + } + _ => { + let _ = new_connection_waker.insert(cx.waker().clone()); + Poll::Pending + } + } }) .fuse() } => { @@ -487,6 +507,7 @@ async fn background_task( let elem = queued_new_connections.pop_front() .expect("if queue is empty, the future above is always Pending; qed"); + let new_connections = new_connections.as_mut().expect("in case of None, the future above is always Pending; qed"); new_connections.start_send(elem) .expect("future is waken up only if poll_ready returned Ready; qed"); //endpoint.accept(); @@ -537,6 +558,9 @@ async fn background_task( // to the `new_connections` channel. We call `endpoint.accept()` only once // the element has successfully been sent on `new_connections`. queued_new_connections.push_back(connection); + if let Some(waker) = new_connection_waker.take() { + waker.wake(); + } }, } } diff --git a/transports/quic/src/lib.rs b/transports/quic/src/lib.rs index 51d3e4df1fb..3dca1d3cbe3 100644 --- a/transports/quic/src/lib.rs +++ b/transports/quic/src/lib.rs @@ -25,13 +25,14 @@ //! Example: //! //! ``` -//! use libp2p_quic::{Config, Endpoint}; -//! use libp2p_core::Multiaddr; +//! use libp2p_quic::{Config, QuicTransport}; +//! use libp2p_core::{Multiaddr, Transport}; //! //! let keypair = libp2p_core::identity::Keypair::generate_ed25519(); +//! let quic_config = Config::new(&keypair).expect("could not make config"); +//! let mut quic_transport = QuicTransport::new(quic_config); //! let addr = "/ip4/127.0.0.1/udp/12345/quic".parse().expect("bad address?"); -//! let quic_config = Config::new(&keypair, addr).expect("could not make config"); -//! let quic_endpoint = Endpoint::new(quic_config).expect("I/O error"); +//! quic_transport.listen_on(addr).expect("listen error."); //! ``` //! //! The `Endpoint` struct implements the `Transport` trait of the `core` library. See the diff --git a/transports/quic/src/transport.rs b/transports/quic/src/transport.rs index f60afdc5f59..e309d53a8f5 100644 --- a/transports/quic/src/transport.rs +++ b/transports/quic/src/transport.rs @@ -55,6 +55,8 @@ pub use quinn_proto::{ pub struct QuicTransport { config: Config, listeners: SelectAll, + /// Endpoint to use if no listener exists. + dialer: Option>, } impl QuicTransport { @@ -62,6 +64,7 @@ impl QuicTransport { Self { listeners: SelectAll::new(), config, + dialer: None, } } } @@ -75,12 +78,9 @@ pub enum Error { /// Error after the remote has been reached. #[error("{0}")] Established(Libp2pQuicConnectionError), - /// Error while working with IfWatcher. - #[error("{0}")] - IfWatcher(std::io::Error), #[error("{0}")] - Socket(std::io::Error), + Io(#[from] std::io::Error), #[error("Background task crashed.")] TaskCrashed, @@ -94,13 +94,15 @@ impl Transport for QuicTransport { fn listen_on(&mut self, addr: Multiaddr) -> Result> { let socket_addr = multiaddr_to_socketaddr(&addr) - .ok_or_else(|| TransportError::MultiaddrNotSupported(addr))?; - let in_addr = InAddr::new(socket_addr.ip()); - let (endpoint, new_connections_rx) = Endpoint::new(self.config.clone(), socket_addr) - .map_err(|e| TransportError::Other(Error::Socket(e)))?; + .ok_or(TransportError::MultiaddrNotSupported(addr))?; let listener_id = ListenerId::new(); - let listener = Listener::new(listener_id, endpoint, new_connections_rx, in_addr); + let listener = Listener::new(listener_id, socket_addr, self.config.clone()) + .map_err(TransportError::Other)?; self.listeners.push(listener); + // Drop reference to dialer endpoint so that the endpoint is dropped once the last + // connection that uses it is closed. + // New outbound connections will use a bidirectional (listener) endpoint. + let _ = self.dialer.take(); Ok(listener_id) } @@ -118,26 +120,40 @@ impl Transport for QuicTransport { } fn dial(&mut self, addr: Multiaddr) -> Result> { - todo!() - // let socket_addr = if let Some(socket_addr) = multiaddr_to_socketaddr(&addr) { - // if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() { - // tracing::error!("multiaddr not supported"); - // return Err(TransportError::MultiaddrNotSupported(addr)); - // } - // socket_addr - // } else { - // tracing::error!("multiaddr not supported"); - // return Err(TransportError::MultiaddrNotSupported(addr)); - // }; - - // let endpoint = self.endpoint.clone(); - - // Ok(async move { - // let connection = endpoint.dial(socket_addr).await.map_err(Error::Reach)?; - // let final_connec = Upgrade::from_connection(connection).await?; - // Ok(final_connec) - // } - // .boxed()) + let socket_addr = multiaddr_to_socketaddr(&addr) + .ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?; + if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() { + tracing::error!("multiaddr not supported"); + return Err(TransportError::MultiaddrNotSupported(addr)); + } + let endpoint = if self.listeners.is_empty() { + match self.dialer.clone() { + Some(endpoint) => endpoint, + None => { + let endpoint = + Endpoint::new_dialer(self.config.clone()).map_err(TransportError::Other)?; + let _ = self.dialer.insert(endpoint.clone()); + endpoint + } + } + } else { + // Pick a random listener to use for dialing. + // TODO: Prefer listeners with same IP version. + let n = rand::random::() % self.listeners.len(); + let listener = self + .listeners + .iter_mut() + .nth(n) + .expect("Can not be out of bound."); + listener.endpoint.clone() + }; + + Ok(async move { + let connection = endpoint.dial(socket_addr).await.map_err(Error::Reach)?; + let final_connec = Upgrade::from_connection(connection).await?; + Ok(final_connec) + } + .boxed()) } fn dial_as_listener( @@ -169,7 +185,7 @@ struct Listener { listener_id: ListenerId, /// Channel where new connections are being sent. - new_connections: mpsc::Receiver, + new_connections_rx: mpsc::Receiver, /// The IP addresses of network interfaces on which the listening socket /// is accepting connections. @@ -187,17 +203,18 @@ struct Listener { impl Listener { fn new( listener_id: ListenerId, - endpoint: Arc, - new_connections: mpsc::Receiver, - in_addr: InAddr, - ) -> Self { - Listener { + socket_addr: SocketAddr, + config: Config, + ) -> Result { + let in_addr = InAddr::new(socket_addr.ip()); + let (endpoint, new_connections_rx) = Endpoint::new_bidirectional(config, socket_addr)?; + Ok(Listener { endpoint, listener_id, - new_connections, + new_connections_rx, in_addr, report_closed: None, - } + }) } /// Report the listener as closed in a [`TransportEvent::ListenerClosed`] and @@ -261,7 +278,7 @@ impl Listener { }; Some(TransportEvent::ListenerError { listener_id: self.listener_id, - error: Error::IfWatcher(err), + error: err.into(), }) } } @@ -286,7 +303,7 @@ impl Stream for Listener { if let Some(event) = self.poll_if_addr(cx) { return Poll::Ready(Some(event)); } - let connection = match futures::ready!(self.new_connections.poll_next_unpin(cx)) { + let connection = match futures::ready!(self.new_connections_rx.poll_next_unpin(cx)) { Some(c) => c, None => { self.close(Err(Error::TaskCrashed)); diff --git a/transports/quic/tests/smoke.rs b/transports/quic/tests/smoke.rs index c35f8138674..d00cb78cfbb 100644 --- a/transports/quic/tests/smoke.rs +++ b/transports/quic/tests/smoke.rs @@ -14,7 +14,6 @@ use libp2p::swarm::{Swarm, SwarmEvent}; use libp2p_quic::{Config as QuicConfig, QuicTransport}; use rand::RngCore; use std::num::NonZeroU8; -use std::time::Duration; use std::{io, iter}; fn generate_tls_keypair() -> libp2p::identity::Keypair {