Skip to content

Commit

Permalink
transports/quic: re-use endpoints for dialing
Browse files Browse the repository at this point in the history
  • Loading branch information
elenaf9 committed Jul 10, 2022
1 parent 57743ef commit e5e5b34
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 77 deletions.
1 change: 1 addition & 0 deletions transports/quic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if-watch = "1.0.0"
libp2p-core = { version = "0.34.0", path = "../../core" }
parking_lot = "0.12.0"
quinn-proto = { version = "0.8.2", default-features = false, features = ["tls-rustls"] }
rand = "0.8.5"
rcgen = "0.9.2"
ring = "0.16.20"
rustls = { version = "0.20.2", default-features = false, features = ["dangerous_configuration"] }
Expand Down
2 changes: 1 addition & 1 deletion transports/quic/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl Connection {
// In a normal case scenario this should not happen, because
// we get want to get a local addr for a server connection only.
tracing::error!("trying to get quinn::local_ip for a client");
endpoint_addr.clone()
*endpoint_addr
})
}

Expand Down
88 changes: 56 additions & 32 deletions transports/quic/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
//! the rest of the code only happens through channels. See the documentation of the
//! [`background_task`] for a thorough description.
use crate::{connection::Connection, tls};

use std::net::{SocketAddr, UdpSocket};
use crate::{connection::Connection, tls, transport};

use futures::{
channel::{mpsc, oneshot},
Expand All @@ -40,9 +38,10 @@ use futures::{
use quinn_proto::{ClientConfig as QuinnClientConfig, ServerConfig as QuinnServerConfig};
use std::{
collections::{HashMap, VecDeque},
fmt, io,
fmt,
net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket},
sync::{Arc, Weak},
task::Poll,
task::{Poll, Waker},
time::{Duration, Instant},
};

Expand Down Expand Up @@ -99,45 +98,52 @@ pub struct Endpoint {
}

impl Endpoint {
/// Builds a new `Endpoint`.
pub fn new(
/// Builds a new `Endpoint` that is listening on the [`SocketAddr`].
pub fn new_bidirectional(
config: Config,
socket_addr: SocketAddr,
) -> Result<(Arc<Endpoint>, mpsc::Receiver<Connection>), io::Error> {
) -> Result<(Arc<Endpoint>, mpsc::Receiver<Connection>), transport::Error> {
let (new_connections_tx, new_connections_rx) = mpsc::channel(1);
let endpoint = Self::new(config, socket_addr, Some(new_connections_tx))?;
Ok((endpoint, new_connections_rx))
}

/// Builds a new `Endpoint` that only supports outbound connections.
pub fn new_dialer(config: Config) -> Result<Arc<Endpoint>, transport::Error> {
let socket_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
Self::new(config, socket_addr.into(), None)
}

fn new(
config: Config,
socket_addr: SocketAddr,
new_connections: Option<mpsc::Sender<Connection>>,
) -> Result<Arc<Endpoint>, transport::Error> {
// NOT blocking, as per man:bind(2), as we pass an IP address.
let socket = std::net::UdpSocket::bind(&socket_addr)?;
// TODO:
/*let port_is_zero = local_socket_addr.port() == 0;
let local_socket_addr = socket.local_addr()?;
if port_is_zero {
assert_ne!(local_socket_addr.port(), 0);
assert_eq!(multiaddr.pop(), Some(Protocol::Quic));
assert_eq!(multiaddr.pop(), Some(Protocol::Udp(0)));
multiaddr.push(Protocol::Udp(local_socket_addr.port()));
multiaddr.push(Protocol::Quic);
}*/

let (to_endpoint_tx, to_endpoint_rx) = mpsc::channel(32);
let to_endpoint2 = to_endpoint_tx.clone();
let (new_connections_tx, new_connections_rx) = mpsc::channel(1);

let endpoint = Arc::new(Endpoint {
to_endpoint: Mutex::new(to_endpoint_tx),
to_endpoint2,
socket_addr,
socket_addr: socket.local_addr()?,
});

let server_config = new_connections.map(|c| (c, config.server_config.clone()));

// TODO: just for testing, do proper task spawning
async_global_executor::spawn(background_task(
config,
config.endpoint_config,
config.client_config,
server_config,
Arc::downgrade(&endpoint),
async_io::Async::<UdpSocket>::new(socket)?,
new_connections_tx,
to_endpoint_rx.fuse(),
))
.detach();

Ok((endpoint, new_connections_rx))
Ok(endpoint)
}

pub fn socket_addr(&self) -> &SocketAddr {
Expand Down Expand Up @@ -335,17 +341,20 @@ enum ToEndpoint {
/// for as long as any QUIC connection is open.
///
async fn background_task(
config: Config,
endpoint_config: Arc<quinn_proto::EndpointConfig>,
client_config: quinn_proto::ClientConfig,
server_config: Option<(mpsc::Sender<Connection>, Arc<quinn_proto::ServerConfig>)>,
endpoint_weak: Weak<Endpoint>,
udp_socket: async_io::Async<UdpSocket>,
mut new_connections: mpsc::Sender<Connection>,
mut receiver: stream::Fuse<mpsc::Receiver<ToEndpoint>>,
) {
let (mut new_connections, server_config) = match server_config {
Some((a, b)) => (Some(a), Some(b)),
None => (None, None),
};

// The actual QUIC state machine.
let mut endpoint = quinn_proto::Endpoint::new(
config.endpoint_config.clone(),
Some(config.server_config.clone()),
);
let mut endpoint = quinn_proto::Endpoint::new(endpoint_config.clone(), server_config);

// List of all active connections, with a sender to notify them of events.
let mut alive_connections = HashMap::<quinn_proto::ConnectionHandle, mpsc::Sender<_>>::new();
Expand All @@ -365,6 +374,8 @@ async fn background_task(
// code below.
let mut next_packet_out: Option<(SocketAddr, Vec<u8>)> = None;

let mut new_connection_waker: Option<Waker> = None;

// Main loop of the task.
loop {
// Start by flushing `next_packet_out`.
Expand Down Expand Up @@ -409,7 +420,7 @@ async fn background_task(
// name. While we don't use domain names, the underlying rustls library
// is based upon the assumption that we do.
let (connection_id, connection) =
match endpoint.connect(config.client_config.clone(), addr, "l") {
match endpoint.connect(client_config.clone(), addr, "l") {
Ok(c) => c,
Err(err) => {
let _ = result.send(Err(err));
Expand Down Expand Up @@ -474,8 +485,17 @@ async fn background_task(
readiness = {
let active = !queued_new_connections.is_empty();
let new_connections = &mut new_connections;
let new_connection_waker = &mut new_connection_waker;
future::poll_fn(move |cx| {
if active { new_connections.poll_ready(cx) } else { Poll::Pending }
match new_connections.as_mut() {
Some(ref mut c) if active => {
c.poll_ready(cx)
}
_ => {
let _ = new_connection_waker.insert(cx.waker().clone());
Poll::Pending
}
}
})
.fuse()
} => {
Expand All @@ -487,6 +507,7 @@ async fn background_task(

let elem = queued_new_connections.pop_front()
.expect("if queue is empty, the future above is always Pending; qed");
let new_connections = new_connections.as_mut().expect("in case of None, the future above is always Pending; qed");
new_connections.start_send(elem)
.expect("future is waken up only if poll_ready returned Ready; qed");
//endpoint.accept();
Expand Down Expand Up @@ -537,6 +558,9 @@ async fn background_task(
// to the `new_connections` channel. We call `endpoint.accept()` only once
// the element has successfully been sent on `new_connections`.
queued_new_connections.push_back(connection);
if let Some(waker) = new_connection_waker.take() {
waker.wake();
}
},
}
}
Expand Down
9 changes: 5 additions & 4 deletions transports/quic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
//! Example:
//!
//! ```
//! use libp2p_quic::{Config, Endpoint};
//! use libp2p_core::Multiaddr;
//! use libp2p_quic::{Config, QuicTransport};
//! use libp2p_core::{Multiaddr, Transport};
//!
//! let keypair = libp2p_core::identity::Keypair::generate_ed25519();
//! let quic_config = Config::new(&keypair).expect("could not make config");
//! let mut quic_transport = QuicTransport::new(quic_config);
//! let addr = "/ip4/127.0.0.1/udp/12345/quic".parse().expect("bad address?");
//! let quic_config = Config::new(&keypair, addr).expect("could not make config");
//! let quic_endpoint = Endpoint::new(quic_config).expect("I/O error");
//! quic_transport.listen_on(addr).expect("listen error.");
//! ```
//!
//! The `Endpoint` struct implements the `Transport` trait of the `core` library. See the
Expand Down
95 changes: 56 additions & 39 deletions transports/quic/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@ pub use quinn_proto::{
pub struct QuicTransport {
config: Config,
listeners: SelectAll<Listener>,
/// Endpoint to use if no listener exists.
dialer: Option<Arc<Endpoint>>,
}

impl QuicTransport {
pub fn new(config: Config) -> Self {
Self {
listeners: SelectAll::new(),
config,
dialer: None,
}
}
}
Expand All @@ -75,12 +78,9 @@ pub enum Error {
/// Error after the remote has been reached.
#[error("{0}")]
Established(Libp2pQuicConnectionError),
/// Error while working with IfWatcher.
#[error("{0}")]
IfWatcher(std::io::Error),

#[error("{0}")]
Socket(std::io::Error),
Io(#[from] std::io::Error),

#[error("Background task crashed.")]
TaskCrashed,
Expand All @@ -94,13 +94,15 @@ impl Transport for QuicTransport {

fn listen_on(&mut self, addr: Multiaddr) -> Result<ListenerId, TransportError<Self::Error>> {
let socket_addr = multiaddr_to_socketaddr(&addr)
.ok_or_else(|| TransportError::MultiaddrNotSupported(addr))?;
let in_addr = InAddr::new(socket_addr.ip());
let (endpoint, new_connections_rx) = Endpoint::new(self.config.clone(), socket_addr)
.map_err(|e| TransportError::Other(Error::Socket(e)))?;
.ok_or(TransportError::MultiaddrNotSupported(addr))?;
let listener_id = ListenerId::new();
let listener = Listener::new(listener_id, endpoint, new_connections_rx, in_addr);
let listener = Listener::new(listener_id, socket_addr, self.config.clone())
.map_err(TransportError::Other)?;
self.listeners.push(listener);
// Drop reference to dialer endpoint so that the endpoint is dropped once the last
// connection that uses it is closed.
// New outbound connections will use a bidirectional (listener) endpoint.
let _ = self.dialer.take();
Ok(listener_id)
}

Expand All @@ -118,26 +120,40 @@ impl Transport for QuicTransport {
}

fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
todo!()
// let socket_addr = if let Some(socket_addr) = multiaddr_to_socketaddr(&addr) {
// if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() {
// tracing::error!("multiaddr not supported");
// return Err(TransportError::MultiaddrNotSupported(addr));
// }
// socket_addr
// } else {
// tracing::error!("multiaddr not supported");
// return Err(TransportError::MultiaddrNotSupported(addr));
// };

// let endpoint = self.endpoint.clone();

// Ok(async move {
// let connection = endpoint.dial(socket_addr).await.map_err(Error::Reach)?;
// let final_connec = Upgrade::from_connection(connection).await?;
// Ok(final_connec)
// }
// .boxed())
let socket_addr = multiaddr_to_socketaddr(&addr)
.ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?;
if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() {
tracing::error!("multiaddr not supported");
return Err(TransportError::MultiaddrNotSupported(addr));
}
let endpoint = if self.listeners.is_empty() {
match self.dialer.clone() {
Some(endpoint) => endpoint,
None => {
let endpoint =
Endpoint::new_dialer(self.config.clone()).map_err(TransportError::Other)?;
let _ = self.dialer.insert(endpoint.clone());
endpoint
}
}
} else {
// Pick a random listener to use for dialing.
// TODO: Prefer listeners with same IP version.
let n = rand::random::<usize>() % self.listeners.len();
let listener = self
.listeners
.iter_mut()
.nth(n)
.expect("Can not be out of bound.");
listener.endpoint.clone()
};

Ok(async move {
let connection = endpoint.dial(socket_addr).await.map_err(Error::Reach)?;
let final_connec = Upgrade::from_connection(connection).await?;
Ok(final_connec)
}
.boxed())
}

fn dial_as_listener(
Expand Down Expand Up @@ -169,7 +185,7 @@ struct Listener {
listener_id: ListenerId,

/// Channel where new connections are being sent.
new_connections: mpsc::Receiver<Connection>,
new_connections_rx: mpsc::Receiver<Connection>,

/// The IP addresses of network interfaces on which the listening socket
/// is accepting connections.
Expand All @@ -187,17 +203,18 @@ struct Listener {
impl Listener {
fn new(
listener_id: ListenerId,
endpoint: Arc<Endpoint>,
new_connections: mpsc::Receiver<Connection>,
in_addr: InAddr,
) -> Self {
Listener {
socket_addr: SocketAddr,
config: Config,
) -> Result<Self, Error> {
let in_addr = InAddr::new(socket_addr.ip());
let (endpoint, new_connections_rx) = Endpoint::new_bidirectional(config, socket_addr)?;
Ok(Listener {
endpoint,
listener_id,
new_connections,
new_connections_rx,
in_addr,
report_closed: None,
}
})
}

/// Report the listener as closed in a [`TransportEvent::ListenerClosed`] and
Expand Down Expand Up @@ -261,7 +278,7 @@ impl Listener {
};
Some(TransportEvent::ListenerError {
listener_id: self.listener_id,
error: Error::IfWatcher(err),
error: err.into(),
})
}
}
Expand All @@ -286,7 +303,7 @@ impl Stream for Listener {
if let Some(event) = self.poll_if_addr(cx) {
return Poll::Ready(Some(event));
}
let connection = match futures::ready!(self.new_connections.poll_next_unpin(cx)) {
let connection = match futures::ready!(self.new_connections_rx.poll_next_unpin(cx)) {
Some(c) => c,
None => {
self.close(Err(Error::TaskCrashed));
Expand Down
1 change: 0 additions & 1 deletion transports/quic/tests/smoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use libp2p::swarm::{Swarm, SwarmEvent};
use libp2p_quic::{Config as QuicConfig, QuicTransport};
use rand::RngCore;
use std::num::NonZeroU8;
use std::time::Duration;
use std::{io, iter};

fn generate_tls_keypair() -> libp2p::identity::Keypair {
Expand Down

0 comments on commit e5e5b34

Please sign in to comment.