use std::{fmt::Debug, mem::size_of}; use crate::{ error::DecodeError, sys::{ nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLMSG_ALIGNTO, NLM_F_ACK, NLM_F_CREATE, }, MsgType, ProtocolFamily, }; /// /// The largest nf_tables netlink message is the set element message, which contains the /// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set /// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is /// a bit larger than 64 KBytes. pub fn nft_nlmsg_maxsize() -> u32 { u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 } #[inline] pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize { // align on a 4 bytes boundary (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1) } #[inline] pub const fn pad_netlink_object() -> usize { let size = size_of::(); pad_netlink_object_with_variable_size(size) } pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { ((x & 0xff00) >> 8) as u8 } pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { (x & 0x00ff) as u8 } pub struct NfNetlinkWriter<'a> { buf: &'a mut Vec, headers: Option<(usize, usize)>, } impl<'a> NfNetlinkWriter<'a> { pub fn new(buf: &'a mut Vec) -> NfNetlinkWriter<'a> { NfNetlinkWriter { buf, headers: None } } pub fn add_data_zeroed<'b>(&'b mut self, size: usize) -> &'b mut [u8] { let padded_size = pad_netlink_object_with_variable_size(size); let start = self.buf.len(); self.buf.resize(start + padded_size, 0); if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers { let mut hdr: &mut nlmsghdr = unsafe { std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr) }; hdr.nlmsg_len += padded_size as u32; } &mut self.buf[start..start + size] } pub fn extract_buffer(self) -> &'a mut Vec { self.buf } // rewrite of `__nftnl_nlmsg_build_hdr` pub fn write_header( &mut self, msg_type: u16, family: ProtocolFamily, flags: u16, seq: u32, ressource_id: Option, ) { if self.headers.is_some() { error!("Calling write_header while still holding headers open!?"); } let nlmsghdr_len = pad_netlink_object::(); let nfgenmsg_len = pad_netlink_object::(); let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len); let mut hdr: &mut nlmsghdr = unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) }; hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32; hdr.nlmsg_type = msg_type; // batch messages are not specific to the nftables subsystem if msg_type != NFNL_MSG_BATCH_BEGIN as u16 && msg_type != NFNL_MSG_BATCH_END as u16 { hdr.nlmsg_type |= (NFNL_SUBSYS_NFTABLES as u16) << 8; } hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags; hdr.nlmsg_seq = seq; let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len); let mut nfgenmsg: &mut nfgenmsg = unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) }; nfgenmsg.nfgen_family = family as u8; nfgenmsg.version = NFNETLINK_V0 as u8; nfgenmsg.res_id = ressource_id.unwrap_or(0); self.headers = Some(( self.buf.len() - (nlmsghdr_len + nfgenmsg_len), self.buf.len() - nfgenmsg_len, )); } pub fn finalize_writing_object(&mut self) { self.headers = None; } } pub trait AttributeDecoder { fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>; } pub trait NfNetlinkDeserializable: Sized { fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>; } pub trait NfNetlinkObject: Sized + AttributeDecoder + NfNetlinkDeserializable + NfNetlinkAttribute { const MSG_TYPE_ADD: u32; const MSG_TYPE_DEL: u32; fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { let raw_msg_type = match msg_type { MsgType::Add => Self::MSG_TYPE_ADD, MsgType::Del => Self::MSG_TYPE_DEL, } as u16; writer.write_header( raw_msg_type, self.get_family(), (if let MsgType::Add = msg_type { self.get_add_flags() } else { self.get_del_flags() } | NLM_F_ACK) as u16, seq, None, ); let buf = writer.add_data_zeroed(self.get_size()); unsafe { self.write_payload(buf.as_mut_ptr()); } writer.finalize_writing_object(); } fn get_family(&self) -> ProtocolFamily; fn set_family(&mut self, _family: ProtocolFamily) { // the default impl do nothing, because some types are family-agnostic } fn with_family(mut self, family: ProtocolFamily) -> Self { self.set_family(family); self } fn get_add_flags(&self) -> u32 { NLM_F_CREATE } fn get_del_flags(&self) -> u32 { 0 } } pub type NetlinkType = u16; pub trait NfNetlinkAttribute: Debug + Sized { // is it a nested argument that must be marked with a NLA_F_NESTED flag? fn is_nested(&self) -> bool { false } fn get_size(&self) -> usize { size_of::() } // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); unsafe fn write_payload(&self, addr: *mut u8); }