// Licensed under the Apache License, Version 2.0 or the MIT license // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. use std::{ ffi::CStr, io::{Error, Read as _, Result, Write as _}, net::IpAddr, num::TryFromIntError, ptr, slice, }; use libc::{ AF_NETLINK, ARPHRD_NONE, IFLA_IFNAME, IFLA_MTU, NETLINK_ROUTE, RT_SCOPE_UNIVERSE, RT_TABLE_MAIN, RTA_DST, RTA_OIF, RTM_GETLINK, RTM_GETROUTE, RTM_NEWLINK, RTM_NEWROUTE, RTN_UNICAST, c_int, }; use static_assertions::{const_assert, const_assert_eq}; use crate::{aligned_by, default_err, routesocket::RouteSocket, unlikely_err}; #[allow( clippy::allow_attributes, clippy::allow_attributes_without_reason, clippy::struct_field_names, non_camel_case_types, clippy::too_many_lines, reason = "Bindgen-generated code" )] mod bindings { include!(env!("BINDINGS")); } use bindings::{ifinfomsg, nlmsghdr, rtattr, rtmsg}; asserted_const_with_type!(AF_INET, u8, libc::AF_INET, i32); asserted_const_with_type!(AF_INET6, u8, libc::AF_INET6, i32); asserted_const_with_type!(AF_UNSPEC, u8, libc::AF_UNSPEC, i32); asserted_const_with_type!(NLM_F_REQUEST, u16, libc::NLM_F_REQUEST, c_int); asserted_const_with_type!(NLM_F_ACK, u16, libc::NLM_F_ACK, c_int); asserted_const_with_type!(NLMSG_ERROR, u16, libc::NLMSG_ERROR, c_int); const_assert!(size_of::() <= u8::MAX as usize); const_assert!(size_of::() <= u8::MAX as usize); const_assert!(size_of::() <= u8::MAX as usize); const_assert!(size_of::() <= u8::MAX as usize); const NETLINK_BUFFER_SIZE: usize = 8192; // See netlink(7) man page. #[repr(C)] enum AddrBytes { V4([u8; 4]), V6([u8; 16]), } impl AddrBytes { const fn new(ip: IpAddr) -> Self { match ip { IpAddr::V4(ip) => Self::V4(ip.octets()), IpAddr::V6(ip) => Self::V6(ip.octets()), } } const fn len(&self) -> usize { match self { Self::V4(_) => 4, Self::V6(_) => 16, } } } impl From for [u8; 16] { fn from(addr: AddrBytes) -> Self { match addr { AddrBytes::V4(bytes) => { let mut v6 = [0; 16]; v6[..4].copy_from_slice(&bytes); v6 } AddrBytes::V6(bytes) => bytes, } } } #[repr(C)] #[derive(Default)] struct IfIndexMsg { nlmsg: nlmsghdr, rtm: rtmsg, rt: rtattr, addr: [u8; 16], } impl IfIndexMsg { fn new(remote: IpAddr, nlmsg_seq: u32) -> Self { let addr = AddrBytes::new(remote); #[expect( clippy::cast_possible_truncation, reason = "Structs lens are <= u8::MAX per `const_assert!`s above; `addr_bytes` is max. 16 for IPv6." )] let nlmsg_len = (size_of::() + size_of::() + size_of::() + addr.len()) as u32; Self { nlmsg: nlmsghdr { nlmsg_len, nlmsg_type: RTM_GETROUTE, nlmsg_flags: NLM_F_REQUEST | NLM_F_ACK, nlmsg_seq, ..Default::default() }, rtm: rtmsg { rtm_family: match remote { IpAddr::V4(_) => AF_INET, IpAddr::V6(_) => AF_INET6, }, rtm_dst_len: match remote { IpAddr::V4(_) => 32, IpAddr::V6(_) => 128, }, rtm_table: RT_TABLE_MAIN, rtm_scope: RT_SCOPE_UNIVERSE, rtm_type: RTN_UNICAST, ..Default::default() }, rt: rtattr { #[expect( clippy::cast_possible_truncation, reason = "Structs len is <= u8::MAX per `const_assert!` above; `addr_bytes` is max. 16 for IPv6." )] rta_len: (size_of::() + addr.len()) as u16, rta_type: RTA_DST, }, addr: addr.into(), } } const fn len(&self) -> usize { let len = self.nlmsg.nlmsg_len as usize; debug_assert!(len <= size_of::()); len } } impl From<&IfIndexMsg> for &[u8] { fn from(value: &IfIndexMsg) -> Self { unsafe { slice::from_raw_parts(ptr::from_ref(value).cast(), value.len()) } } } impl TryFrom<&[u8]> for nlmsghdr { type Error = Error; fn try_from(value: &[u8]) -> Result { if value.len() < size_of::() { return Err(default_err()); } Ok(unsafe { ptr::read_unaligned(value.as_ptr().cast()) }) } } fn parse_c_int(buf: &[u8]) -> Result { if buf.len() < size_of::() { return Err(default_err()); } let bytes = <&[u8] as TryInto<[u8; size_of::()]>>::try_into(&buf[..size_of::()]) .map_err(|_| default_err())?; Ok(c_int::from_ne_bytes(bytes)) } fn read_msg_with_seq(fd: &mut RouteSocket, seq: u32, kind: u16) -> Result<(nlmsghdr, Vec)> { loop { let buf = &mut [0u8; NETLINK_BUFFER_SIZE]; let len = fd.read(buf.as_mut_slice())?; let mut next = &buf[..len]; while size_of::() <= next.len() { let (hdr, mut msg) = next.split_at(size_of::()); let hdr: nlmsghdr = hdr.try_into()?; // `msg` has the remainder of this message plus any following messages. // Strip those it off and assign them to `next`. debug_assert!(size_of::() <= hdr.nlmsg_len as usize); (msg, next) = msg.split_at(hdr.nlmsg_len as usize - size_of::()); if hdr.nlmsg_seq != seq { continue; } if hdr.nlmsg_type == NLMSG_ERROR { // Extract the error code and return it. let err = parse_c_int(msg)?; if err != 0 { return Err(Error::from_raw_os_error(-err)); } } else if hdr.nlmsg_type == kind { // Return the header and the message. return Ok((hdr, msg.to_vec())); } } } } impl TryFrom<&[u8]> for rtattr { type Error = Error; fn try_from(value: &[u8]) -> Result { if value.len() < size_of::() { return Err(default_err()); } Ok(unsafe { ptr::read_unaligned(value.as_ptr().cast()) }) } } struct RtAttr<'a> { hdr: rtattr, msg: &'a [u8], } impl<'a> RtAttr<'a> { fn new(bytes: &'a [u8]) -> Result { debug_assert!(bytes.len() >= size_of::()); let (hdr, mut msg) = bytes.split_at(size_of::()); let hdr: rtattr = hdr.try_into()?; let aligned_len = aligned_by(hdr.rta_len.into(), 4); debug_assert!(size_of::() <= aligned_len); (msg, _) = msg.split_at(aligned_len - size_of::()); Ok(Self { hdr, msg }) } } struct RtAttrs<'a>(&'a [u8]); impl<'a> Iterator for RtAttrs<'a> { type Item = RtAttr<'a>; fn next(&mut self) -> Option { if size_of::() <= self.0.len() { let attr = RtAttr::new(self.0).ok()?; let aligned_len = aligned_by(attr.hdr.rta_len.into(), 4); debug_assert!(self.0.len() >= aligned_len); self.0 = self.0.split_at(aligned_len).1; Some(attr) } else { None } } } fn if_index(remote: IpAddr, fd: &mut RouteSocket) -> Result { // Send RTM_GETROUTE message to get the interface index associated with the destination. let msg_seq = RouteSocket::new_seq(); let msg = IfIndexMsg::new(remote, msg_seq); fd.write_all((&msg).into())?; // Receive RTM_GETROUTE response. let (_hdr, mut buf) = read_msg_with_seq(fd, msg_seq, RTM_NEWROUTE)?; debug_assert!(size_of::() <= buf.len()); let buf = buf.split_off(size_of::()); // Parse through the attributes to find the interface index. for attr in RtAttrs(buf.as_slice()).by_ref() { if attr.hdr.rta_type == RTA_OIF { // We have our interface index. return parse_c_int(attr.msg); } } Err(default_err()) } #[repr(C)] struct IfInfoMsg { nlmsg: nlmsghdr, ifim: ifinfomsg, } impl IfInfoMsg { fn new(if_index: i32, nlmsg_seq: u32) -> Self { #[expect( clippy::cast_possible_truncation, reason = "Structs lens are <= u8::MAX per `const_assert!`s above." )] let nlmsg_len = (size_of::() + size_of::()) as u32; Self { nlmsg: nlmsghdr { nlmsg_len, nlmsg_type: RTM_GETLINK, nlmsg_flags: NLM_F_REQUEST | NLM_F_ACK, nlmsg_seq, ..Default::default() }, ifim: ifinfomsg { ifi_family: AF_UNSPEC, ifi_type: ARPHRD_NONE, ifi_index: if_index, ..Default::default() }, } } const fn len(&self) -> usize { self.nlmsg.nlmsg_len as usize } } impl From<&IfInfoMsg> for &[u8] { fn from(value: &IfInfoMsg) -> Self { debug_assert!(value.len() >= size_of::()); unsafe { slice::from_raw_parts(ptr::from_ref(value).cast(), value.len()) } } } fn if_name_mtu(if_index: i32, fd: &mut RouteSocket) -> Result<(String, usize)> { // Send RTM_GETLINK message to get interface information for the given interface index. let msg_seq = RouteSocket::new_seq(); let msg = IfInfoMsg::new(if_index, msg_seq); fd.write_all((&msg).into())?; // Receive RTM_GETLINK response. let (_hdr, mut buf) = read_msg_with_seq(fd, msg_seq, RTM_NEWLINK)?; debug_assert!(size_of::() <= buf.len()); let buf = buf.split_off(size_of::()); // Parse through the attributes to find the interface name and MTU. let mut ifname = None; let mut mtu = None; for attr in RtAttrs(buf.as_slice()).by_ref() { match attr.hdr.rta_type { IFLA_IFNAME => { let name = CStr::from_bytes_until_nul(attr.msg).map_err(Error::other)?; ifname = Some(name.to_str().map_err(Error::other)?.to_string()); } IFLA_MTU => { mtu = Some( parse_c_int(attr.msg)? .try_into() .map_err(|e: TryFromIntError| unlikely_err(e.to_string()))?, ); } _ => (), } if let (Some(ifname), Some(mtu)) = (ifname.as_ref(), mtu.as_ref()) { return Ok((ifname.clone(), *mtu)); } } Err(default_err()) } pub fn interface_and_mtu_impl(remote: IpAddr) -> Result<(String, usize)> { // Create a netlink socket. let mut fd = RouteSocket::new(AF_NETLINK, NETLINK_ROUTE)?; let if_index = if_index(remote, &mut fd)?; if_name_mtu(if_index, &mut fd) } #[cfg(test)] #[cfg_attr(coverage_nightly, coverage(off))] mod test { use std::net::{Ipv4Addr, Ipv6Addr}; use super::*; #[test] fn nlmsghdr_try_from() { assert!(nlmsghdr::try_from([0u8; 4].as_slice()).is_err()); let mut buf = [0u8; 32]; buf[0..4].copy_from_slice(&20u32.to_ne_bytes()); assert_eq!(nlmsghdr::try_from(buf.as_slice()).unwrap().nlmsg_len, 20); } #[test] fn rtattr_try_from() { assert!(rtattr::try_from([0u8; 2].as_slice()).is_err()); let mut buf = [0u8; 8]; buf[0..2].copy_from_slice(&8u16.to_ne_bytes()); buf[2..4].copy_from_slice(&3u16.to_ne_bytes()); let attr = rtattr::try_from(buf.as_slice()).unwrap(); assert_eq!((attr.rta_len, attr.rta_type), (8, 3)); } #[test] fn rtattrs_iteration() { assert_eq!(RtAttrs(&[]).count(), 0); assert_eq!(RtAttrs(&[0u8; 2]).count(), 0); let mut buf = [0u8; 16]; for (offset, rta_type) in [(0, 1u16), (8, 2u16)] { buf[offset..offset + 2].copy_from_slice(&8u16.to_ne_bytes()); buf[offset + 2..offset + 4].copy_from_slice(&rta_type.to_ne_bytes()); } let types: Vec<_> = RtAttrs(&buf).map(|a| a.hdr.rta_type).collect(); assert_eq!(types, [1, 2]); } #[test] fn addr_bytes() { assert_eq!(AddrBytes::new(IpAddr::V4(Ipv4Addr::LOCALHOST)).len(), 4); assert_eq!(AddrBytes::new(IpAddr::V6(Ipv6Addr::LOCALHOST)).len(), 16); let v4: [u8; 16] = AddrBytes::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))).into(); assert_eq!(v4, [1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); } #[test] fn parse_c_int_valid_and_invalid() { assert!(parse_c_int(&[0u8; 2]).is_err()); assert_eq!(parse_c_int(&42i32.to_ne_bytes()).unwrap(), 42); } #[test] fn if_index_msg_len_and_slice() { let msg = IfIndexMsg::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1); assert!(msg.len() > 0); let slice: &[u8] = (&msg).into(); assert_eq!(slice.len(), msg.len()); } #[test] fn if_info_msg_len_and_slice() { let msg = IfInfoMsg::new(1, 1); assert!(msg.len() > 0); let slice: &[u8] = (&msg).into(); assert_eq!(slice.len(), msg.len()); } }