diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index 0b202ff157..ff0bf1c19a 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -434,6 +434,10 @@ where /// Improves behavior for clients that move between different internet connections or suffer NAT /// rebinding. Enabled by default. pub(crate) migration: bool, + + /// Whether to use the initial local IP address for outgoing packets + /// instead of the IP the the socket is bound to. + pub(crate) send_from_initial_ip: bool, } impl ServerConfig @@ -453,6 +457,7 @@ where concurrent_connections: 100_000, migration: true, + send_from_initial_ip: true, } } @@ -492,6 +497,21 @@ where self.migration = value; self } + + /// Whether to use the initial local IP address for outgoing packets. + /// + /// If set to `true`, the local IP address which was target of the initial + /// QUIC packet will be used as the source IP address for all outoing packets. + /// + /// If set to `false`, the source IP address will be address the socket is + /// bound to. + /// + /// The setting is only having an effect on platforms where + /// [`quinn_proto::Connection::local_ip()`] returns a local IP address. + pub fn send_from_initial_ip(&mut self, value: bool) -> &mut Self { + self.send_from_initial_ip = value; + self + } } #[cfg(feature = "rustls")] @@ -520,6 +540,7 @@ where .field("retry_token_lifetime", &self.retry_token_lifetime) .field("concurrent_connections", &self.concurrent_connections) .field("migration", &self.migration) + .field("send_from_initial_ip", &self.send_from_initial_ip) .finish() } } @@ -552,6 +573,7 @@ where retry_token_lifetime: self.retry_token_lifetime, concurrent_connections: self.concurrent_connections, migration: self.migration, + send_from_initial_ip: self.send_from_initial_ip, } } } diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index d390aacc4e..068ecb484b 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -379,6 +379,7 @@ where destination, contents: buf, ecn: None, + src_ip: self.src_ip(), }); } } @@ -559,6 +560,7 @@ where } else { None }, + src_ip: self.src_ip(), }) } @@ -717,6 +719,16 @@ where pad } + /// Returns the source IP used in outgoing packets. + /// + /// Returns `None` if no specific IP should be used + fn src_ip(&self) -> Option { + match self.server_config.as_ref()?.send_from_initial_ip { + true => self.local_ip, + false => None, + } + } + /// Indicates whether we're a server that hasn't validated the peer's address and hasn't /// received enough data from the peer to permit additional sending fn anti_amplification_blocked(&self) -> bool { diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 8717dd4c4c..7a5f019e61 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -189,6 +189,7 @@ where destination: remote, ecn: None, contents: buf, + src_ip: self.src_ip(local_ip), }); return None; } @@ -252,7 +253,7 @@ where if !self.is_server() { debug!("packet for unrecognized connection {}", dst_cid); - self.stateless_reset(datagram_len, remote, &dst_cid); + self.stateless_reset(datagram_len, remote, local_ip, &dst_cid); return None; } @@ -287,7 +288,7 @@ where // if !dst_cid.is_empty() { - self.stateless_reset(datagram_len, remote, &dst_cid); + self.stateless_reset(datagram_len, remote, local_ip, &dst_cid); } else { trace!("dropping unrecognized short packet without ID"); } @@ -298,6 +299,7 @@ where &mut self, inciting_dgram_len: usize, remote: SocketAddr, + local_ip: Option, dst_cid: &ConnectionId, ) { /// Minimum amount of padding for the stateless reset to look like a short-header packet @@ -334,6 +336,7 @@ where destination: remote, ecn: None, contents: buf, + src_ip: self.src_ip(local_ip), }); } @@ -389,6 +392,16 @@ where ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now)) } + /// Returns the source IP used in outgoing packets. + /// + /// Returns `None` if no specific IP should be used + fn src_ip(&self, local_ip: Option) -> Option { + match self.server_config.as_ref()?.send_from_initial_ip { + true => local_ip, + false => None, + } + } + fn new_cid(&mut self) -> ConnectionId { loop { let cid = self.local_cid_generator.generate_cid(); @@ -533,6 +546,7 @@ where debug!("refusing connection"); self.initial_close( remote, + local_ip, crypto, &src_cid, &temp_loc_cid, @@ -551,6 +565,7 @@ where ); self.initial_close( remote, + local_ip, crypto, &src_cid, &temp_loc_cid, @@ -587,6 +602,7 @@ where destination: remote, ecn: None, contents: buf, + src_ip: self.src_ip(local_ip), }); return None; } @@ -605,6 +621,7 @@ where debug!("rejecting invalid stateless retry token"); self.initial_close( remote, + local_ip, crypto, &src_cid, &temp_loc_cid, @@ -642,7 +659,7 @@ where debug!("handshake failed: {}", e); self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained)); if let ConnectionError::TransportError(e) = e { - self.initial_close(remote, crypto, &src_cid, &temp_loc_cid, e); + self.initial_close(remote, local_ip, crypto, &src_cid, &temp_loc_cid, e); } None } @@ -652,6 +669,7 @@ where fn initial_close( &mut self, destination: SocketAddr, + local_ip: Option, crypto: &Keys, remote_id: &ConnectionId, local_id: &ConnectionId, @@ -679,6 +697,7 @@ where destination, ecn: None, contents: buf, + src_ip: self.src_ip(local_ip), }) } diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 05a41ac62f..f7a8e619c8 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -18,7 +18,13 @@ #![allow(clippy::cognitive_complexity)] #![allow(clippy::too_many_arguments)] -use std::{convert::TryInto, fmt, net::SocketAddr, ops, time::Duration}; +use std::{ + convert::TryInto, + fmt, + net::{IpAddr, SocketAddr}, + ops, + time::Duration, +}; mod cid_queue; #[doc(hidden)] @@ -274,6 +280,8 @@ pub struct Transmit { pub ecn: Option, /// Contents of the datagram pub contents: Vec, + /// Optional source IP address for the datagram + pub src_ip: Option, } // diff --git a/quinn/src/builders.rs b/quinn/src/builders.rs index 35488801e1..6bc0245262 100644 --- a/quinn/src/builders.rs +++ b/quinn/src/builders.rs @@ -168,6 +168,21 @@ where self.config.use_stateless_retry(enabled); self } + + /// Whether to use the initial local IP address for outgoing packets. + /// + /// If set to `true`, the local IP address which was target of the initial + /// QUIC packet will be used as the source IP address for all outoing packets. + /// + /// If set to `false`, the source IP address will be address the socket is + /// bound to. + /// + /// The setting is only having an effect on platforms where + /// [`quinn::Connection::local_ip()`] returns a local IP address. + pub fn send_from_initial_ip(&mut self, value: bool) -> &mut Self { + self.config.send_from_initial_ip(value); + self + } } #[cfg(feature = "rustls")] diff --git a/quinn/src/platform/unix.rs b/quinn/src/platform/unix.rs index 5257985b6d..62dc7cba46 100644 --- a/quinn/src/platform/unix.rs +++ b/quinn/src/platform/unix.rs @@ -2,7 +2,7 @@ use std::{ io, io::IoSliceMut, mem::{self, MaybeUninit}, - net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::unix::io::AsRawFd, ptr, }; @@ -290,12 +290,40 @@ fn prepare_msg( hdr.msg_control = ctrl.0.as_mut_ptr() as _; hdr.msg_controllen = CMSG_LEN as _; let mut encoder = unsafe { cmsg::Encoder::new(hdr) }; + let ecn = transmit.ecn.map_or(0, |x| x as libc::c_int); if transmit.destination.is_ipv4() { encoder.push(libc::IPPROTO_IP, libc::IP_TOS, ecn as IpTosTy); } else { encoder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn); } + + if let Some(ip) = &transmit.src_ip { + if cfg!(target_os = "linux") { + match ip { + IpAddr::V4(v4) => { + let pktinfo = libc::in_pktinfo { + ipi_ifindex: 0, + ipi_spec_dst: libc::in_addr { s_addr: 0 }, + ipi_addr: unsafe { + *(v4 as *const Ipv4Addr as *const () as *const libc::in_addr) + }, + }; + encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo); + } + IpAddr::V6(v6) => { + let pktinfo = libc::in6_pktinfo { + ipi6_ifindex: 0, + ipi6_addr: unsafe { + *(v6 as *const Ipv6Addr as *const () as *const libc::in6_addr) + }, + }; + encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo); + } + } + } + } + encoder.finish(); }