diff --git a/transports/quic/src/hole_punching.rs b/transports/quic/src/hole_punching.rs index 2b4dc81ee15..c1e9a447c6d 100644 --- a/transports/quic/src/hole_punching.rs +++ b/transports/quic/src/hole_punching.rs @@ -5,7 +5,6 @@ use futures::future::Either; use rand::{distributions, Rng}; use std::{ - io, net::{SocketAddr, UdpSocket}, time::Duration, }; @@ -33,10 +32,8 @@ async fn punch_holes(socket: UdpSocket, remote_addr: SocketAddr) -> .take(64) .collect(); - if let Err(e) = socket.send_to(&contents, remote_addr) { - if !matches!(e.kind(), io::ErrorKind::WouldBlock) { - return Error::Io(e); - } + if let Err(e) = P::send_to(&socket, &contents, remote_addr).await { + return Error::Io(e); } } } diff --git a/transports/quic/src/provider.rs b/transports/quic/src/provider.rs index 8a2ca62cb44..26e9e35902f 100644 --- a/transports/quic/src/provider.rs +++ b/transports/quic/src/provider.rs @@ -22,6 +22,7 @@ use futures::{future::BoxFuture, Future}; use if_watch::IfEvent; use std::{ io, + net::{SocketAddr, UdpSocket}, task::{Context, Poll}, time::Duration, }; @@ -62,4 +63,11 @@ pub trait Provider: Unpin + Send + Sized + 'static { /// Sleep for specified amount of time. fn sleep(duration: Duration) -> BoxFuture<'static, ()>; + + /// Sends data on the socket to the given address. On success, returns the number of bytes written. + fn send_to<'a>( + udp_socket: &'a UdpSocket, + buf: &'a [u8], + target: SocketAddr, + ) -> BoxFuture<'a, io::Result>; } diff --git a/transports/quic/src/provider/async_std.rs b/transports/quic/src/provider/async_std.rs index da28727aed1..4721cba9b32 100644 --- a/transports/quic/src/provider/async_std.rs +++ b/transports/quic/src/provider/async_std.rs @@ -22,6 +22,7 @@ use async_std::task::spawn; use futures::{future::BoxFuture, Future, FutureExt}; use std::{ io, + net::UdpSocket, task::{Context, Poll}, time::Duration, }; @@ -59,4 +60,16 @@ impl super::Provider for Provider { fn sleep(duration: Duration) -> BoxFuture<'static, ()> { async_std::task::sleep(duration).boxed() } + + fn send_to<'a>( + udp_socket: &'a UdpSocket, + buf: &'a [u8], + target: std::net::SocketAddr, + ) -> BoxFuture<'a, io::Result> { + Box::pin(async move { + async_std::net::UdpSocket::from(udp_socket.try_clone()?) + .send_to(buf, target) + .await + }) + } } diff --git a/transports/quic/src/provider/tokio.rs b/transports/quic/src/provider/tokio.rs index 7accc3ce60b..b32f7ee184f 100644 --- a/transports/quic/src/provider/tokio.rs +++ b/transports/quic/src/provider/tokio.rs @@ -21,6 +21,7 @@ use futures::{future::BoxFuture, Future, FutureExt}; use std::{ io, + net::{SocketAddr, UdpSocket}, task::{Context, Poll}, time::Duration, }; @@ -58,4 +59,16 @@ impl super::Provider for Provider { fn sleep(duration: Duration) -> BoxFuture<'static, ()> { tokio::time::sleep(duration).boxed() } + + fn send_to<'a>( + udp_socket: &'a UdpSocket, + buf: &'a [u8], + target: SocketAddr, + ) -> BoxFuture<'a, io::Result> { + Box::pin(async move { + tokio::net::UdpSocket::from_std(udp_socket.try_clone()?)? + .send_to(buf, target) + .await + }) + } }