diff --git a/Cargo.toml b/Cargo.toml index 1a5f86b..652591e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ hex = { version = "0.4.3", optional = true } take-until = { version = " 0.1.0", optional = true } [target.'cfg(target_os = "linux")'.dependencies] -neli = "=0.5.3" +neli = "=0.6.2" libc = "0.2.66" [dev-dependencies] diff --git a/src/linux/attr.rs b/src/linux/attr.rs index 744703b..ea83d39 100644 --- a/src/linux/attr.rs +++ b/src/linux/attr.rs @@ -1,5 +1,5 @@ use neli::consts::genl::NlAttrType; -use neli::impl_var; +use neli::neli_enum; use std::fmt; // As of neli 0.4.3, the NLA_F_NESTED flag needs to be added to newly created @@ -29,31 +29,31 @@ macro_rules! impl_bit_ops_for_nla { }; } -impl_var!( - pub NlaNested, u16, - Unspec => 0, +#[neli_enum(serialized_type = "u16")] +pub enum NlaNested { + Unspec = 0, // neli requires 1 non-zero argument even though WireGuard // does not use it. - Unused => 1 -); + Unused = 1, +} impl NlAttrType for NlaNested {} impl_bit_ops_for_nla!(NlaNested); // https://github.com/WireGuard/WireGuard/blob/62b335b56cc99312ccedfa571500fbef3756a623/src/uapi/wireguard.h#L147 -impl_var!( - pub WgDeviceAttribute, u16, - Unspec => 0, - Ifindex => 1, - Ifname => 2, - PrivateKey => 3, - PublicKey => 4, - Flags => 5, - ListenPort => 6, - Fwmark => 7, - Peers => 8 -); +#[neli_enum(serialized_type = "u16")] +pub enum WgDeviceAttribute { + Unspec = 0, + Ifindex = 1, + Ifname = 2, + PrivateKey = 3, + PublicKey = 4, + Flags = 5, + ListenPort = 6, + Fwmark = 7, + Peers = 8, +} impl NlAttrType for WgDeviceAttribute {} @@ -66,20 +66,20 @@ impl fmt::Display for WgDeviceAttribute { impl_bit_ops_for_nla!(WgDeviceAttribute); // https://github.com/WireGuard/WireGuard/blob/62b335b56cc99312ccedfa571500fbef3756a623/src/uapi/wireguard.h#L165 -impl_var!( - pub WgPeerAttribute, u16, - Unspec => 0, - PublicKey => 1, - PresharedKey => 2, - Flags => 3, - Endpoint => 4, - PersistentKeepaliveInterval => 5, - LastHandshakeTime => 6, - RxBytes => 7, - TxBytes => 8, - AllowedIps => 9, - ProtocolVersion => 10 -); +#[neli_enum(serialized_type = "u16")] +pub enum WgPeerAttribute { + Unspec = 0, + PublicKey = 1, + PresharedKey = 2, + Flags = 3, + Endpoint = 4, + PersistentKeepaliveInterval = 5, + LastHandshakeTime = 6, + RxBytes = 7, + TxBytes = 8, + AllowedIps = 9, + ProtocolVersion = 10, +} impl NlAttrType for WgPeerAttribute {} @@ -92,11 +92,12 @@ impl fmt::Display for WgPeerAttribute { impl_bit_ops_for_nla!(WgPeerAttribute); // https://github.com/WireGuard/WireGuard/blob/62b335b56cc99312ccedfa571500fbef3756a623/src/uapi/wireguard.h#L181 -impl_var!( - pub WgAllowedIpAttribute, u16, - Unspec => 0, - Family => 1, - IpAddr => 2, - CidrMask => 3 -); +#[neli_enum(serialized_type = "u16")] +pub enum WgAllowedIpAttribute { + Unspec = 0, + Family = 1, + IpAddr = 2, + CidrMask = 3, +} + impl NlAttrType for WgAllowedIpAttribute {} diff --git a/src/linux/cmd.rs b/src/linux/cmd.rs index 77bcfd4..6b4b129 100644 --- a/src/linux/cmd.rs +++ b/src/linux/cmd.rs @@ -1,10 +1,11 @@ -use neli::{consts::genl::Cmd, impl_var}; +use neli::consts::genl::Cmd; +use neli::neli_enum; // https://github.com/WireGuard/WireGuard/blob/62b335b56cc99312ccedfa571500fbef3756a623/src/uapi/wireguard.h#L137 -impl_var!( - pub WgCmd, u8, - GetDevice => 0, - SetDevice => 1 -); +#[neli_enum(serialized_type = "u8")] +pub enum WgCmd { + GetDevice = 0, + SetDevice = 1, +} impl Cmd for WgCmd {} diff --git a/src/linux/err/connect_error.rs b/src/linux/err/connect_error.rs index 9f48a9a..2aca27c 100644 --- a/src/linux/err/connect_error.rs +++ b/src/linux/err/connect_error.rs @@ -1,4 +1,9 @@ +use neli::consts::{ + genl::{CtrlAttr, CtrlCmd}, + nl::GenlId, +}; use neli::err::NlError; +use neli::genl::Genlmsghdr; use thiserror::Error; #[derive(Error, Debug)] @@ -7,7 +12,7 @@ pub enum ConnectError { NlError(NlError), #[error("Unable to connect to the WireGuard DKMS. Is WireGuard installed?")] - ResolveFamilyError(#[source] NlError), + ResolveFamilyError(#[source] NlError>), } impl From for ConnectError { diff --git a/src/linux/interface.rs b/src/linux/interface.rs index d944c63..24b1035 100644 --- a/src/linux/interface.rs +++ b/src/linux/interface.rs @@ -25,15 +25,11 @@ impl<'a> TryFrom<&DeviceInterface<'a>> for Nlattr { fn try_from(interface: &DeviceInterface) -> Result { let attr = match interface { &DeviceInterface::Index(ifindex) => { - Nlattr::new(None, false, false, WgDeviceAttribute::Ifindex, ifindex)? + Nlattr::new(false, false, WgDeviceAttribute::Ifindex, ifindex)? + } + DeviceInterface::Name(ifname) => { + Nlattr::new(false, false, WgDeviceAttribute::Ifname, ifname.as_ref())? } - DeviceInterface::Name(ifname) => Nlattr::new( - None, - false, - false, - WgDeviceAttribute::Ifname, - ifname.as_ref(), - )?, }; Ok(attr) } diff --git a/src/linux/set/allowed_ip.rs b/src/linux/set/allowed_ip.rs index 725742f..c2eb666 100644 --- a/src/linux/set/allowed_ip.rs +++ b/src/linux/set/allowed_ip.rs @@ -26,14 +26,13 @@ impl<'a> TryFrom<&AllowedIp<'a>> for Nlattr { fn try_from(allowed_ip: &AllowedIp) -> Result { let mut nested = - Nlattr::new::>(None, false, false, NlaNested::Unspec | NLA_F_NESTED, vec![])?; + Nlattr::new::>(false, false, NlaNested::Unspec | NLA_F_NESTED, vec![])?; let family = match allowed_ip.ipaddr { IpAddr::V4(_) => libc::AF_INET as u16, IpAddr::V6(_) => libc::AF_INET6 as u16, }; nested.add_nested_attribute(&Nlattr::new( - None, false, false, WgAllowedIpAttribute::Family, @@ -45,7 +44,6 @@ impl<'a> TryFrom<&AllowedIp<'a>> for Nlattr { IpAddr::V6(addr) => addr.octets().to_vec(), }; nested.add_nested_attribute(&Nlattr::new( - None, false, false, WgAllowedIpAttribute::IpAddr, @@ -57,7 +55,6 @@ impl<'a> TryFrom<&AllowedIp<'a>> for Nlattr { IpAddr::V6(_) => 128, }); nested.add_nested_attribute(&Nlattr::new( - None, false, false, WgAllowedIpAttribute::CidrMask, diff --git a/src/linux/set/create_set_device_messages.rs b/src/linux/set/create_set_device_messages.rs index 1f8f56a..a45daf0 100644 --- a/src/linux/set/create_set_device_messages.rs +++ b/src/linux/set/create_set_device_messages.rs @@ -11,7 +11,7 @@ use neli::{ genl::{Genlmsghdr, Nlattr}, nl::{NlPayload, Nlmsghdr}, types::{Buffer, GenlBuffer}, - Nl, + Size, }; use std::convert::TryInto; use std::net::SocketAddr; @@ -37,7 +37,8 @@ impl IncubatingDeviceFragment { partial_device: { let mut attrs = GenlBuffer::new(); - let interface_attr = (&device.interface).try_into()?; + let interface_attr: Nlattr = + (&device.interface).try_into()?; attrs.push(interface_attr); if !device.flags.is_empty() { @@ -45,7 +46,6 @@ impl IncubatingDeviceFragment { unique.dedup(); attrs.push(Nlattr::new( - None, false, false, WgDeviceAttribute::Flags, @@ -55,7 +55,6 @@ impl IncubatingDeviceFragment { if let Some(private_key) = device.private_key { attrs.push(Nlattr::new( - None, false, false, WgDeviceAttribute::PrivateKey, @@ -65,7 +64,6 @@ impl IncubatingDeviceFragment { if let Some(listen_port) = device.listen_port { attrs.push(Nlattr::new( - None, false, false, WgDeviceAttribute::ListenPort, @@ -75,7 +73,6 @@ impl IncubatingDeviceFragment { if let Some(fwmark) = device.fwmark { attrs.push(Nlattr::new( - None, false, false, WgDeviceAttribute::Fwmark, @@ -88,8 +85,7 @@ impl IncubatingDeviceFragment { attrs }, - peers: Nlattr::new( - None, + peers: Nlattr::new::>( false, false, WgDeviceAttribute::Peers | NLA_F_NESTED, @@ -101,14 +97,12 @@ impl IncubatingDeviceFragment { } fn from_interface(interface: &DeviceInterface) -> Result { - let mut interface_attr = GenlBuffer::new(); - - interface_attr.push(interface.try_into()?); + let mut partial_device = GenlBuffer::new(); + partial_device.push(interface.try_into()?); Ok(Self { - partial_device: interface_attr, - peers: Nlattr::new( - None, + partial_device, + peers: Nlattr::new::>( false, false, WgDeviceAttribute::Peers | NLA_F_NESTED, @@ -118,16 +112,20 @@ impl IncubatingDeviceFragment { } fn incubating_size(&self) -> usize { - let attrs_size: usize = self.partial_device.iter().map(|attr| attr.asize()).sum(); + let attrs_size: usize = self + .partial_device + .iter() + .map(|attr| attr.padded_size()) + .sum(); - NETLINK_HEADER_SIZE + GENL_HEADER_SIZE + attrs_size + self.peers.asize() + NETLINK_HEADER_SIZE + GENL_HEADER_SIZE + attrs_size + self.peers.padded_size() } fn finalize(self, family_id: NlWgMsgType) -> Result { let mut device_attrs = self.partial_device; // TODO: Condition this behavior on whether peers have ever been added. - if self.peers.size() > GENL_HEADER_SIZE { + if self.peers.unpadded_size() > GENL_HEADER_SIZE { device_attrs.push(self.peers); } @@ -158,10 +156,9 @@ struct IncubatingPeerFragment { impl IncubatingPeerFragment { fn split_off_allowed_ips(peer: Peer<'_>) -> Result<(Self, Vec>), NlError> { let mut partial_peer = - Nlattr::new(None, false, false, NlaNested::Unspec | NLA_F_NESTED, vec![])?; + Nlattr::new::>(false, false, NlaNested::Unspec | NLA_F_NESTED, vec![])?; let public_key = Nlattr::new( - None, false, false, WgPeerAttribute::PublicKey, @@ -174,7 +171,6 @@ impl IncubatingPeerFragment { unique.dedup(); partial_peer.add_nested_attribute(&Nlattr::new( - None, false, false, WgPeerAttribute::Flags, @@ -184,7 +180,6 @@ impl IncubatingPeerFragment { if let Some(preshared_key) = peer.preshared_key { partial_peer.add_nested_attribute(&Nlattr::new( - None, false, false, WgPeerAttribute::PresharedKey, @@ -218,7 +213,6 @@ impl IncubatingPeerFragment { }; partial_peer.add_nested_attribute(&Nlattr::new( - None, false, false, WgPeerAttribute::Endpoint, @@ -228,7 +222,6 @@ impl IncubatingPeerFragment { if let Some(persistent_keepalive_interval) = peer.persistent_keepalive_interval { partial_peer.add_nested_attribute(&Nlattr::new( - None, false, false, WgPeerAttribute::PersistentKeepaliveInterval, @@ -238,7 +231,6 @@ impl IncubatingPeerFragment { if let Some(protocol_version) = peer.protocol_version { partial_peer.add_nested_attribute(&Nlattr::new( - None, false, false, WgPeerAttribute::ProtocolVersion, @@ -251,8 +243,7 @@ impl IncubatingPeerFragment { let incubating_peer_fragment = IncubatingPeerFragment { partial_peer, - allowed_ips: Nlattr::new( - None, + allowed_ips: Nlattr::new::>( false, false, WgPeerAttribute::AllowedIps | NLA_F_NESTED, @@ -265,9 +256,8 @@ impl IncubatingPeerFragment { fn from_public_key(public_key: &[u8; 32]) -> Result { let mut partial_peer = - Nlattr::new(None, false, false, NlaNested::Unspec | NLA_F_NESTED, vec![])?; - let allowed_ips = Nlattr::new( - None, + Nlattr::new::>(false, false, NlaNested::Unspec | NLA_F_NESTED, vec![])?; + let allowed_ips = Nlattr::new::>( false, false, WgPeerAttribute::AllowedIps | NLA_F_NESTED, @@ -275,7 +265,6 @@ impl IncubatingPeerFragment { )?; let public_key = Nlattr::new( - None, false, false, WgPeerAttribute::PublicKey, @@ -290,12 +279,12 @@ impl IncubatingPeerFragment { } fn incubating_size(&self) -> usize { - self.partial_peer.asize() + self.allowed_ips.asize() + self.partial_peer.padded_size() + self.allowed_ips.padded_size() } fn finalize(self) -> Result, NlError> { let mut partial_peer = self.partial_peer; - if self.allowed_ips.size() > GENL_HEADER_SIZE { + if self.allowed_ips.padded_size() > GENL_HEADER_SIZE { partial_peer.add_nested_attribute(&self.allowed_ips)?; } Ok(partial_peer) @@ -333,7 +322,7 @@ pub fn create_set_device_messages( let next_size = incubating_device_fragment.incubating_size() + incubating_peer_fragment.incubating_size() - + allowed_ip_attr.asize(); + + allowed_ip_attr.padded_size(); if next_size > NETLINK_MSG_LIMIT { let peer_fragment = incubating_peer_fragment.finalize()?; incubating_device_fragment diff --git a/src/linux/socket/link_message.rs b/src/linux/socket/link_message.rs index fada6ed..0808d64 100644 --- a/src/linux/socket/link_message.rs +++ b/src/linux/socket/link_message.rs @@ -7,7 +7,7 @@ use neli::{ err::NlError, nl::{NlPayload, Nlmsghdr}, rtnl::{Ifinfomsg, Rtattr}, - types::{Buffer, RtBuffer}, + types::RtBuffer, }; pub enum WireGuardDeviceLinkOperation { @@ -19,17 +19,17 @@ pub fn link_message( ifname: &str, link_operation: WireGuardDeviceLinkOperation, ) -> Result, NlError> { - let ifname = Rtattr::new(None, Ifla::Ifname, Buffer::from(ifname.as_bytes()))?; - let link = { + let rtattrs = { let mut attrs = RtBuffer::new(); + attrs.push(Rtattr::new(None, Ifla::Ifname, ifname.as_bytes())?); - attrs.push(Rtattr::new( - None, - IflaInfo::Kind, - WG_GENL_NAME.as_bytes().to_vec(), - )?); + let mut genl_name = RtBuffer::new(); + genl_name.push(Rtattr::new(None, IflaInfo::Kind, WG_GENL_NAME.as_bytes())?); - Rtattr::new(None, Ifla::Linkinfo, attrs)? + let link = Rtattr::new(None, Ifla::Linkinfo, genl_name)?; + + attrs.push(link); + attrs }; let infomsg = { let ifi_family = RtAddrFamily::Unspecified; @@ -38,12 +38,7 @@ pub fn link_message( let ifi_type = Arphrd::Netrom; let ifi_index = 0; let ifi_flags = IffFlags::empty(); - let mut rtattrs = RtBuffer::new(); let ifi_change = IffFlags::new(&[Iff::Up]); - - rtattrs.push(ifname); - rtattrs.push(link); - Ifinfomsg::new( ifi_family, ifi_type, ifi_index, ifi_flags, ifi_change, rtattrs, ) diff --git a/src/linux/socket/list_device_names_utils.rs b/src/linux/socket/list_device_names_utils.rs index 894396e..337c288 100644 --- a/src/linux/socket/list_device_names_utils.rs +++ b/src/linux/socket/list_device_names_utils.rs @@ -1,7 +1,7 @@ use crate::err::ListDevicesError; use neli::{ consts::{ - nl::{NlTypeWrapper, NlmF, NlmFFlags}, + nl::{NlmF, NlmFFlags, Nlmsg}, rtnl::{Arphrd, Iff, IffFlags, Ifla, IflaInfo, Rtm}, }, nl::{NlPayload, Nlmsghdr}, @@ -12,8 +12,7 @@ use std::convert::TryFrom; pub fn get_list_device_names_msg() -> Nlmsghdr { let infomsg = { - let ifi_family = - neli::consts::rtnl::RtAddrFamily::UnrecognizedVariant(libc::AF_UNSPEC as u8); + let ifi_family = neli::consts::rtnl::RtAddrFamily::Unspecified; // Arphrd::Netrom corresponds to 0. Not sure why 0 is necessary here but this is what the // embedded C library does. let ifi_type = Arphrd::Netrom; @@ -41,19 +40,25 @@ pub struct PotentialWireGuardDeviceName { pub is_wireguard: bool, } -impl TryFrom> for PotentialWireGuardDeviceName { +impl TryFrom> for PotentialWireGuardDeviceName { type Error = ListDevicesError; - fn try_from(response: Nlmsghdr) -> Result { - let mut handle = response.get_payload()?.rtattrs.get_attr_handle(); + fn try_from(response: Nlmsghdr) -> Result { + let payload = response + .nl_payload + .get_payload() + .ok_or(ListDevicesError::Unknown)?; + let mut handle = payload.rtattrs.get_attr_handle(); Ok(PotentialWireGuardDeviceName { - ifname: handle.get_attr_payload_as::(Ifla::Ifname).ok(), + ifname: handle + .get_attr_payload_as_with_len::(Ifla::Ifname) + .ok(), is_wireguard: handle .get_nested_attributes(Ifla::Linkinfo) .map_or(false, |linkinfo| { linkinfo - .get_attr_payload_as::(IflaInfo::Kind) + .get_attr_payload_as_with_len::(IflaInfo::Kind) .map_or(false, |info_kind| { info_kind == crate::linux::consts::WG_GENL_NAME }) diff --git a/src/linux/socket/parse.rs b/src/linux/socket/parse.rs index 3be4550..ead0efe 100644 --- a/src/linux/socket/parse.rs +++ b/src/linux/socket/parse.rs @@ -23,7 +23,7 @@ impl TryFrom> for Device { let mut device_builder = DeviceBuilder::default(); for attr in handle.iter() { - match attr.nla_type & NLA_TYPE_MASK { + match attr.nla_type.nla_type & NLA_TYPE_MASK { WgDeviceAttribute::Unspec => { // The embeddable-wg-library example ignores unspec, so we'll do the same. } @@ -51,7 +51,7 @@ impl TryFrom> for Device { WgDeviceAttribute::Flags => { // This attribute is for set_device. Ignore it for get_device. } - WgDeviceAttribute::UnrecognizedVariant(i) => { + WgDeviceAttribute::UnrecognizedConst(i) => { return Err(ParseDeviceError::UnknownDeviceAttributeError { id: i }) } } @@ -63,14 +63,14 @@ impl TryFrom> for Device { pub fn extend_device( mut device: Device, - handle: AttrHandle, + handle: AttrHandle<'_, WgDeviceAttribute>, ) -> Result { let next_peers = { let peers_attr = handle .iter() - .find(|attr| attr.nla_type & NLA_TYPE_MASK == WgDeviceAttribute::Peers) + .find(|attr| attr.nla_type.nla_type & NLA_TYPE_MASK == WgDeviceAttribute::Peers) .expect("Unable to find additional peers to coalesce."); - let handle = peers_attr.get_attr_handle::()?; + let handle = peers_attr.get_attr_handle()?; handle .iter() @@ -100,11 +100,11 @@ pub fn extend_device( Ok(device) } -pub fn parse_peers(handle: AttrHandle) -> Result, ParseDeviceError> { +pub fn parse_peers(handle: AttrHandle<'_, NlaNested>) -> Result, ParseDeviceError> { let mut peers = vec![]; for peer in handle.iter() { - let handle = peer.get_attr_handle::()?; + let handle = peer.get_attr_handle()?; peers.push(Peer::try_from(handle)?); } @@ -112,12 +112,12 @@ pub fn parse_peers(handle: AttrHandle) -> Result, ParseDevi } pub fn parse_peer_builder( - handle: AttrHandle, + handle: AttrHandle<'_, WgPeerAttribute>, ) -> Result { let mut peer_builder = PeerBuilder::default(); for attr in handle.iter() { - match attr.nla_type & NLA_TYPE_MASK { + match attr.nla_type.nla_type & NLA_TYPE_MASK { WgPeerAttribute::Unspec => {} WgPeerAttribute::Flags => {} WgPeerAttribute::PublicKey => { @@ -144,14 +144,13 @@ pub fn parse_peer_builder( peer_builder.tx_bytes(parse_nla_u64(attr.nla_payload.as_ref())?); } WgPeerAttribute::AllowedIps => { - if let Some(handle) = handle.get_attribute(WgPeerAttribute::AllowedIps) { - peer_builder.allowed_ips(parse_allowedips(handle.get_attr_handle()?)?); - } + let handle = attr.get_attr_handle()?; + peer_builder.allowed_ips(parse_allowedips(handle)?); } WgPeerAttribute::ProtocolVersion => { peer_builder.protocol_version(parse_nla_u32(attr.nla_payload.as_ref())?); } - WgPeerAttribute::UnrecognizedVariant(i) => { + WgPeerAttribute::UnrecognizedConst(i) => { return Err(ParseDeviceError::UnknownPeerAttributeError { id: i }) } } @@ -162,14 +161,15 @@ pub fn parse_peer_builder( impl TryFrom> for Peer { type Error = ParseDeviceError; - fn try_from(handle: AttrHandle<'_, WgPeerAttribute>) -> Result { let peer_builder = parse_peer_builder(handle)?; Ok(peer_builder.build()?) } } -pub fn parse_allowedips(handle: AttrHandle) -> Result, ParseDeviceError> { +pub fn parse_allowedips( + handle: AttrHandle<'_, NlaNested>, +) -> Result, ParseDeviceError> { let mut allowed_ips = vec![]; for allowed_ip in handle.iter() { @@ -189,7 +189,7 @@ impl TryFrom> for AllowedIp { for attr in handle.iter() { let payload = attr.nla_payload.as_ref(); - match attr.nla_type { + match attr.nla_type.nla_type { WgAllowedIpAttribute::Unspec => {} WgAllowedIpAttribute::Family => { allowed_ip_builder.family(parse_nla_u16(payload)?); @@ -209,7 +209,7 @@ impl TryFrom> for AllowedIp { WgAllowedIpAttribute::CidrMask => { allowed_ip_builder.cidr_mask(parse_nla_u8(payload)?); } - WgAllowedIpAttribute::UnrecognizedVariant(i) => { + WgAllowedIpAttribute::UnrecognizedConst(i) => { return Err(ParseDeviceError::UnknownAllowedIpAttributeError { id: i }) } } @@ -354,7 +354,7 @@ mod tests { use super::*; use crate::linux::cmd::WgCmd; use anyhow::Error; - use neli::{err::DeError, genl::Genlmsghdr, Nl}; + use neli::{err::DeError, genl::Genlmsghdr}; // This device comes from the configuration example in "man wg", but with // the third peer removed since it specifies an domain endpoint only valid @@ -427,7 +427,8 @@ mod tests { fn create_test_genlmsghdr( payload: &[u8], ) -> Result, DeError> { - Genlmsghdr::deserialize(payload) + use neli::FromBytesWithInput; + Genlmsghdr::from_bytes_with_input(&mut std::io::Cursor::new(payload), payload.len()) } #[test] diff --git a/src/linux/socket/route_socket.rs b/src/linux/socket/route_socket.rs index 93377b8..12d4acf 100644 --- a/src/linux/socket/route_socket.rs +++ b/src/linux/socket/route_socket.rs @@ -3,12 +3,7 @@ use super::{link_message, WireGuardDeviceLinkOperation}; use crate::err::{ConnectError, LinkDeviceError, ListDevicesError}; use list_device_names_utils::PotentialWireGuardDeviceName; use neli::{ - consts::{ - genl::{CtrlAttr, CtrlCmd}, - nl::{NlTypeWrapper, Nlmsg}, - socket::NlFamily, - }, - genl::Genlmsghdr, + consts::{nl::Nlmsg, socket::NlFamily}, rtnl::Ifinfomsg, socket::NlSocketHandle, }; @@ -20,30 +15,26 @@ pub struct RouteSocket { impl RouteSocket { pub fn connect() -> Result { - let sock = NlSocketHandle::new(NlFamily::Route)?; - // Autoselect a PID let pid = None; let groups = &[]; - sock.bind(pid, groups)?; + let sock = NlSocketHandle::connect(NlFamily::Route, pid, groups)?; Ok(Self { sock }) } pub fn add_device(&mut self, ifname: &str) -> Result<(), LinkDeviceError> { let operation = WireGuardDeviceLinkOperation::Add; - self.sock.send(link_message(ifname, operation)?)?; - self.sock.recv::>()?; + self.sock.recv()?; Ok(()) } pub fn del_device(&mut self, ifname: &str) -> Result<(), LinkDeviceError> { let operation = WireGuardDeviceLinkOperation::Delete; - self.sock.send(link_message(ifname, operation)?)?; - self.sock.recv::>()?; + self.sock.recv()?; Ok(()) } @@ -54,14 +45,14 @@ impl RouteSocket { self.sock .send(list_device_names_utils::get_list_device_names_msg())?; - let mut iter = self.sock.iter::(false); + let mut iter = self.sock.iter::(false); let mut result_names = vec![]; while let Some(Ok(response)) = iter.next() { match response.nl_type { - NlTypeWrapper::Nlmsg(Nlmsg::Error) => return Err(ListDevicesError::Unknown), - NlTypeWrapper::Nlmsg(Nlmsg::Done) => break, + Nlmsg::Error => return Err(ListDevicesError::Unknown), + Nlmsg::Done => break, _ => (), } diff --git a/src/linux/socket/wg_socket.rs b/src/linux/socket/wg_socket.rs index d1def1a..3011fdd 100644 --- a/src/linux/socket/wg_socket.rs +++ b/src/linux/socket/wg_socket.rs @@ -11,8 +11,7 @@ use crate::linux::DeviceInterface; use libc::IFNAMSIZ; use neli::{ consts::{ - genl::{CtrlAttr, CtrlCmd}, - nl::{NlTypeWrapper, NlmF, NlmFFlags, Nlmsg}, + nl::{NlmF, NlmFFlags, Nlmsg}, socket::NlFamily, }, genl::{Genlmsghdr, Nlattr}, @@ -35,12 +34,10 @@ impl WgSocket { .map_err(ConnectError::ResolveFamilyError)? }; - let wgsock = NlSocketHandle::new(NlFamily::Generic)?; - // Autoselect a PID let pid = None; let groups = &[]; - wgsock.bind(pid, groups)?; + let wgsock = NlSocketHandle::connect(NlFamily::Generic, pid, groups)?; Ok(Self { sock: wgsock, @@ -57,10 +54,10 @@ impl WgSocket { Some(name.len()) .filter(|&len| 0 < len && len < IFNAMSIZ) .ok_or(GetDeviceError::InvalidInterfaceName)?; - Nlattr::new(None, false, false, WgDeviceAttribute::Ifname, name.as_ref())? + Nlattr::new(false, false, WgDeviceAttribute::Ifname, name.as_ref())? } DeviceInterface::Index(index) => { - Nlattr::new(None, false, false, WgDeviceAttribute::Ifindex, index)? + Nlattr::new(false, false, WgDeviceAttribute::Ifindex, index)? } }; let genlhdr = { @@ -69,7 +66,6 @@ impl WgSocket { let mut attrs = GenlBuffer::new(); attrs.push(attr); - Genlmsghdr::new(cmd, version, attrs) }; let nlhdr = { @@ -86,13 +82,13 @@ impl WgSocket { let mut iter = self .sock - .iter::>(false); + .iter::>(false); let mut device = None; while let Some(Ok(response)) = iter.next() { match response.nl_type { - NlTypeWrapper::Nlmsg(Nlmsg::Error) => return Err(GetDeviceError::AccessError), - NlTypeWrapper::Nlmsg(Nlmsg::Done) => break, + Nlmsg::Error => return Err(GetDeviceError::AccessError), + Nlmsg::Done => break, _ => (), }; @@ -121,7 +117,7 @@ impl WgSocket { pub fn set_device(&mut self, device: set::Device) -> Result<(), SetDeviceError> { for nl_message in create_set_device_messages(device, self.family_id)? { self.sock.send(nl_message)?; - self.sock.recv::>()?; + self.sock.recv()?; } Ok(())