Skip to content

Commit

Permalink
Windows: move BSD socket shims to netc
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisDenton committed Jul 14, 2024
1 parent 0968298 commit b221712
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 111 deletions.
99 changes: 1 addition & 98 deletions library/std/src/sys/pal/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::ffi::CStr;
use crate::mem;
use crate::num::NonZero;
pub use crate::os::raw::c_int;
use crate::os::raw::{c_char, c_long, c_longlong, c_uint, c_ulong, c_ushort, c_void};
use crate::os::raw::{c_long, c_longlong, c_uint, c_ulong, c_ushort, c_void};
use crate::os::windows::io::{AsRawHandle, BorrowedHandle};
use crate::ptr;

Expand All @@ -25,9 +25,7 @@ pub type LARGE_INTEGER = c_longlong;
pub type LONG = c_long;
pub type UINT = c_uint;
pub type WCHAR = u16;
pub type USHORT = c_ushort;
pub type SIZE_T = usize;
pub type CHAR = c_char;
pub type ULONG = c_ulong;

pub type LPCVOID = *const c_void;
Expand All @@ -40,12 +38,6 @@ pub type LPWSTR = *mut WCHAR;
#[cfg(target_vendor = "win7")]
pub type PSRWLOCK = *mut SRWLOCK;

pub type socklen_t = c_int;
pub type ADDRESS_FAMILY = USHORT;
pub use FD_SET as fd_set;
pub use LINGER as linger;
pub use TIMEVAL as timeval;

pub const INVALID_HANDLE_VALUE: HANDLE = ::core::ptr::without_provenance_mut(-1i32 as _);

// https://learn.microsoft.com/en-us/cpp/c-runtime-library/exit-success-exit-failure?view=msvc-170
Expand All @@ -63,20 +55,6 @@ pub const INIT_ONCE_STATIC_INIT: INIT_ONCE = INIT_ONCE { Ptr: ptr::null_mut() };
pub const OBJ_DONT_REPARSE: u32 = windows_sys::OBJ_DONT_REPARSE as u32;
pub const FRS_ERR_SYSVOL_POPULATE_TIMEOUT: u32 =
windows_sys::FRS_ERR_SYSVOL_POPULATE_TIMEOUT as u32;
pub const AF_INET: c_int = windows_sys::AF_INET as c_int;
pub const AF_INET6: c_int = windows_sys::AF_INET6 as c_int;

#[repr(C)]
pub struct ip_mreq {
pub imr_multiaddr: in_addr,
pub imr_interface: in_addr,
}

#[repr(C)]
pub struct ipv6_mreq {
pub ipv6mr_multiaddr: in6_addr,
pub ipv6mr_interface: c_uint,
}

// Equivalent to the `NT_SUCCESS` C preprocessor macro.
// See: https://docs.microsoft.com/en-us/windows-hardware/drivers/kernel/using-ntstatus-values
Expand Down Expand Up @@ -148,45 +126,6 @@ pub struct MOUNT_POINT_REPARSE_BUFFER {
pub PathBuffer: WCHAR,
}

#[repr(C)]
pub struct SOCKADDR_STORAGE_LH {
pub ss_family: ADDRESS_FAMILY,
pub __ss_pad1: [CHAR; 6],
pub __ss_align: i64,
pub __ss_pad2: [CHAR; 112],
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct sockaddr_in {
pub sin_family: ADDRESS_FAMILY,
pub sin_port: USHORT,
pub sin_addr: in_addr,
pub sin_zero: [CHAR; 8],
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct sockaddr_in6 {
pub sin6_family: ADDRESS_FAMILY,
pub sin6_port: USHORT,
pub sin6_flowinfo: c_ulong,
pub sin6_addr: in6_addr,
pub sin6_scope_id: c_ulong,
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct in_addr {
pub s_addr: u32,
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct in6_addr {
pub s6_addr: [u8; 16],
}

// Desktop specific functions & types
cfg_if::cfg_if! {
if #[cfg(not(target_vendor = "uwp"))] {
Expand Down Expand Up @@ -226,42 +165,6 @@ pub unsafe extern "system" fn ReadFileEx(
)
}

// POSIX compatibility shims.
pub unsafe fn recv(socket: SOCKET, buf: *mut c_void, len: c_int, flags: c_int) -> c_int {
windows_sys::recv(socket, buf.cast::<u8>(), len, flags)
}
pub unsafe fn send(socket: SOCKET, buf: *const c_void, len: c_int, flags: c_int) -> c_int {
windows_sys::send(socket, buf.cast::<u8>(), len, flags)
}
pub unsafe fn recvfrom(
socket: SOCKET,
buf: *mut c_void,
len: c_int,
flags: c_int,
addr: *mut SOCKADDR,
addrlen: *mut c_int,
) -> c_int {
windows_sys::recvfrom(socket, buf.cast::<u8>(), len, flags, addr, addrlen)
}
pub unsafe fn sendto(
socket: SOCKET,
buf: *const c_void,
len: c_int,
flags: c_int,
addr: *const SOCKADDR,
addrlen: c_int,
) -> c_int {
windows_sys::sendto(socket, buf.cast::<u8>(), len, flags, addr, addrlen)
}
pub unsafe fn getaddrinfo(
node: *const c_char,
service: *const c_char,
hints: *const ADDRINFOA,
res: *mut *mut ADDRINFOA,
) -> c_int {
windows_sys::getaddrinfo(node.cast::<u8>(), service.cast::<u8>(), hints, res)
}

cfg_if::cfg_if! {
if #[cfg(not(target_vendor = "uwp"))] {
pub unsafe fn NtReadFile(
Expand Down
1 change: 1 addition & 0 deletions library/std/src/sys/pal/windows/c/bindings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,7 @@ Windows.Win32.Networking.WinSock.SOCK_RDM
Windows.Win32.Networking.WinSock.SOCK_SEQPACKET
Windows.Win32.Networking.WinSock.SOCK_STREAM
Windows.Win32.Networking.WinSock.SOCKADDR
Windows.Win32.Networking.WinSock.SOCKADDR_STORAGE
Windows.Win32.Networking.WinSock.SOCKADDR_UN
Windows.Win32.Networking.WinSock.SOCKET
Windows.Win32.Networking.WinSock.SOCKET_ERROR
Expand Down
8 changes: 8 additions & 0 deletions library/std/src/sys/pal/windows/c/windows_sys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2890,6 +2890,14 @@ pub struct SOCKADDR {
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct SOCKADDR_STORAGE {
pub ss_family: ADDRESS_FAMILY,
pub __ss_pad1: [i8; 6],
pub __ss_align: i64,
pub __ss_pad2: [i8; 112],
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct SOCKADDR_UN {
pub sun_family: ADDRESS_FAMILY,
pub sun_path: [i8; 108],
Expand Down
113 changes: 100 additions & 13 deletions library/std/src/sys/pal/windows/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,101 @@ use crate::time::Duration;

use core::ffi::{c_int, c_long, c_ulong, c_ushort};

#[allow(non_camel_case_types)]
pub type wrlen_t = i32;

pub mod netc {
pub use crate::sys::c::ADDRESS_FAMILY as sa_family_t;
pub use crate::sys::c::ADDRINFOA as addrinfo;
pub use crate::sys::c::SOCKADDR as sockaddr;
pub use crate::sys::c::SOCKADDR_STORAGE_LH as sockaddr_storage;
pub use crate::sys::c::*;
//! BSD socket compatibility shim
//!
//! Some Windows API types are not quite what's expected by our cross-platform
//! net code. E.g. naming differences or different pointer types.
use crate::sys::c::{self, ADDRESS_FAMILY, ADDRINFOA, SOCKADDR, SOCKET};
use core::ffi::{c_char, c_int, c_uint, c_ulong, c_ushort, c_void};

// re-exports from Windows API bindings.
pub use crate::sys::c::{
bind, connect, freeaddrinfo, getpeername, getsockname, getsockopt, listen, setsockopt,
ADDRESS_FAMILY as sa_family_t, ADDRINFOA as addrinfo, IPPROTO_IP, IPPROTO_IPV6,
IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MULTICAST_LOOP, IPV6_V6ONLY,
IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IP_TTL,
SOCKADDR as sockaddr, SOCKADDR_STORAGE as sockaddr_storage, SOCK_DGRAM, SOCK_STREAM,
SOL_SOCKET, SO_BROADCAST, SO_RCVTIMEO, SO_SNDTIMEO,
};

#[allow(non_camel_case_types)]
pub type socklen_t = c_int;

pub const AF_INET: i32 = c::AF_INET as i32;
pub const AF_INET6: i32 = c::AF_INET6 as i32;

// The following two structs use a union in the generated bindings but
// our cross-platform code expects a normal field so it's redefined here.
// As a consequence, we also need to redefine other structs that use this struct.
#[repr(C)]
#[derive(Copy, Clone)]
pub struct in_addr {
pub s_addr: u32,
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct in6_addr {
pub s6_addr: [u8; 16],
}

#[repr(C)]
pub struct ip_mreq {
pub imr_multiaddr: in_addr,
pub imr_interface: in_addr,
}

#[repr(C)]
pub struct ipv6_mreq {
pub ipv6mr_multiaddr: in6_addr,
pub ipv6mr_interface: c_uint,
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct sockaddr_in {
pub sin_family: ADDRESS_FAMILY,
pub sin_port: c_ushort,
pub sin_addr: in_addr,
pub sin_zero: [c_char; 8],
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct sockaddr_in6 {
pub sin6_family: ADDRESS_FAMILY,
pub sin6_port: c_ushort,
pub sin6_flowinfo: c_ulong,
pub sin6_addr: in6_addr,
pub sin6_scope_id: c_ulong,
}

pub unsafe fn send(socket: SOCKET, buf: *const c_void, len: c_int, flags: c_int) -> c_int {
c::send(socket, buf.cast::<u8>(), len, flags)
}
pub unsafe fn sendto(
socket: SOCKET,
buf: *const c_void,
len: c_int,
flags: c_int,
addr: *const SOCKADDR,
addrlen: c_int,
) -> c_int {
c::sendto(socket, buf.cast::<u8>(), len, flags, addr, addrlen)
}
pub unsafe fn getaddrinfo(
node: *const c_char,
service: *const c_char,
hints: *const ADDRINFOA,
res: *mut *mut ADDRINFOA,
) -> c_int {
c::getaddrinfo(node.cast::<u8>(), service.cast::<u8>(), hints, res)
}
}

pub struct Socket(OwnedSocket);
Expand Down Expand Up @@ -102,8 +189,8 @@ where
impl Socket {
pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result<Socket> {
let family = match *addr {
SocketAddr::V4(..) => c::AF_INET,
SocketAddr::V6(..) => c::AF_INET6,
SocketAddr::V4(..) => netc::AF_INET,
SocketAddr::V6(..) => netc::AF_INET6,
};
let socket = unsafe {
c::WSASocketW(
Expand Down Expand Up @@ -157,7 +244,7 @@ impl Socket {
return Err(io::Error::ZERO_TIMEOUT);
}

let mut timeout = c::timeval {
let mut timeout = c::TIMEVAL {
tv_sec: cmp::min(timeout.as_secs(), c_long::MAX as u64) as c_long,
tv_usec: timeout.subsec_micros() as c_long,
};
Expand All @@ -167,7 +254,7 @@ impl Socket {
}

let fds = {
let mut fds = unsafe { mem::zeroed::<c::fd_set>() };
let mut fds = unsafe { mem::zeroed::<c::FD_SET>() };
fds.fd_count = 1;
fds.fd_array[0] = self.as_raw();
fds
Expand Down Expand Up @@ -295,8 +382,8 @@ impl Socket {
buf: &mut [u8],
flags: c_int,
) -> io::Result<(usize, SocketAddr)> {
let mut storage = unsafe { mem::zeroed::<c::SOCKADDR_STORAGE_LH>() };
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
let mut storage = unsafe { mem::zeroed::<c::SOCKADDR_STORAGE>() };
let mut addrlen = mem::size_of_val(&storage) as netc::socklen_t;
let length = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t;

// On unix when a socket is shut down all further reads return 0, so we
Expand Down Expand Up @@ -399,7 +486,7 @@ impl Socket {
}

pub fn set_linger(&self, linger: Option<Duration>) -> io::Result<()> {
let linger = c::linger {
let linger = c::LINGER {
l_onoff: linger.is_some() as c_ushort,
l_linger: linger.unwrap_or_default().as_secs() as c_ushort,
};
Expand All @@ -408,7 +495,7 @@ impl Socket {
}

pub fn linger(&self) -> io::Result<Option<Duration>> {
let val: c::linger = net::getsockopt(self, c::SOL_SOCKET, c::SO_LINGER)?;
let val: c::LINGER = net::getsockopt(self, c::SOL_SOCKET, c::SO_LINGER)?;

Ok((val.l_onoff != 0).then(|| Duration::from_secs(val.l_linger as u64)))
}
Expand Down

0 comments on commit b221712

Please sign in to comment.