diff --git a/Cargo.lock b/Cargo.lock index 7bf2ca44cb..2181edcfaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1441,9 +1441,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" dependencies = [ "serde", ] @@ -4185,6 +4185,18 @@ dependencies = [ "slab", ] +[[package]] +name = "futures_ringbuf" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6628abb6eb1fc74beaeb20cd0670c43d158b0150f7689b38c3eaf663f99bdec7" +dependencies = [ + "futures", + "log", + "ringbuf", + "rustc_version", +] + [[package]] name = "fxhash" version = "0.2.1" @@ -7878,18 +7890,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.3" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.3" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", @@ -8708,6 +8720,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ringbuf" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79abed428d1fd2a128201cec72c5f6938e2da607c6f3745f769fabea399d950a" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "ripemd" version = "0.1.3" @@ -11618,6 +11639,24 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "subspace-cluster-networking" +version = "0.1.0" +dependencies = [ + "async-trait", + "backoff", + "event-listener-primitives", + "futures", + "futures_ringbuf", + "libp2p 0.53.2", + "parity-scale-codec", + "parking_lot 0.12.1", + "pin-project", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "subspace-core-primitives" version = "0.1.0" diff --git a/crates/subspace-farmer-components/Cargo.toml b/crates/subspace-farmer-components/Cargo.toml index 5157473ded..883a9a96d5 100644 --- a/crates/subspace-farmer-components/Cargo.toml +++ b/crates/subspace-farmer-components/Cargo.toml @@ -18,7 +18,7 @@ bench = false [dependencies] async-lock = "3.3.0" async-trait = "0.1.77" -backoff = { version = "0.4.0", features = ["futures", "tokio"] } +backoff = { version = "0.4.0", features = ["tokio"] } bitvec = "1.0.1" # TODO: Switch to fs4 once https://github.com/al8n/fs4-rs/issues/15 is resolved fs2 = "0.4.3" diff --git a/crates/subspace-farmer/Cargo.toml b/crates/subspace-farmer/Cargo.toml index 226971441d..3384b9c53d 100644 --- a/crates/subspace-farmer/Cargo.toml +++ b/crates/subspace-farmer/Cargo.toml @@ -15,7 +15,7 @@ include = [ anyhow = "1.0.79" async-lock = "3.3.0" async-trait = "0.1.77" -backoff = { version = "0.4.0", features = ["futures", "tokio"] } +backoff = { version = "0.4.0", features = ["tokio"] } base58 = "0.2.0" blake2 = "0.10.6" blake3 = { version = "1.5.0", default-features = false } diff --git a/crates/subspace-networking/Cargo.toml b/crates/subspace-networking/Cargo.toml index 08d91b2477..dd5bca7b92 100644 --- a/crates/subspace-networking/Cargo.toml +++ b/crates/subspace-networking/Cargo.toml @@ -18,8 +18,8 @@ include = [ [dependencies] async-mutex = "1.4.0" async-trait = "0.1.77" -backoff = { version = "0.4.0", features = ["futures", "tokio"] } -bytes = "1.5.0" +backoff = { version = "0.4.0", features = ["tokio"] } +bytes = "1.6.0" clap = { version = "4.4.18", features = ["color", "derive"] } derive_more = "0.99.17" either = "1.8.1" @@ -34,7 +34,7 @@ memmap2 = "0.9.3" nohash-hasher = "0.2.0" parity-scale-codec = "3.6.9" parking_lot = "0.12.1" -pin-project = "1.1.3" +pin-project = "1.1.5" prometheus-client = "0.22.0" rand = "0.8.5" serde = { version = "1.0.195", features = ["derive"] } diff --git a/crates/subspace-networking/src/behavior.rs b/crates/subspace-networking/src/behavior.rs index c11c63fe96..d94f4a4726 100644 --- a/crates/subspace-networking/src/behavior.rs +++ b/crates/subspace-networking/src/behavior.rs @@ -57,7 +57,6 @@ pub(crate) struct BehaviorConfig { #[derive(NetworkBehaviour)] #[behaviour(to_swarm = "Event")] -#[behaviour(event_process = false)] pub(crate) struct Behavior { // TODO: Connection limits must be the first protocol due to https://github.com/libp2p/rust-libp2p/issues/4773 as // suggested in https://github.com/libp2p/rust-libp2p/issues/4898#issuecomment-1818013483 diff --git a/crates/subspace-networking/src/node.rs b/crates/subspace-networking/src/node.rs index 213a445a01..3a7b5a4db0 100644 --- a/crates/subspace-networking/src/node.rs +++ b/crates/subspace-networking/src/node.rs @@ -5,7 +5,6 @@ use crate::utils::multihash::Multihash; use crate::utils::HandlerFn; use bytes::Bytes; use event_listener_primitives::HandlerId; -use futures::channel::mpsc::SendError; use futures::channel::{mpsc, oneshot}; use futures::{SinkExt, Stream, StreamExt}; use libp2p::gossipsub::{Sha256Topic, SubscriptionError}; @@ -43,7 +42,7 @@ impl Stream for TopicSubscription { #[pin_project::pinned_drop] impl PinnedDrop for TopicSubscription { - fn drop(mut self: std::pin::Pin<&mut Self>) { + fn drop(mut self: Pin<&mut Self>) { let topic = self .topic .take() @@ -70,7 +69,7 @@ impl PinnedDrop for TopicSubscription { pub enum GetValueError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -87,7 +86,7 @@ impl From for GetValueError { pub enum PutValueError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -105,7 +104,7 @@ impl From for PutValueError { pub enum GetClosestPeersError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -123,7 +122,7 @@ impl From for GetClosestPeersError { pub enum SubscribeError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -143,7 +142,7 @@ impl From for SubscribeError { pub enum PublishError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -163,7 +162,7 @@ impl From for PublishError { pub enum GetProvidersError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -184,7 +183,7 @@ impl From for GetProvidersError { pub enum SendRequestError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -207,7 +206,7 @@ impl From for SendRequestError { pub enum ConnectedPeersError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -227,7 +226,7 @@ impl From for ConnectedPeersError { pub enum BootstrapError { /// Failed to send command to the node runner #[error("Failed to send command to the node runner: {0}")] - SendCommand(#[from] SendError), + SendCommand(#[from] mpsc::SendError), /// Node runner was dropped #[error("Node runner was dropped")] NodeRunnerDropped, @@ -471,7 +470,7 @@ impl Node { } /// Ban peer with specified peer ID. - pub async fn ban_peer(&self, peer_id: PeerId) -> Result<(), SendError> { + pub async fn ban_peer(&self, peer_id: PeerId) -> Result<(), mpsc::SendError> { self.shared .command_sender .clone() @@ -483,7 +482,7 @@ impl Node { /// It could be used to test libp2p transports bypassing protocol checks for bootstrap /// or listen-on addresses. #[doc(hidden)] - pub async fn dial(&self, address: Multiaddr) -> Result<(), SendError> { + pub async fn dial(&self, address: Multiaddr) -> Result<(), mpsc::SendError> { self.shared .command_sender .clone() @@ -559,16 +558,16 @@ impl Node { Ok(()) } - /// Callback is called when a peer is disconnected. - pub fn on_disconnected_peer(&self, callback: HandlerFn) -> HandlerId { - self.shared.handlers.disconnected_peer.add(callback) - } - /// Callback is called when a peer is connected. pub fn on_connected_peer(&self, callback: HandlerFn) -> HandlerId { self.shared.handlers.connected_peer.add(callback) } + /// Callback is called when a peer is disconnected. + pub fn on_disconnected_peer(&self, callback: HandlerFn) -> HandlerId { + self.shared.handlers.disconnected_peer.add(callback) + } + /// Callback is called when a routable or unraoutable peer is discovered. pub fn on_discovered_peer(&self, callback: HandlerFn) -> HandlerId { self.shared.handlers.peer_discovered.add(callback) diff --git a/crates/subspace-networking/src/node_runner.rs b/crates/subspace-networking/src/node_runner.rs index 8d9a8ce34c..c6f4a0c34c 100644 --- a/crates/subspace-networking/src/node_runner.rs +++ b/crates/subspace-networking/src/node_runner.rs @@ -472,7 +472,8 @@ where %peer_id, %is_reserved_peer, ?endpoint, - "Connection established [{num_established} from peer]" + %num_established, + "Connection established" ); let maybe_remote_ip = @@ -524,9 +525,12 @@ where return; } }; + debug!( + %peer_id, ?cause, - "Connection closed with peer {peer_id} [{num_established} from peer]" + %num_established, + "Connection closed with peer" ); if num_established == 0 { @@ -1291,7 +1295,10 @@ where } } Ok(false) => { - panic!("Logic error, topic subscription wasn't created, this must never happen"); + panic!( + "Logic error, topic subscription wasn't created, this \ + must never happen" + ); } Err(error) => { let _ = result_sender.send(Err(error)); diff --git a/crates/subspace-networking/src/shared.rs b/crates/subspace-networking/src/shared.rs index 3b92ff7466..6f21952cc3 100644 --- a/crates/subspace-networking/src/shared.rs +++ b/crates/subspace-networking/src/shared.rs @@ -111,8 +111,8 @@ pub(crate) enum Command { pub(crate) struct Handlers { pub(crate) new_listener: Handler, pub(crate) num_established_peer_connections_change: Handler, - pub(crate) disconnected_peer: Handler, pub(crate) connected_peer: Handler, + pub(crate) disconnected_peer: Handler, pub(crate) peer_discovered: Handler, } diff --git a/shared/subspace-cluster-networking/Cargo.toml b/shared/subspace-cluster-networking/Cargo.toml new file mode 100644 index 0000000000..ba6bfb09dc --- /dev/null +++ b/shared/subspace-cluster-networking/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "subspace-cluster-networking" +version = "0.1.0" +authors = ["Nazar Mokrynskyi "] +description = "Networking functionality for cluster applications" +edition = "2021" +license = "Apache-2.0" +homepage = "https://subspace.network" +repository = "https://github.com/subspace/subspace" +include = [ + "/src", + "/Cargo.toml", +] + +[dependencies] +async-trait = "0.1.77" +backoff = { version = "0.4.0", features = ["tokio"] } +event-listener-primitives = "2.0.1" +futures = "0.3.29" +libp2p = { version = "0.53.2", features = ["dns", "macros", "metrics", "noise", "request-response", "tcp", "tokio", "yamux"] } +parity-scale-codec = { version = "3.6.9", features = ["derive"] } +parking_lot = "0.12.1" +pin-project = "1.1.5" +thiserror = "1.0.56" +tokio = { version = "1.35.1", features = ["macros", "parking_lot", "rt-multi-thread"] } +tracing = "0.1.40" + +[dev-dependencies] +futures_ringbuf = "0.4.0" diff --git a/shared/subspace-cluster-networking/src/behavior.rs b/shared/subspace-cluster-networking/src/behavior.rs new file mode 100644 index 0000000000..d7fc590f6f --- /dev/null +++ b/shared/subspace-cluster-networking/src/behavior.rs @@ -0,0 +1,61 @@ +use crate::request_response::NoCodec; +use libp2p::request_response::{ + Behaviour as RequestResponse, Config as RequestResponseConfig, Event as RequestResponseEvent, + ProtocolSupport, +}; +use libp2p::swarm::NetworkBehaviour; +use libp2p::StreamProtocol; +use std::iter; +use std::time::Duration; + +#[derive(Debug)] +pub(crate) enum Event { + RequestResponse(RequestResponseEvent, Vec>), +} + +impl From, Vec>> for Event { + fn from(value: RequestResponseEvent, Vec>) -> Self { + Self::RequestResponse(value) + } +} + +pub struct BehaviorConfig { + pub request_response_protocol: &'static str, + /// Maximum allowed size, in bytes, of a request. + /// + /// Any request larger than this value will be declined as a way to avoid allocating too + /// much memory for it. + pub max_request_size: u64, + /// Maximum allowed size, in bytes, of a response. + /// + /// Any response larger than this value will be declined as a way to avoid allocating too + /// much memory for it. + pub max_response_size: u64, + /// Timeout for inbound and outbound requests + pub request_timeout: Duration, + /// Upper bound for the number of concurrent inbound + outbound streams + pub max_concurrent_streams: usize, +} + +#[derive(NetworkBehaviour)] +#[behaviour(to_swarm = "Event")] +pub(crate) struct Behavior { + pub(crate) request_response: RequestResponse, +} + +impl Behavior { + pub(crate) fn new(config: BehaviorConfig) -> Self { + let request_response = RequestResponse::with_codec( + NoCodec::new(config.max_request_size, config.max_response_size), + iter::once(( + StreamProtocol::new(config.request_response_protocol), + ProtocolSupport::Full, + )), + RequestResponseConfig::default() + .with_request_timeout(config.request_timeout) + .with_max_concurrent_streams(config.max_concurrent_streams), + ); + + Self { request_response } + } +} diff --git a/shared/subspace-cluster-networking/src/lib.rs b/shared/subspace-cluster-networking/src/lib.rs new file mode 100644 index 0000000000..a1756213e6 --- /dev/null +++ b/shared/subspace-cluster-networking/src/lib.rs @@ -0,0 +1,10 @@ +#![feature(assert_matches)] + +mod behavior; +pub mod network; +pub mod network_worker; +mod request_response; +mod shared; +#[cfg(test)] +mod tests; +mod utils; diff --git a/shared/subspace-cluster-networking/src/network.rs b/shared/subspace-cluster-networking/src/network.rs new file mode 100644 index 0000000000..49124a3daf --- /dev/null +++ b/shared/subspace-cluster-networking/src/network.rs @@ -0,0 +1,193 @@ +use crate::behavior::{Behavior, BehaviorConfig}; +use crate::network_worker::{InboundRequestsHandler, NetworkWorker}; +use crate::shared::{Command, HandlerFn, Shared}; +use event_listener_primitives::HandlerId; +use futures::channel::{mpsc, oneshot}; +use futures::SinkExt; +use libp2p::identity::Keypair; +use libp2p::metrics::Metrics; +use libp2p::noise::Config as NoiseConfig; +use libp2p::request_response::OutboundFailure; +use libp2p::yamux::Config as YamuxConfig; +use libp2p::{Multiaddr, PeerId, SwarmBuilder}; +use parity_scale_codec::{Decode, Encode}; +use std::error::Error; +use std::marker::PhantomData; +use std::sync::Arc; +use std::time::Duration; +use thiserror::Error; + +/// Generic request with associated response +pub trait GenericRequest: Encode + Decode + Send + Sync + 'static { + /// Response type that corresponds to this request + type Response: Encode + Decode + Send + Sync + 'static; +} + +/// Request sending errors +#[derive(Debug, Error)] +pub enum SendRequestError { + /// Failed to send command to the node runner + #[error("Failed to send command to the node runner: {0}")] + SendCommand(#[from] mpsc::SendError), + /// Worker was dropped + #[error("Worker was dropped")] + WorkerDropped, + /// Underlying protocol returned an error, impossible to get response + #[error("Underlying protocol returned an error: {0}")] + ProtocolFailure(#[from] OutboundFailure), + /// Underlying protocol returned an incorrect format, impossible to get response + #[error("Received incorrectly formatted response: {0}")] + IncorrectResponseFormat(#[from] parity_scale_codec::Error), + /// Unrecognized response + #[error("Unrecognized response")] + UnrecognizedResponse(Box), +} + +impl From for SendRequestError { + #[inline] + fn from(oneshot::Canceled: oneshot::Canceled) -> Self { + Self::WorkerDropped + } +} + +/// Network configuration +pub struct NetworkConfig { + /// Bootstrap nodes + pub bootstrap_nodes: Vec, + /// Multiaddrs to listen on + pub listen_on: Vec, + /// Keypair to use + pub keypair: Keypair, + /// Network key to limit connections to those who know the key + pub network_key: Vec, + /// Behavior config + pub behavior_config: BehaviorConfig, + /// How long to keep a connection alive once it is idling + pub idle_connection_timeout: Duration, + /// Handler for incoming requests + pub request_handler: InboundRequestsHandler, + /// Optional libp2p metrics + pub metrics: Option, +} + +/// Implementation of a network +#[derive(Debug)] +#[must_use = "Network doesn't do anything if dropped"] +pub struct Network { + id: PeerId, + shared: Arc, + phantom: PhantomData<(Requests, Responses)>, +} + +impl Clone for Network { + fn clone(&self) -> Self { + Self { + id: self.id, + shared: Arc::clone(&self.shared), + phantom: PhantomData, + } + } +} + +impl Network +where + Requests: Encode + Decode + Send, + Responses: Encode + Decode + Send + 'static, +{ + pub fn new( + config: NetworkConfig, + ) -> Result<(Self, NetworkWorker), Box> { + let mut swarm = SwarmBuilder::with_existing_identity(config.keypair) + .with_tokio() + .with_tcp( + Default::default(), + |keypair: &Keypair| { + NoiseConfig::new(keypair) + .map(|noise_config| noise_config.with_prologue(config.network_key)) + }, + YamuxConfig::default, + )? + .with_dns()? + .with_behaviour(move |_keypair| Ok(Behavior::new(config.behavior_config))) + .expect("Not fallible; qed") + .with_swarm_config(|swarm_config| { + swarm_config.with_idle_connection_timeout(config.idle_connection_timeout) + }) + .build(); + + // Setup listen_on addresses + for addr in config.listen_on { + swarm.listen_on(addr.clone())?; + } + + let (command_sender, command_receiver) = mpsc::channel(1); + let shared = Arc::new(Shared::new(command_sender)); + let shared_weak = Arc::downgrade(&shared); + + let network = Self { + id: *swarm.local_peer_id(), + shared, + phantom: PhantomData, + }; + let network_worker = NetworkWorker::new( + config.request_handler, + command_receiver, + swarm, + shared_weak, + config.bootstrap_nodes, + config.metrics, + ); + + Ok((network, network_worker)) + } + + /// Node's own local ID. + pub fn id(&self) -> PeerId { + self.id + } + + /// Sends the generic request to the peer at specified address and awaits the result + pub async fn request( + &self, + peer_id: PeerId, + addresses: Vec, + request: Request, + ) -> Result + where + Request: GenericRequest, + Request: Into, + Request::Response: TryFrom, + <::Response as TryFrom>::Error: Into>, + { + let (result_sender, result_receiver) = oneshot::channel(); + let command = Command::Request { + peer_id, + addresses, + request: Into::::into(request).encode(), + result_sender, + }; + + self.shared.command_sender.clone().send(command).await?; + + let result = result_receiver.await??; + + let responses = Responses::decode(&mut result.as_slice())?; + Request::Response::try_from(responses) + .map_err(|error| SendRequestError::UnrecognizedResponse(error.into())) + } + + /// Callback is called when node starts listening on new address. + pub fn on_new_listener(&self, callback: HandlerFn) -> HandlerId { + self.shared.handlers.new_listener.add(callback) + } + + /// Callback is called when a peer is connected. + pub fn on_connected_peer(&self, callback: HandlerFn) -> HandlerId { + self.shared.handlers.connected_peer.add(callback) + } + + /// Callback is called when a peer is disconnected. + pub fn on_disconnected_peer(&self, callback: HandlerFn) -> HandlerId { + self.shared.handlers.disconnected_peer.add(callback) + } +} diff --git a/shared/subspace-cluster-networking/src/network_worker.rs b/shared/subspace-cluster-networking/src/network_worker.rs new file mode 100644 index 0000000000..fc808facd9 --- /dev/null +++ b/shared/subspace-cluster-networking/src/network_worker.rs @@ -0,0 +1,450 @@ +use crate::behavior::{Behavior, Event}; +use crate::shared::{Command, Shared}; +use crate::utils::AsyncJoinOnDrop; +use backoff::backoff::Backoff; +use backoff::ExponentialBackoff; +use futures::channel::{mpsc, oneshot}; +use futures::stream::FuturesUnordered; +use futures::StreamExt; +use libp2p::metrics::{Metrics, Recorder}; +use libp2p::multiaddr::Protocol; +use libp2p::request_response::{ + Event as RequestResponseEvent, InboundRequestId, Message, OutboundFailure, OutboundRequestId, + ResponseChannel, +}; +use libp2p::swarm::dial_opts::{DialOpts, PeerCondition}; +use libp2p::swarm::{DialError, SwarmEvent}; +use libp2p::{Multiaddr, PeerId, Swarm}; +use parity_scale_codec::{Decode, Encode}; +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Weak; +use tokio::task::yield_now; +use tokio::time::sleep; +use tracing::{debug, error, trace, warn}; + +pub type InboundRequestsHandler = + Box Pin + Send>> + Send>; + +#[derive(Debug)] +struct BootstrapNode { + backoff: ExponentialBackoff, + addresses: Vec, +} + +impl Default for BootstrapNode { + fn default() -> Self { + BootstrapNode { + backoff: ExponentialBackoff { + max_elapsed_time: None, + ..ExponentialBackoff::default() + }, + addresses: vec![], + } + } +} + +pub struct NetworkWorker { + bootstrap_nodes: HashMap, + request_handler: InboundRequestsHandler, + command_receiver: mpsc::Receiver, + swarm: Swarm, + shared_weak: Weak, + redials: FuturesUnordered)>>, + #[allow(clippy::type_complexity)] + inbound_requests: FuturesUnordered< + AsyncJoinOnDrop<(InboundRequestId, PeerId, ResponseChannel>, Vec)>, + >, + #[allow(clippy::type_complexity)] + pending_outbound_requests: + HashMap, oneshot::Sender, OutboundFailure>>)>>, + outbound_requests: + HashMap, OutboundFailure>>>, + metrics: Option, +} + +impl NetworkWorker +where + Requests: Decode + Send, + Responses: Encode + Send + 'static, +{ + pub(crate) fn new( + request_handler: InboundRequestsHandler, + command_receiver: mpsc::Receiver, + swarm: Swarm, + shared_weak: Weak, + bootstrap_nodes: Vec, + metrics: Option, + ) -> Self { + let mut grouped_bootstrap_nodes = HashMap::::new(); + for mut address in bootstrap_nodes { + if let Some(Protocol::P2p(peer_id)) = address.pop() { + grouped_bootstrap_nodes + .entry(peer_id) + .or_default() + .addresses + .push(address); + } + } + + Self { + bootstrap_nodes: grouped_bootstrap_nodes, + request_handler, + command_receiver, + swarm, + shared_weak, + redials: FuturesUnordered::default(), + inbound_requests: FuturesUnordered::default(), + pending_outbound_requests: HashMap::default(), + outbound_requests: HashMap::default(), + metrics, + } + } + + /// Drives the network worker + pub async fn run(&mut self) { + for (peer_id, bootstrap_node) in self.bootstrap_nodes.iter() { + for address in bootstrap_node.addresses.clone() { + self.swarm + .behaviour_mut() + .request_response + .add_address(peer_id, address); + } + if let Err(error) = self.swarm.dial( + DialOpts::peer_id(*peer_id) + .addresses(bootstrap_node.addresses.clone()) + .build(), + ) { + error!(%error, %peer_id, "Failed to dial bootstrap node"); + } + } + + loop { + futures::select! { + swarm_event = self.swarm.next() => { + if let Some(swarm_event) = swarm_event { + self.register_event_metrics(&swarm_event); + self.handle_swarm_event(swarm_event).await; + } else { + break; + } + }, + redial_result = self.redials.select_next_some() => { + match redial_result { + Ok((peer_id, addresses)) => { + if let Err(error) = self.swarm.dial( + DialOpts::peer_id(peer_id) + .addresses(addresses) + .build(), + ) { + error!(%error, %peer_id, "Failed to redial peer"); + } + } + Err(error) => { + error!(%error, "Redial task error"); + } + } + }, + inbound_request_result = self.inbound_requests.select_next_some() => { + match inbound_request_result { + Ok((request_id, peer, channel, response)) => { + self.handle_inbound_request_response(request_id, peer, channel, response); + } + Err(error) => { + error!(%error, "Failed to join inbound request"); + } + } + }, + command = self.command_receiver.next() => { + if let Some(command) = command { + self.handle_command(command); + } else { + break; + } + }, + } + + // Allow to exit from busy loop during graceful shutdown + yield_now().await; + } + } + + async fn handle_swarm_event(&mut self, swarm_event: SwarmEvent) { + match swarm_event { + SwarmEvent::Behaviour(Event::RequestResponse(event)) => { + self.handle_request_response_event(event).await; + } + SwarmEvent::NewListenAddr { address, .. } => { + let shared = match self.shared_weak.upgrade() { + Some(shared) => shared, + None => { + return; + } + }; + shared.listeners.lock().push(address.clone()); + shared.handlers.new_listener.call_simple(&address); + } + SwarmEvent::ConnectionEstablished { + peer_id, + endpoint, + num_established, + .. + } => { + let shared = match self.shared_weak.upgrade() { + Some(shared) => shared, + None => { + return; + } + }; + + debug!( + %peer_id, + ?endpoint, + %num_established, + "Connection established" + ); + + // A new connection + if num_established.get() == 1 { + shared.handlers.connected_peer.call_simple(&peer_id); + } + + // If bootstrap node then reset retries + if let Some(bootstrap_node) = self.bootstrap_nodes.get_mut(&peer_id) { + bootstrap_node.backoff.reset(); + } + + // Process any pending requests for this peer + if let Some(pending_outbound_requests) = + self.pending_outbound_requests.remove(&peer_id) + { + for (request, result_sender) in pending_outbound_requests { + let request_id = self + .swarm + .behaviour_mut() + .request_response + .send_request(&peer_id, request); + self.outbound_requests.insert(request_id, result_sender); + } + } + } + SwarmEvent::ConnectionClosed { + peer_id, + num_established, + cause, + .. + } => { + let shared = match self.shared_weak.upgrade() { + Some(shared) => shared, + None => { + return; + } + }; + + debug!( + %peer_id, + ?cause, + %num_established, + "Connection closed with peer" + ); + + // No more connections + if num_established == 0 { + shared.handlers.disconnected_peer.call_simple(&peer_id); + + // In case of disconnection from bootstrap node reconnect to it + if let Some(bootstrap_node) = self.bootstrap_nodes.get_mut(&peer_id) { + if let Err(error) = self.swarm.dial( + DialOpts::peer_id(peer_id) + .addresses(bootstrap_node.addresses.clone()) + .build(), + ) { + error!(%error, %peer_id, "Failed to dial bootstrap node"); + } + } + } + } + SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => { + if let Some(peer_id) = peer_id { + warn!(%error, %peer_id, "Failed to establish outgoing connection"); + + // If bootstrap node then retry after some delay + if let Some(bootstrap_node) = self.bootstrap_nodes.get_mut(&peer_id) { + if let Some(delay) = bootstrap_node.backoff.next_backoff() { + let addresses = bootstrap_node.addresses.clone(); + + self.redials.push(AsyncJoinOnDrop::new( + tokio::spawn(async move { + sleep(delay).await; + + (peer_id, addresses) + }), + true, + )) + } + } + // Send errors to all pending requests for this peer + if let Some(pending_outbound_requests) = + self.pending_outbound_requests.remove(&peer_id) + { + for (_request, result_sender) in pending_outbound_requests { + let _ = result_sender.send(Err(OutboundFailure::DialFailure)); + } + } + } + } + other => { + trace!("Other swarm event: {:?}", other); + } + } + } + + async fn handle_request_response_event( + &mut self, + event: RequestResponseEvent, Vec>, + ) { + match event { + RequestResponseEvent::Message { peer, message } => match message { + Message::Request { + request_id, + request, + channel, + } => { + let request = match Requests::decode(&mut request.as_slice()) { + Ok(request) => request, + Err(error) => { + warn!(%error, "Failed to decode requests"); + return; + } + }; + let response_fut = (self.request_handler)(request); + + self.inbound_requests.push(AsyncJoinOnDrop::new( + tokio::spawn(async move { + let response = response_fut.await.encode(); + (request_id, peer, channel, response) + }), + true, + )); + } + Message::Response { + request_id, + response, + } => { + if let Some(sender) = self.outbound_requests.remove(&request_id) { + let _ = sender.send(Ok(response)); + } + } + }, + RequestResponseEvent::OutboundFailure { + peer, + request_id, + error, + } => { + debug!( + %peer, + %request_id, + %error, + "Outbound request failed" + ); + + if let Some(sender) = self.outbound_requests.remove(&request_id) { + let _ = sender.send(Err(error)); + } + } + RequestResponseEvent::InboundFailure { + peer, + request_id, + error, + } => { + debug!( + %peer, + %request_id, + %error, + "Inbound request failed" + ); + } + RequestResponseEvent::ResponseSent { .. } => { + // Not interested + } + } + } + + fn handle_inbound_request_response( + &mut self, + request_id: InboundRequestId, + peer: PeerId, + channel: ResponseChannel>, + response: Vec, + ) { + if !channel.is_open() { + trace!(%peer, %request_id, "Response channel already closed"); + return; + } + + if self + .swarm + .behaviour_mut() + .request_response + .send_response(channel, response) + .is_err() + { + debug!(%peer, %request_id, "Response sending failed"); + } + } + + fn handle_command(&mut self, command: Command) { + match command { + Command::Request { + peer_id, + addresses, + request, + result_sender, + } => { + let request_response = &mut self.swarm.behaviour_mut().request_response; + if request_response.is_connected(&peer_id) { + // If already connected - send request right away + let request_id = request_response.send_request(&peer_id, request); + self.outbound_requests.insert(request_id, result_sender); + } else { + // Otherwise try to dial + match self.swarm.dial( + DialOpts::peer_id(peer_id) + .addresses(addresses) + .condition(PeerCondition::DisconnectedAndNotDialing) + .build(), + ) { + Ok(()) | Err(DialError::DialPeerConditionFalse(_)) => { + // In case dial initiated successfully, or it was initiated prior - + // store pending request + self.pending_outbound_requests + .entry(peer_id) + .or_default() + .push((request, result_sender)); + } + Err(error) => { + warn!(%error, %peer_id, "Failed to dial peer on request"); + let _ = result_sender.send(Err(OutboundFailure::DialFailure)); + } + } + } + } + } + } + + fn register_event_metrics(&mut self, swarm_event: &SwarmEvent) { + if let Some(ref mut metrics) = self.metrics { + #[allow(clippy::match_single_binding)] + match swarm_event { + // TODO: implement in the upstream repository + // SwarmEvent::Behaviour(Event::RequestResponse(request_response_event)) => { + // self.metrics.record(request_response_event); + // } + swarm_event => { + metrics.record(swarm_event); + } + } + } + } +} diff --git a/shared/subspace-cluster-networking/src/request_response.rs b/shared/subspace-cluster-networking/src/request_response.rs new file mode 100644 index 0000000000..17b4376f41 --- /dev/null +++ b/shared/subspace-cluster-networking/src/request_response.rs @@ -0,0 +1,146 @@ +use async_trait::async_trait; +use futures::prelude::*; +use libp2p::swarm::StreamProtocol; +use std::io; + +/// A request-response codec using that sends bytes without extra encoding. +#[derive(Debug, Copy, Clone)] +pub struct NoCodec { + /// Maximum allowed size, in bytes, of a request. + /// + /// Any request larger than this value will be declined as a way to avoid allocating too + /// much memory for it. + pub max_request_size: u64, + /// Maximum allowed size, in bytes, of a response. + /// + /// Any response larger than this value will be declined as a way to avoid allocating too + /// much memory for it. + pub max_response_size: u64, +} + +impl NoCodec { + pub fn new(max_request_size: u64, max_response_size: u64) -> Self { + Self { + max_request_size, + max_response_size, + } + } +} + +#[async_trait] +impl libp2p::request_response::Codec for NoCodec { + type Protocol = StreamProtocol; + type Request = Vec; + type Response = Vec; + + async fn read_request(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result> + where + T: AsyncRead + Unpin + Send, + { + let mut vec = Vec::new(); + + let len = io.take(self.max_request_size).read_to_end(&mut vec).await?; + + vec.truncate(len); + + Ok(vec) + } + + async fn read_response(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result> + where + T: AsyncRead + Unpin + Send, + { + let mut vec = Vec::new(); + + let len = io + .take(self.max_response_size) + .read_to_end(&mut vec) + .await?; + + vec.truncate(len); + + Ok(vec) + } + + async fn write_request( + &mut self, + _: &Self::Protocol, + io: &mut T, + request: Vec, + ) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + { + io.write_all(&request).await + } + + async fn write_response( + &mut self, + _: &Self::Protocol, + io: &mut T, + response: Vec, + ) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + { + io.write_all(&response).await + } +} + +#[cfg(test)] +mod tests { + use super::NoCodec; + use futures::prelude::*; + use futures_ringbuf::Endpoint; + use libp2p::request_response::Codec; + use libp2p::swarm::StreamProtocol; + use parity_scale_codec::{Decode, Encode}; + + #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)] + struct TestRequest { + payload: String, + } + + #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)] + struct TestResponse { + payload: String, + } + + #[tokio::test] + async fn test_codec() { + let expected_request = b"test_payload".to_vec(); + let expected_response = b"test_payload".to_vec(); + let protocol = StreamProtocol::new("/test_vec/1"); + let mut codec = NoCodec::new(1024, 1024); + + let (mut a, mut b) = Endpoint::pair(124, 124); + codec + .write_request(&protocol, &mut a, expected_request.clone()) + .await + .expect("Should write request"); + a.close().await.unwrap(); + + let actual_request = codec + .read_request(&protocol, &mut b) + .await + .expect("Should read request"); + b.close().await.unwrap(); + + assert_eq!(actual_request, expected_request); + + let (mut a, mut b) = Endpoint::pair(124, 124); + codec + .write_response(&protocol, &mut a, expected_response.clone()) + .await + .expect("Should write response"); + a.close().await.unwrap(); + + let actual_response = codec + .read_response(&protocol, &mut b) + .await + .expect("Should read response"); + b.close().await.unwrap(); + + assert_eq!(actual_response, expected_response); + } +} diff --git a/shared/subspace-cluster-networking/src/shared.rs b/shared/subspace-cluster-networking/src/shared.rs new file mode 100644 index 0000000000..4e03f6bb1d --- /dev/null +++ b/shared/subspace-cluster-networking/src/shared.rs @@ -0,0 +1,48 @@ +//! Data structures shared between node and node runner, facilitating exchange and creation of +//! queries, subscriptions, various events and shared information. + +use event_listener_primitives::Bag; +use futures::channel::{mpsc, oneshot}; +use libp2p::request_response::OutboundFailure; +use libp2p::{Multiaddr, PeerId}; +use parking_lot::Mutex; +use std::sync::Arc; + +pub(crate) type HandlerFn = Arc; +pub(crate) type Handler = Bag, A>; + +#[derive(Debug)] +pub(crate) enum Command { + Request { + peer_id: PeerId, + addresses: Vec, + request: Vec, + result_sender: oneshot::Sender, OutboundFailure>>, + }, +} + +#[derive(Default, Debug)] +pub(crate) struct Handlers { + pub(crate) new_listener: Handler, + pub(crate) connected_peer: Handler, + pub(crate) disconnected_peer: Handler, +} + +#[derive(Debug)] +pub(crate) struct Shared { + pub(crate) handlers: Handlers, + /// Addresses on which node is listening for incoming requests. + pub(crate) listeners: Mutex>, + /// Sender end of the channel for sending commands to the swarm. + pub(crate) command_sender: mpsc::Sender, +} + +impl Shared { + pub(crate) fn new(command_sender: mpsc::Sender) -> Self { + Self { + handlers: Handlers::default(), + listeners: Mutex::default(), + command_sender, + } + } +} diff --git a/shared/subspace-cluster-networking/src/tests.rs b/shared/subspace-cluster-networking/src/tests.rs new file mode 100644 index 0000000000..50afc86907 --- /dev/null +++ b/shared/subspace-cluster-networking/src/tests.rs @@ -0,0 +1,275 @@ +use crate::behavior::BehaviorConfig; +use crate::network::{GenericRequest, Network, NetworkConfig, SendRequestError}; +use futures::channel::oneshot; +use libp2p::identity::Keypair; +use libp2p::multiaddr::Protocol; +use libp2p::request_response::OutboundFailure; +use parity_scale_codec::{Decode, Encode}; +use parking_lot::Mutex; +use std::assert_matches::assert_matches; +use std::error::Error; +use std::sync::Arc; +use std::time::Duration; + +const MAX_REQUEST_SIZE: u64 = 1024; +const MAX_RESPONSE_SIZE: u64 = 1024; +const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); +const MAX_CONCURRENT_STREAMS: usize = 1024; +const IDLE_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); + +impl GenericRequest for String { + type Response = String; +} + +#[derive(Debug, Encode, Decode)] +enum Requests { + S(String), +} + +impl From for Requests { + fn from(value: String) -> Self { + Self::S(value) + } +} + +#[derive(Debug, Encode, Decode)] +enum Responses { + S(String), +} + +impl TryFrom for String { + type Error = Box; + + fn try_from(Responses::S(s): Responses) -> Result { + Ok(s) + } +} + +fn typical_behavior_config() -> BehaviorConfig { + BehaviorConfig { + request_response_protocol: "/request_response_protocol", + max_request_size: MAX_REQUEST_SIZE, + max_response_size: MAX_RESPONSE_SIZE, + request_timeout: REQUEST_TIMEOUT, + max_concurrent_streams: MAX_CONCURRENT_STREAMS, + } +} + +#[tokio::test] +async fn basic() { + let (peer_1, mut peer_1_worker) = Network::::new(NetworkConfig { + bootstrap_nodes: vec![], + listen_on: vec!["/ip4/0.0.0.0/tcp/0".parse().unwrap()], + keypair: Keypair::generate_ed25519(), + network_key: vec![], + behavior_config: typical_behavior_config(), + idle_connection_timeout: IDLE_CONNECTION_TIMEOUT, + request_handler: Box::new(|Requests::S(request)| { + Box::pin(async move { Responses::S(format!("response: {request}")) }) + }), + metrics: None, + }) + .unwrap(); + + let peer_1_addr = { + let (peer_1_address_sender, peer_1_address_receiver) = oneshot::channel(); + let _on_new_listener_handler = peer_1.on_new_listener(Arc::new({ + let peer_1_address_sender = Mutex::new(Some(peer_1_address_sender)); + + move |address| { + if matches!(address.iter().next(), Some(Protocol::Ip4(_))) { + if let Some(peer_1_address_sender) = peer_1_address_sender.lock().take() { + peer_1_address_sender.send(address.clone()).unwrap(); + } + } + } + })); + + tokio::spawn(async move { + peer_1_worker.run().await; + }); + + // Wait for first peer to know its address + let mut peer_1_addr = peer_1_address_receiver.await.unwrap(); + peer_1_addr.push(Protocol::P2p(peer_1.id())); + peer_1_addr + }; + + let (peer_2, mut peer_2_worker) = Network::::new(NetworkConfig { + bootstrap_nodes: vec![peer_1_addr.clone()], + listen_on: vec!["/ip4/0.0.0.0/tcp/0".parse().unwrap()], + keypair: Keypair::generate_ed25519(), + network_key: vec![], + behavior_config: typical_behavior_config(), + idle_connection_timeout: IDLE_CONNECTION_TIMEOUT, + request_handler: Box::new(|Requests::S(request)| { + Box::pin(async move { Responses::S(format!("response: {request}")) }) + }), + metrics: None, + }) + .unwrap(); + + let peer_2_addr = { + let (connected_sender, connected_receiver) = oneshot::channel::<()>(); + let connected_sender = Mutex::new(Some(connected_sender)); + let _connected_handler_id = peer_2.on_connected_peer(Arc::new({ + move |_| { + connected_sender.lock().take(); + } + })); + + let (peer_2_address_sender, peer_2_address_receiver) = oneshot::channel(); + let _on_new_listener_handler = peer_2.on_new_listener(Arc::new({ + let peer_2_address_sender = Mutex::new(Some(peer_2_address_sender)); + + move |address| { + if matches!(address.iter().next(), Some(Protocol::Ip4(_))) { + if let Some(peer_2_address_sender) = peer_2_address_sender.lock().take() { + peer_2_address_sender.send(address.clone()).unwrap(); + } + } + } + })); + + tokio::spawn(async move { + peer_2_worker.run().await; + }); + + // Wait for second peer to know its address + let mut peer_2_addr = peer_2_address_receiver.await.unwrap(); + peer_2_addr.push(Protocol::P2p(peer_2.id())); + + // Wait for connection to bootstrap node + let _ = connected_receiver.await; + + // Basic request to bootstrap node succeeds + let response = peer_2 + .request(peer_1.id(), vec![], "hello".to_string()) + .await + .unwrap(); + assert_eq!(response, "response: hello"); + + peer_2_addr + }; + + { + let (peer_3, mut peer_3_worker) = Network::::new(NetworkConfig { + bootstrap_nodes: vec![peer_1_addr.clone()], + listen_on: vec!["/ip4/0.0.0.0/tcp/0".parse().unwrap()], + keypair: Keypair::generate_ed25519(), + network_key: vec![0, 1, 2, 3], + behavior_config: typical_behavior_config(), + idle_connection_timeout: IDLE_CONNECTION_TIMEOUT, + request_handler: Box::new(|_| unreachable!()), + metrics: None, + }) + .unwrap(); + + tokio::spawn(async move { + peer_3_worker.run().await; + }); + + // Network key mismatch results in dial failure + let response = peer_3 + .request(peer_1.id(), vec![peer_1_addr.clone()], "hello".to_string()) + .await; + assert_matches!( + response, + Err(SendRequestError::ProtocolFailure( + OutboundFailure::DialFailure + )) + ); + } + + { + let idle_connection_timeout = Duration::from_millis(10); + + let (peer_4, mut peer_4_worker) = Network::::new(NetworkConfig { + bootstrap_nodes: vec![peer_1_addr.clone()], + listen_on: vec!["/ip4/0.0.0.0/tcp/0".parse().unwrap()], + keypair: Keypair::generate_ed25519(), + network_key: vec![], + behavior_config: typical_behavior_config(), + idle_connection_timeout, + request_handler: Box::new(|_| unreachable!()), + metrics: None, + }) + .unwrap(); + + tokio::spawn(async move { + peer_4_worker.run().await; + }); + + let (disconnected_sender, disconnected_receiver) = oneshot::channel::<()>(); + let disconnected_sender = Mutex::new(Some(disconnected_sender)); + let mut disconnected_receiver = Some(disconnected_receiver); + let _disconnected_handler_id = peer_4.on_disconnected_peer(Arc::new({ + move |_| { + disconnected_sender.lock().take(); + } + })); + + // Try twice with an interval larger than idle connection timeout to make sure it reconnects + // successfully + for _ in 0..2 { + let response = peer_4 + .request(peer_1.id(), vec![peer_1_addr.clone()], "hello".to_string()) + .await + .unwrap(); + + assert_eq!(response, "response: hello"); + + if let Some(disconnected_receiver) = disconnected_receiver.take() { + let _ = disconnected_receiver.await; + } + } + } + + { + let (peer_5, mut peer_5_worker) = Network::::new(NetworkConfig { + bootstrap_nodes: vec![peer_1_addr.clone()], + listen_on: vec!["/ip4/0.0.0.0/tcp/0".parse().unwrap()], + keypair: Keypair::generate_ed25519(), + network_key: vec![], + behavior_config: typical_behavior_config(), + idle_connection_timeout: IDLE_CONNECTION_TIMEOUT, + request_handler: Box::new(|_| unreachable!()), + metrics: None, + }) + .unwrap(); + + tokio::spawn(async move { + peer_5_worker.run().await; + }); + + // Initially not connected to the second peer + { + let response = peer_5 + .request(peer_2.id(), vec![], "hello".to_string()) + .await; + assert_matches!( + response, + Err(SendRequestError::ProtocolFailure( + OutboundFailure::DialFailure + )) + ); + } + + // With explicit address connection succeeds + { + let response = peer_5 + .request(peer_2.id(), vec![peer_2_addr.clone()], "hello".to_string()) + .await + .unwrap(); + assert_eq!(response, "response: hello"); + } + // And also succeeds without address shortly after due to already established connection + { + let response = peer_5 + .request(peer_2.id(), vec![], "hello".to_string()) + .await + .unwrap(); + assert_eq!(response, "response: hello"); + } + } +} diff --git a/shared/subspace-cluster-networking/src/utils.rs b/shared/subspace-cluster-networking/src/utils.rs new file mode 100644 index 0000000000..3e704e9e9d --- /dev/null +++ b/shared/subspace-cluster-networking/src/utils.rs @@ -0,0 +1,50 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::runtime::Handle; +use tokio::task; + +/// Joins async join handle on drop +pub struct AsyncJoinOnDrop { + handle: Option>, + abort_on_drop: bool, +} + +impl Drop for AsyncJoinOnDrop { + fn drop(&mut self) { + if let Some(handle) = self.handle.take() { + if self.abort_on_drop { + handle.abort(); + } + + if !handle.is_finished() { + task::block_in_place(move || { + let _ = Handle::current().block_on(handle); + }); + } + } + } +} + +impl AsyncJoinOnDrop { + /// Create new instance. + pub fn new(handle: task::JoinHandle, abort_on_drop: bool) -> Self { + Self { + handle: Some(handle), + abort_on_drop, + } + } +} + +impl Future for AsyncJoinOnDrop { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new( + self.handle + .as_mut() + .expect("Only dropped in Drop impl; qed"), + ) + .poll(cx) + } +}