use std::alloc::Layout;
use std::mem::MaybeUninit;
use std::os::unix::io::AsRawFd;
use std::str;
use libc::EINVAL;
use log::error;
use zerocopy::FromBytes;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
use super::errno_result;
use super::getpid;
use super::Error;
use super::RawDescriptor;
use super::Result;
use crate::alloc::LayoutAllocation;
use crate::descriptor::AsRawDescriptor;
use crate::descriptor::FromRawDescriptor;
use crate::descriptor::SafeDescriptor;
macro_rules! debug_pr {
($($args:tt)+) => {};
}
const NLMSGHDR_SIZE: usize = std::mem::size_of::<NlMsgHdr>();
const GENL_HDRLEN: usize = std::mem::size_of::<GenlMsgHdr>();
const NLA_HDRLEN: usize = std::mem::size_of::<NlAttr>();
const NLATTR_ALIGN_TO: usize = 4;
#[repr(C)]
#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
struct NlMsgHdr {
pub nlmsg_len: u32,
pub nlmsg_type: u16,
pub nlmsg_flags: u16,
pub nlmsg_seq: u32,
pub nlmsg_pid: u32,
}
#[repr(C)]
#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
pub struct NlAttr {
pub len: u16,
pub _type: u16,
}
#[repr(C)]
#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
pub struct GenlMsgHdr {
pub cmd: u8,
pub version: u8,
pub reserved: u16,
}
pub struct NetlinkMessage<'a> {
pub _type: u16,
pub flags: u16,
pub seq: u32,
pub pid: u32,
pub data: &'a [u8],
}
pub struct NlAttrWithData<'a> {
pub len: u16,
pub _type: u16,
pub data: &'a [u8],
}
pub struct NetlinkGenericDataIter<'a> {
data: &'a [u8],
}
impl<'a> Iterator for NetlinkGenericDataIter<'a> {
type Item = NlAttrWithData<'a>;
fn next(&mut self) -> Option<Self::Item> {
let (nl_hdr, _) = NlAttr::read_from_prefix(self.data).ok()?;
let nl_data_len = nl_hdr.len as usize;
let data = self.data.get(NLA_HDRLEN..nl_data_len)?;
let next_hdr = nl_data_len.next_multiple_of(NLATTR_ALIGN_TO);
self.data = self.data.get(next_hdr..).unwrap_or(&[]);
Some(NlAttrWithData {
_type: nl_hdr._type,
len: nl_hdr.len,
data,
})
}
}
pub struct NetlinkMessageIter<'a> {
data: &'a [u8],
}
impl<'a> Iterator for NetlinkMessageIter<'a> {
type Item = NetlinkMessage<'a>;
fn next(&mut self) -> Option<Self::Item> {
let (hdr, _) = NlMsgHdr::read_from_prefix(self.data).ok()?;
let msg_len = hdr.nlmsg_len as usize;
let data = self.data.get(NLMSGHDR_SIZE..msg_len)?;
let next_hdr = msg_len.next_multiple_of(std::mem::align_of::<NlMsgHdr>());
self.data = self.data.get(next_hdr..).unwrap_or(&[]);
Some(NetlinkMessage {
_type: hdr.nlmsg_type,
flags: hdr.nlmsg_flags,
seq: hdr.nlmsg_seq,
pid: hdr.nlmsg_pid,
data,
})
}
}
pub struct NetlinkGenericSocket {
sock: SafeDescriptor,
}
impl AsRawDescriptor for NetlinkGenericSocket {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.sock.as_raw_descriptor()
}
}
impl NetlinkGenericSocket {
pub fn new(nl_groups: u32) -> Result<Self> {
let sock = unsafe {
let fd = libc::socket(
libc::AF_NETLINK,
libc::SOCK_RAW | libc::SOCK_CLOEXEC,
libc::NETLINK_GENERIC,
);
if fd < 0 {
return errno_result();
}
SafeDescriptor::from_raw_descriptor(fd)
};
let mut sa = unsafe { MaybeUninit::<libc::sockaddr_nl>::zeroed().assume_init() };
sa.nl_family = libc::AF_NETLINK as libc::sa_family_t;
sa.nl_groups = nl_groups;
unsafe {
let res = libc::bind(
sock.as_raw_fd(),
&sa as *const libc::sockaddr_nl as *const libc::sockaddr,
std::mem::size_of_val(&sa) as libc::socklen_t,
);
if res < 0 {
return errno_result();
}
}
Ok(NetlinkGenericSocket { sock })
}
pub fn recv(&self) -> Result<NetlinkGenericRead> {
let buf_size = 8192;
let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
.map_err(|_| Error::new(EINVAL))?;
let allocation = LayoutAllocation::uninitialized(layout);
let bytes_read = unsafe {
let res = libc::recv(self.sock.as_raw_fd(), allocation.as_ptr(), buf_size, 0);
if res < 0 {
return errno_result();
}
res as usize
};
Ok(NetlinkGenericRead {
allocation,
len: bytes_read,
})
}
pub fn family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead> {
let buf_size = 1024;
debug_pr!(
"preparing query for family name {}, len {}",
family_name,
family_name.len()
);
let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
.map_err(|_| Error::new(EINVAL))
.unwrap();
let mut allocation = LayoutAllocation::zeroed(layout);
let data = unsafe { allocation.as_mut_slice(buf_size) };
let (hdr, genl_hdr) = NlMsgHdr::mut_from_prefix(data).expect("failed to unwrap");
hdr.nlmsg_len = NLMSGHDR_SIZE as u32 + GENL_HDRLEN as u32;
hdr.nlmsg_len += NLA_HDRLEN as u32 + family_name.len() as u32 + 1;
hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16;
hdr.nlmsg_type = libc::GENL_ID_CTRL as u16;
hdr.nlmsg_pid = getpid() as u32;
let (genl_hdr, nlattr) =
GenlMsgHdr::mut_from_prefix(genl_hdr).expect("unable to get GenlMsgHdr from slice");
genl_hdr.cmd = libc::CTRL_CMD_GETFAMILY as u8;
genl_hdr.version = 0x1;
let (nl_attr, payload) =
NlAttr::mut_from_prefix(nlattr).expect("unable to get NlAttr from slice");
nl_attr._type = libc::CTRL_ATTR_FAMILY_NAME as u16;
nl_attr.len = family_name.len() as u16 + 1 + NLA_HDRLEN as u16;
payload[..family_name.len()].copy_from_slice(family_name.as_bytes());
let len = NLMSGHDR_SIZE + GENL_HDRLEN + NLA_HDRLEN + family_name.len() + 1;
unsafe {
let res = libc::send(self.sock.as_raw_fd(), allocation.as_ptr(), len, 0);
if res < 0 {
error!("failed to send get_family_cmd");
return errno_result();
}
};
match self.recv() {
Ok(msg) => Ok(msg),
Err(e) => {
error!("recv get_family returned with error {}", e);
Err(e)
}
}
}
}
fn parse_ctrl_group_name_and_id(
nested_nl_attr_data: NetlinkGenericDataIter,
group_name: &str,
) -> Option<u32> {
let mut mcast_group_id: Option<u32> = None;
for nested_nl_attr in nested_nl_attr_data {
debug_pr!(
"\t\tmcast_grp: nlattr type {}, len {}",
nested_nl_attr._type,
nested_nl_attr.len
);
if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_ID as u16 {
mcast_group_id = Some(u32::from_ne_bytes(nested_nl_attr.data.try_into().unwrap()));
debug_pr!("\t\t mcast group_id {}", mcast_group_id?);
}
if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_NAME as u16 {
debug_pr!(
"\t\t mcast group name {}",
strip_padding(&nested_nl_attr.data)
);
if group_name.eq(strip_padding(nested_nl_attr.data)) && mcast_group_id.is_some() {
debug_pr!(
"\t\t Got what we were looking for group_id = {} for {}",
mcast_group_id?,
group_name
);
return mcast_group_id;
}
}
}
None
}
fn parse_ctrl_mcast_group_id(
nl_attr_area: NetlinkGenericDataIter,
group_name: &str,
) -> Option<u32> {
for nested_gr_nl_attr in nl_attr_area {
debug_pr!(
"\tmcast_groups: nlattr type(gr_nr) {}, len {}",
nested_gr_nl_attr._type,
nested_gr_nl_attr.len
);
let netlink_nested_attr = NetlinkGenericDataIter {
data: nested_gr_nl_attr.data,
};
if let Some(mcast_group_id) = parse_ctrl_group_name_and_id(netlink_nested_attr, group_name)
{
return Some(mcast_group_id);
}
}
None
}
fn strip_padding(b: &[u8]) -> &str {
let pos = b
.iter()
.position(|&c| c == 0)
.expect("`b` doesn't contain any nul bytes");
str::from_utf8(&b[..pos]).unwrap()
}
pub struct NetlinkGenericRead {
allocation: LayoutAllocation,
len: usize,
}
impl NetlinkGenericRead {
pub fn iter(&self) -> NetlinkMessageIter {
let data = unsafe { &self.allocation.as_slice(self.len) };
NetlinkMessageIter { data }
}
pub fn get_multicast_group_id(&self, group_name: String) -> Option<u32> {
for netlink_msg in self.iter() {
debug_pr!(
"received type: {}, flags {}, pid {}, data {:?}",
netlink_msg._type,
netlink_msg.flags,
netlink_msg.pid,
netlink_msg.data
);
if netlink_msg._type != libc::GENL_ID_CTRL as u16 {
error!("Received not a generic netlink controller msg");
return None;
}
let netlink_data = NetlinkGenericDataIter {
data: &netlink_msg.data[GENL_HDRLEN..],
};
for nl_attr in netlink_data {
debug_pr!("nl_attr type {}, len {}", nl_attr._type, nl_attr.len);
if nl_attr._type == libc::CTRL_ATTR_MCAST_GROUPS as u16 {
let netlink_nested_attr = NetlinkGenericDataIter { data: nl_attr.data };
if let Some(mcast_group_id) =
parse_ctrl_mcast_group_id(netlink_nested_attr, &group_name)
{
return Some(mcast_group_id);
}
}
}
}
None
}
}