diff --git a/quinn-udp/src/lib.rs b/quinn-udp/src/lib.rs index fd232b023..e4ef45b41 100644 --- a/quinn-udp/src/lib.rs +++ b/quinn-udp/src/lib.rs @@ -68,6 +68,7 @@ pub struct RecvMeta { pub ecn: Option, /// The destination IP address which was encoded in this datagram pub dst_ip: Option, + pub gso_size: Option, } impl Default for RecvMeta { @@ -78,6 +79,7 @@ impl Default for RecvMeta { len: 0, ecn: None, dst_ip: None, + gso_size: None, } } } diff --git a/quinn-udp/src/unix.rs b/quinn-udp/src/unix.rs index a3746ac1f..01395bc25 100644 --- a/quinn-udp/src/unix.rs +++ b/quinn-udp/src/unix.rs @@ -52,6 +52,7 @@ pub struct UdpSocket { impl UdpSocket { pub fn from_std(socket: std::net::UdpSocket) -> io::Result { socket.set_nonblocking(true)?; + init(&socket)?; let now = Instant::now(); Ok(UdpSocket { @@ -134,6 +135,20 @@ fn init(io: &std::net::UdpSocket) -> io::Result<()> { } #[cfg(target_os = "linux")] { + let on: libc::c_int = 1; + let rc = unsafe { + libc::setsockopt( + io.as_raw_fd(), + libc::SOL_UDP, + libc::UDP_GRO, + &on as *const _ as _, + mem::size_of_val(&on) as _, + ) + }; + if rc == -1 { + return Err(io::Error::last_os_error()); + } + if addr.is_ipv4() { let rc = unsafe { libc::setsockopt( @@ -500,6 +515,7 @@ fn decode_recv( let name = unsafe { name.assume_init() }; let mut ecn_bits = 0; let mut dst_ip = None; + let mut gso_size = None; let cmsg_iter = unsafe { cmsg::Iter::new(hdr) }; for cmsg in cmsg_iter { @@ -527,6 +543,10 @@ fn decode_recv( let pktinfo = cmsg::decode::(cmsg); dst_ip = Some(IpAddr::V6(ptr::read(&pktinfo.ipi6_addr as *const _ as _))); }, + #[cfg(target_os = "linux")] + (libc::SOL_UDP, libc::UDP_GRO) => unsafe { + gso_size = Some(cmsg::decode::(cmsg) as usize); + }, _ => {} } } @@ -542,6 +562,7 @@ fn decode_recv( addr, ecn: EcnCodepoint::from_bits(ecn_bits), dst_ip, + gso_size, } } diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 698b706e7..5b0f5d8c7 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -12,7 +12,7 @@ use std::{ time::Instant, }; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use proto::{ self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig, }; @@ -346,27 +346,31 @@ impl EndpointInner { Poll::Ready(Ok(msgs)) => { self.recv_limiter.record_work(msgs); for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) { - let data = buf[0..meta.len].into(); - match self - .inner - .handle(now, meta.addr, meta.dst_ip, meta.ecn, data) - { - Some((handle, DatagramEvent::NewConnection(conn))) => { - let conn = - self.connections - .insert(handle, conn, self.udp_state.clone()); - self.incoming.push_back(conn); + for buf in buf[0..meta.len].chunks(meta.gso_size.unwrap_or(meta.len)) { + let data: BytesMut = buf.into(); + match self + .inner + .handle(now, meta.addr, meta.dst_ip, meta.ecn, data) + { + Some((handle, DatagramEvent::NewConnection(conn))) => { + let conn = self.connections.insert( + handle, + conn, + self.udp_state.clone(), + ); + self.incoming.push_back(conn); + } + Some((handle, DatagramEvent::ConnectionEvent(event))) => { + // Ignoring errors from dropped connections that haven't yet been cleaned up + let _ = self + .connections + .senders + .get_mut(&handle) + .unwrap() + .send(ConnectionEvent::Proto(event)); + } + None => {} } - Some((handle, DatagramEvent::ConnectionEvent(event))) => { - // Ignoring errors from dropped connections that haven't yet been cleaned up - let _ = self - .connections - .senders - .get_mut(&handle) - .unwrap() - .send(ConnectionEvent::Proto(event)); - } - None => {} } } } @@ -565,8 +569,12 @@ pub(crate) struct EndpointRef(Arc>); impl EndpointRef { pub(crate) fn new(socket: UdpSocket, inner: proto::Endpoint, ipv6: bool) -> Self { + // FIXME: don't hardcode the GRO size let recv_buf = - vec![0; inner.config().get_max_udp_payload_size().min(64 * 1024) as usize * BATCH_SIZE]; + vec![ + 0; + inner.config().get_max_udp_payload_size().min(64 * 1024) as usize * 10 * BATCH_SIZE + ]; let (sender, events) = mpsc::unbounded_channel(); Self(Arc::new(Mutex::new(EndpointInner { socket,