diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 4989e63932..80ece6eda3 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -2,7 +2,7 @@ use std::{ cmp, collections::{BTreeMap, HashSet, VecDeque}, fmt, io, mem, - net::SocketAddr, + net::{IpAddr, SocketAddr}, sync::Arc, time::{Duration, Instant}, }; @@ -111,6 +111,10 @@ where /// cid length used to decode short packet local_cid_len: usize, + /// The "real" local IP address which was was used to receive the initial packet. + /// This is only populated for the server case, and if known + initial_local_ip: Option, + path: PathData, prev_path: Option, state: State, @@ -208,6 +212,7 @@ where loc_cid: ConnectionId, rem_cid: ConnectionId, remote: SocketAddr, + initial_local_ip: Option, crypto: S, now: Instant, local_cid_len: usize, @@ -244,6 +249,7 @@ where config.congestion_controller_factory.build(now), now, ), + initial_local_ip, prev_path: None, side, state, @@ -1062,6 +1068,14 @@ where self.path.remote } + /// The local IP address which was used when the peer established + /// the connection + /// + /// This is `None` if either not known or if this is the client side of a connection. + pub fn initial_local_ip(&self) -> Option { + self.initial_local_ip + } + /// Current best estimate of this connection's latency (round-trip-time) pub fn rtt(&self) -> Duration { self.path.rtt.get() diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index abfd9084c0..ae907bb2e6 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -2,7 +2,7 @@ use std::{ collections::{HashMap, VecDeque}, convert::TryFrom, fmt, iter, - net::SocketAddr, + net::{IpAddr, SocketAddr}, ops::{Index, IndexMut}, sync::Arc, time::{Duration, Instant, SystemTime}, @@ -154,6 +154,7 @@ where &mut self, now: Instant, remote: SocketAddr, + local_ip: Option, ecn: Option, data: BytesMut, ) -> Option<(ConnectionHandle, DatagramEvent)> { @@ -273,7 +274,7 @@ where let crypto = S::initial_keys(&dst_cid, Side::Server); return match first_decode.finish(Some(&crypto.header.remote)) { Ok(packet) => self - .handle_first_packet(now, remote, ecn, packet, remaining, &crypto) + .handle_first_packet(now, remote, local_ip, ecn, packet, remaining, &crypto) .map(|(ch, conn)| (ch, DatagramEvent::NewConnection(conn))), Err(e) => { trace!("unable to decode initial packet: {}", e); @@ -357,6 +358,7 @@ where remote_id, remote_id, remote, + None, ConnectionOpts::Client { config, server_name: server_name.into(), @@ -399,6 +401,7 @@ where init_cid: ConnectionId, rem_cid: ConnectionId, remote: SocketAddr, + local_ip: Option, opts: ConnectionOpts, now: Instant, ) -> Result<(ConnectionHandle, Connection), ConnectError> { @@ -454,6 +457,7 @@ where loc_cid, rem_cid, remote, + local_ip, tls, now, self.local_cid_generator.cid_len(), @@ -479,6 +483,7 @@ where &mut self, now: Instant, remote: SocketAddr, + local_ip: Option, ecn: Option, mut packet: Packet, rest: Option, @@ -614,6 +619,7 @@ where dst_cid, src_cid, remote, + local_ip, ConnectionOpts::Server { retry_src_cid, orig_dst_cid, diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index acb2a4d2dd..c12285411b 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -28,6 +28,7 @@ fn version_negotiate_server() { now, client_addr, None, + None, // Long-header packet with reserved version number hex!("80 0a1a2a3a 04 00000000 04 00000000 00")[..].into(), ); @@ -64,6 +65,7 @@ fn version_negotiate_client() { now, server_addr, None, + None, // Version negotiation packet for reserved version hex!( "80 00000000 04 00000000 04 00000000 diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 87574393c4..a16b88c516 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -226,7 +226,7 @@ impl TestEndpoint { let (_, ecn, packet) = self.inbound.pop_front().unwrap(); if let Some((ch, event)) = self.endpoint - .handle(now, remote, ecn, packet.as_slice().into()) + .handle(now, remote, None, ecn, packet.as_slice().into()) { match event { DatagramEvent::NewConnection(conn) => { diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index a859ba0a36..838f854690 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -3,7 +3,7 @@ use std::{ fmt, future::Future, mem, - net::SocketAddr, + net::{IpAddr, SocketAddr}, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, @@ -127,6 +127,17 @@ where .expect("spurious handshake data ready notification") }) } + + /// The local IP address which was used when the peer established + /// the connection + /// + /// This is `None` if either not known or if this is the client side of a connection. + pub fn initial_local_ip(&self) -> Option { + let conn = self.conn.as_ref().unwrap(); + let inner = conn.lock().unwrap(); + + inner.inner.initial_local_ip() + } } impl Future for Connecting @@ -397,6 +408,14 @@ where self.0.lock().unwrap().inner.remote_address() } + /// The local IP address which was used when the peer established + /// the connection + /// + /// This is `None` if either not known or if this is the client side of a connection. + pub fn initial_local_ip(&self) -> Option { + self.0.lock().unwrap().inner.initial_local_ip() + } + /// Current best estimate of this connection's latency (round-trip-time) pub fn rtt(&self) -> Duration { self.0.lock().unwrap().inner.rtt() diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 1ac17bef37..dc355c96a8 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -279,7 +279,10 @@ where recvd += msgs; for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) { let data = buf[0..meta.len].into(); - match self.inner.handle(now, meta.addr, meta.ecn, data) { + match self + .inner + .handle(now, meta.addr, meta.dest_ip, meta.ecn, data) + { Some((handle, DatagramEvent::NewConnection(conn))) => { let conn = self.connections.insert(handle, conn); self.incoming.push_back(conn); diff --git a/quinn/src/platform/fallback.rs b/quinn/src/platform/fallback.rs index d9d3285b5c..b4ca3a80d7 100644 --- a/quinn/src/platform/fallback.rs +++ b/quinn/src/platform/fallback.rs @@ -38,6 +38,7 @@ impl super::UdpExt for UdpSocket { len, addr, ecn: None, + dest_ip: None, }; Ok(1) } diff --git a/quinn/src/platform/unix.rs b/quinn/src/platform/unix.rs index 517514fdcf..c2a58850bd 100644 --- a/quinn/src/platform/unix.rs +++ b/quinn/src/platform/unix.rs @@ -2,7 +2,7 @@ use std::{ io, io::IoSliceMut, mem::{self, MaybeUninit}, - net::{SocketAddr, SocketAddrV4, SocketAddrV6}, + net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}, os::unix::io::AsRawFd, ptr, }; @@ -29,8 +29,17 @@ impl super::UdpExt for UdpSocket { mem::size_of::(), mem::size_of::() ); + + let mut cmsg_platform_space = 0; + if cfg!(target_os = "linux") { + cmsg_platform_space += + unsafe { libc::CMSG_SPACE(mem::size_of::() as _) as usize }; + } + assert!( - CMSG_LEN >= unsafe { libc::CMSG_SPACE(mem::size_of::() as _) as usize } + CMSG_LEN + >= unsafe { libc::CMSG_SPACE(mem::size_of::() as _) as usize } + + cmsg_platform_space ); assert!( mem::align_of::() <= mem::align_of::>(), @@ -72,6 +81,20 @@ impl super::UdpExt for UdpSocket { if rc == -1 { return Err(io::Error::last_os_error()); } + + let on: libc::c_int = 1; + let rc = unsafe { + libc::setsockopt( + self.as_raw_fd(), + libc::IPPROTO_IP, + libc::IP_PKTINFO, + &on as *const _ as _, + mem::size_of_val(&on) as _, + ) + }; + if rc == -1 { + return Err(io::Error::last_os_error()); + } } else if addr.is_ipv6() { let rc = unsafe { libc::setsockopt( @@ -85,6 +108,20 @@ impl super::UdpExt for UdpSocket { if rc == -1 { return Err(io::Error::last_os_error()); } + + let on: libc::c_int = 1; + let rc = unsafe { + libc::setsockopt( + self.as_raw_fd(), + libc::IPPROTO_IPV6, + libc::IPV6_RECVPKTINFO, + &on as *const _ as _, + mem::size_of_val(&on) as _, + ) + }; + if rc == -1 { + return Err(io::Error::last_os_error()); + } } } if addr.is_ipv6() { @@ -230,7 +267,7 @@ impl super::UdpExt for UdpSocket { } } -const CMSG_LEN: usize = 24; +const CMSG_LEN: usize = 64; fn prepare_msg( transmit: &Transmit, @@ -283,11 +320,15 @@ fn decode_recv( len: usize, ) -> RecvMeta { let name = unsafe { name.assume_init() }; - let ecn_bits = match unsafe { cmsg::Iter::new(&hdr).next() } { - Some(cmsg) => match (cmsg.cmsg_level, cmsg.cmsg_type) { + let mut ecn_bits = 0; + let mut dest_ip = None; + + let cmsg_iter = unsafe { cmsg::Iter::new(&hdr) }; + for cmsg in cmsg_iter { + match (cmsg.cmsg_level, cmsg.cmsg_type) { // FreeBSD uses IP_RECVTOS here, and we can be liberal because cmsgs are opt-in. (libc::IPPROTO_IP, libc::IP_TOS) | (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe { - cmsg::decode::(cmsg) + ecn_bits = cmsg::decode::(cmsg); }, (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe { // Temporary hack around broken macos ABI. Remove once upstream fixes it. @@ -295,24 +336,34 @@ fn decode_recv( if cfg!(target_os = "macos") && cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::() as _) as usize { - cmsg::decode::(cmsg) + ecn_bits = cmsg::decode::(cmsg); } else { - cmsg::decode::(cmsg) as u8 + ecn_bits = cmsg::decode::(cmsg) as u8; } }, - _ => 0, - }, - None => 0, - }; + (libc::IPPROTO_IP, libc::IP_PKTINFO) => unsafe { + let pktinfo = cmsg::decode::(cmsg); + dest_ip = Some(IpAddr::V4(ptr::read(&pktinfo.ipi_addr as *const _ as _))); + }, + (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => unsafe { + let pktinfo = cmsg::decode::(cmsg); + dest_ip = Some(IpAddr::V6(ptr::read(&pktinfo.ipi6_addr as *const _ as _))); + }, + _ => {} + } + } + let addr = match libc::c_int::from(name.ss_family) { libc::AF_INET => unsafe { SocketAddr::V4(ptr::read(&name as *const _ as _)) }, libc::AF_INET6 => unsafe { SocketAddr::V6(ptr::read(&name as *const _ as _)) }, _ => unreachable!(), }; + RecvMeta { len, addr, ecn: EcnCodepoint::from_bits(ecn_bits), + dest_ip, } } diff --git a/quinn/src/udp.rs b/quinn/src/udp.rs index 82ac45e180..da8e6a07ed 100644 --- a/quinn/src/udp.rs +++ b/quinn/src/udp.rs @@ -1,7 +1,7 @@ use std::{ io, io::IoSliceMut, - net::{Ipv6Addr, SocketAddr}, + net::{IpAddr, Ipv6Addr, SocketAddr}, task::{Context, Poll}, }; @@ -74,6 +74,8 @@ pub struct RecvMeta { pub addr: SocketAddr, pub len: usize, pub ecn: Option, + /// The destination IP address which was encoded in this datagram + pub dest_ip: Option, } impl Default for RecvMeta { @@ -83,6 +85,7 @@ impl Default for RecvMeta { addr: SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), len: 0, ecn: None, + dest_ip: None, } } }