Skip to content

Commit

Permalink
Allow to lookup the initial local IP address
Browse files Browse the repository at this point in the history
When a listener is bound to multiple network interfaces (e.g. `::0`),
it is not obvious which IP the peer used to send a packet. We however
might need this information to send packets back to the peer with the
same source address.

This problem is described in #508.

This change makes the destination IP address which was used to send
the initial packet available in the `Conneting` and `Connection` types.

The information is far available only on Linux due to missing test on
other platforms.
  • Loading branch information
Matthias Einwag committed Dec 22, 2020
1 parent cd229d3 commit e50dd45
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 19 deletions.
20 changes: 19 additions & 1 deletion quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
cmp,
collections::{BTreeMap, HashSet, VecDeque},
fmt, io, mem,
net::SocketAddr,
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
Expand Down Expand Up @@ -111,6 +111,10 @@ where

/// cid length used to decode short packet
local_cid_len: usize,
/// The "real" local IP address which was was used to receive the initial packet.
/// This is only populated for the server case, and if known
local_ip: Option<IpAddr>,

path: PathData,
prev_path: Option<PathData>,
state: State,
Expand Down Expand Up @@ -208,6 +212,7 @@ where
loc_cid: ConnectionId,
rem_cid: ConnectionId,
remote: SocketAddr,
local_ip: Option<IpAddr>,
crypto: S,
now: Instant,
local_cid_len: usize,
Expand Down Expand Up @@ -244,6 +249,7 @@ where
config.congestion_controller_factory.build(now),
now,
),
local_ip,
prev_path: None,
side,
state,
Expand Down Expand Up @@ -1062,6 +1068,18 @@ where
self.path.remote
}

/// The local IP address which was used when the peer established
/// the connection
///
/// This can be different from the address the endpoint is bound to, in case
/// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`.
///
/// This will return `None` for clients, as well for servers if capturing
/// the local IP from incoming packets is not supported on the current platform.
pub fn local_ip(&self) -> Option<IpAddr> {
self.local_ip
}

/// Current best estimate of this connection's latency (round-trip-time)
pub fn rtt(&self) -> Duration {
self.path.rtt.get()
Expand Down
10 changes: 8 additions & 2 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
collections::{HashMap, VecDeque},
convert::TryFrom,
fmt, iter,
net::SocketAddr,
net::{IpAddr, SocketAddr},
ops::{Index, IndexMut},
sync::Arc,
time::{Duration, Instant, SystemTime},
Expand Down Expand Up @@ -154,6 +154,7 @@ where
&mut self,
now: Instant,
remote: SocketAddr,
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
data: BytesMut,
) -> Option<(ConnectionHandle, DatagramEvent<S>)> {
Expand Down Expand Up @@ -273,7 +274,7 @@ where
let crypto = S::initial_keys(&dst_cid, Side::Server);
return match first_decode.finish(Some(&crypto.header.remote)) {
Ok(packet) => self
.handle_first_packet(now, remote, ecn, packet, remaining, &crypto)
.handle_first_packet(now, remote, local_ip, ecn, packet, remaining, &crypto)
.map(|(ch, conn)| (ch, DatagramEvent::NewConnection(conn))),
Err(e) => {
trace!("unable to decode initial packet: {}", e);
Expand Down Expand Up @@ -357,6 +358,7 @@ where
remote_id,
remote_id,
remote,
None,
ConnectionOpts::Client {
config,
server_name: server_name.into(),
Expand Down Expand Up @@ -399,6 +401,7 @@ where
init_cid: ConnectionId,
rem_cid: ConnectionId,
remote: SocketAddr,
local_ip: Option<IpAddr>,
opts: ConnectionOpts<S>,
now: Instant,
) -> Result<(ConnectionHandle, Connection<S>), ConnectError> {
Expand Down Expand Up @@ -454,6 +457,7 @@ where
loc_cid,
rem_cid,
remote,
local_ip,
tls,
now,
self.local_cid_generator.cid_len(),
Expand All @@ -479,6 +483,7 @@ where
&mut self,
now: Instant,
remote: SocketAddr,
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
mut packet: Packet,
rest: Option<BytesMut>,
Expand Down Expand Up @@ -614,6 +619,7 @@ where
dst_cid,
src_cid,
remote,
local_ip,
ConnectionOpts::Server {
retry_src_cid,
orig_dst_cid,
Expand Down
2 changes: 2 additions & 0 deletions quinn-proto/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ fn version_negotiate_server() {
now,
client_addr,
None,
None,
// Long-header packet with reserved version number
hex!("80 0a1a2a3a 04 00000000 04 00000000 00")[..].into(),
);
Expand Down Expand Up @@ -64,6 +65,7 @@ fn version_negotiate_client() {
now,
server_addr,
None,
None,
// Version negotiation packet for reserved version
hex!(
"80 00000000 04 00000000 04 00000000
Expand Down
2 changes: 1 addition & 1 deletion quinn-proto/src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ impl TestEndpoint {
let (_, ecn, packet) = self.inbound.pop_front().unwrap();
if let Some((ch, event)) =
self.endpoint
.handle(now, remote, ecn, packet.as_slice().into())
.handle(now, remote, None, ecn, packet.as_slice().into())
{
match event {
DatagramEvent::NewConnection(conn) => {
Expand Down
29 changes: 28 additions & 1 deletion quinn/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
fmt,
future::Future,
mem,
net::SocketAddr,
net::{IpAddr, SocketAddr},
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
Expand Down Expand Up @@ -127,6 +127,21 @@ where
.expect("spurious handshake data ready notification")
})
}

/// The local IP address which was used when the peer established
/// the connection
///
/// This can be different from the address the endpoint is bound to, in case
/// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`.
///
/// This will return `None` for clients, as well for servers if capturing
/// the local IP from incoming packets is not supported on the current platform.
pub fn local_ip(&self) -> Option<IpAddr> {
let conn = self.conn.as_ref().unwrap();
let inner = conn.lock().unwrap();

inner.inner.local_ip()
}
}

impl<S> Future for Connecting<S>
Expand Down Expand Up @@ -397,6 +412,18 @@ where
self.0.lock().unwrap().inner.remote_address()
}

/// The local IP address which was used when the peer established
/// the connection
///
/// This can be different from the address the endpoint is bound to, in case
/// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`.
///
/// This will return `None` for clients, as well for servers if capturing
/// the local IP from incoming packets is not supported on the current platform.
pub fn local_ip(&self) -> Option<IpAddr> {
self.0.lock().unwrap().inner.local_ip()
}

/// Current best estimate of this connection's latency (round-trip-time)
pub fn rtt(&self) -> Duration {
self.0.lock().unwrap().inner.rtt()
Expand Down
5 changes: 4 additions & 1 deletion quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ where
recvd += msgs;
for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
let data = buf[0..meta.len].into();
match self.inner.handle(now, meta.addr, meta.ecn, data) {
match self
.inner
.handle(now, meta.addr, meta.dest_ip, meta.ecn, data)
{
Some((handle, DatagramEvent::NewConnection(conn))) => {
let conn = self.connections.insert(handle, conn);
self.incoming.push_back(conn);
Expand Down
1 change: 1 addition & 0 deletions quinn/src/platform/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ impl super::UdpExt for UdpSocket {
len,
addr,
ecn: None,
dest_ip: None,
};
Ok(1)
}
Expand Down
75 changes: 63 additions & 12 deletions quinn/src/platform/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
io,
io::IoSliceMut,
mem::{self, MaybeUninit},
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6},
os::unix::io::AsRawFd,
ptr,
};
Expand All @@ -29,8 +29,17 @@ impl super::UdpExt for UdpSocket {
mem::size_of::<SocketAddrV6>(),
mem::size_of::<libc::sockaddr_in6>()
);

let mut cmsg_platform_space = 0;
if cfg!(target_os = "linux") {
cmsg_platform_space +=
unsafe { libc::CMSG_SPACE(mem::size_of::<libc::in6_pktinfo>() as _) as usize };
}

assert!(
CMSG_LEN >= unsafe { libc::CMSG_SPACE(mem::size_of::<libc::c_int>() as _) as usize }
CMSG_LEN
>= unsafe { libc::CMSG_SPACE(mem::size_of::<libc::c_int>() as _) as usize }
+ cmsg_platform_space
);
assert!(
mem::align_of::<libc::cmsghdr>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
Expand Down Expand Up @@ -72,6 +81,20 @@ impl super::UdpExt for UdpSocket {
if rc == -1 {
return Err(io::Error::last_os_error());
}

let on: libc::c_int = 1;
let rc = unsafe {
libc::setsockopt(
self.as_raw_fd(),
libc::IPPROTO_IP,
libc::IP_PKTINFO,
&on as *const _ as _,
mem::size_of_val(&on) as _,
)
};
if rc == -1 {
return Err(io::Error::last_os_error());
}
} else if addr.is_ipv6() {
let rc = unsafe {
libc::setsockopt(
Expand All @@ -85,6 +108,20 @@ impl super::UdpExt for UdpSocket {
if rc == -1 {
return Err(io::Error::last_os_error());
}

let on: libc::c_int = 1;
let rc = unsafe {
libc::setsockopt(
self.as_raw_fd(),
libc::IPPROTO_IPV6,
libc::IPV6_RECVPKTINFO,
&on as *const _ as _,
mem::size_of_val(&on) as _,
)
};
if rc == -1 {
return Err(io::Error::last_os_error());
}
}
}
if addr.is_ipv6() {
Expand Down Expand Up @@ -230,7 +267,7 @@ impl super::UdpExt for UdpSocket {
}
}

const CMSG_LEN: usize = 24;
const CMSG_LEN: usize = 64;

fn prepare_msg(
transmit: &Transmit,
Expand Down Expand Up @@ -283,36 +320,50 @@ fn decode_recv(
len: usize,
) -> RecvMeta {
let name = unsafe { name.assume_init() };
let ecn_bits = match unsafe { cmsg::Iter::new(&hdr).next() } {
Some(cmsg) => match (cmsg.cmsg_level, cmsg.cmsg_type) {
let mut ecn_bits = 0;
let mut dest_ip = None;

let cmsg_iter = unsafe { cmsg::Iter::new(&hdr) };
for cmsg in cmsg_iter {
match (cmsg.cmsg_level, cmsg.cmsg_type) {
// FreeBSD uses IP_RECVTOS here, and we can be liberal because cmsgs are opt-in.
(libc::IPPROTO_IP, libc::IP_TOS) | (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe {
cmsg::decode::<u8>(cmsg)
ecn_bits = cmsg::decode::<u8>(cmsg);
},
(libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe {
// Temporary hack around broken macos ABI. Remove once upstream fixes it.
// https://bugreport.apple.com/web/?problemID=48761855
if cfg!(target_os = "macos")
&& cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::<u8>() as _) as usize
{
cmsg::decode::<u8>(cmsg)
ecn_bits = cmsg::decode::<u8>(cmsg);
} else {
cmsg::decode::<libc::c_int>(cmsg) as u8
ecn_bits = cmsg::decode::<libc::c_int>(cmsg) as u8;
}
},
_ => 0,
},
None => 0,
};
(libc::IPPROTO_IP, libc::IP_PKTINFO) => unsafe {
let pktinfo = cmsg::decode::<libc::in_pktinfo>(cmsg);
dest_ip = Some(IpAddr::V4(ptr::read(&pktinfo.ipi_addr as *const _ as _)));
},
(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => unsafe {
let pktinfo = cmsg::decode::<libc::in6_pktinfo>(cmsg);
dest_ip = Some(IpAddr::V6(ptr::read(&pktinfo.ipi6_addr as *const _ as _)));
},
_ => {}
}
}

let addr = match libc::c_int::from(name.ss_family) {
libc::AF_INET => unsafe { SocketAddr::V4(ptr::read(&name as *const _ as _)) },
libc::AF_INET6 => unsafe { SocketAddr::V6(ptr::read(&name as *const _ as _)) },
_ => unreachable!(),
};

RecvMeta {
len,
addr,
ecn: EcnCodepoint::from_bits(ecn_bits),
dest_ip,
}
}

Expand Down
8 changes: 8 additions & 0 deletions quinn/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,14 @@ fn run_echo(client_addr: SocketAddr, server_addr: SocketAddr) {

let handle = runtime.spawn(async move {
let incoming = server_incoming.next().await.unwrap();

if cfg!(target_os = "linux") {
let local_ip = incoming.local_ip().expect("Local IP must be available");
assert!(local_ip.is_loopback());
} else {
assert_eq!(None, incoming.local_ip());
}

let new_conn = incoming.instrument(info_span!("server")).await.unwrap();
tokio::spawn(
new_conn
Expand Down
5 changes: 4 additions & 1 deletion quinn/src/udp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{
io,
io::IoSliceMut,
net::{Ipv6Addr, SocketAddr},
net::{IpAddr, Ipv6Addr, SocketAddr},
task::{Context, Poll},
};

Expand Down Expand Up @@ -74,6 +74,8 @@ pub struct RecvMeta {
pub addr: SocketAddr,
pub len: usize,
pub ecn: Option<EcnCodepoint>,
/// The destination IP address which was encoded in this datagram
pub dest_ip: Option<IpAddr>,
}

impl Default for RecvMeta {
Expand All @@ -83,6 +85,7 @@ impl Default for RecvMeta {
addr: SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
len: 0,
ecn: None,
dest_ip: None,
}
}
}

0 comments on commit e50dd45

Please sign in to comment.