diff options
Diffstat (limited to 'src/nlmsg.rs')
-rw-r--r-- | src/nlmsg.rs | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/src/nlmsg.rs b/src/nlmsg.rs new file mode 100644 index 0000000..1c5b519 --- /dev/null +++ b/src/nlmsg.rs @@ -0,0 +1,182 @@ +use std::{fmt::Debug, mem::size_of}; + +use crate::{ + error::DecodeError, + sys::{ + nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + NFNL_SUBSYS_NFTABLES, NLMSG_ALIGNTO, NLM_F_ACK, NLM_F_CREATE, + }, + MsgType, ProtocolFamily, +}; +/// +/// The largest nf_tables netlink message is the set element message, which contains the +/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set +/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is +/// a bit larger than 64 KBytes. +pub fn nft_nlmsg_maxsize() -> u32 { + u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 +} + +#[inline] +pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize { + // align on a 4 bytes boundary + (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1) +} + +#[inline] +pub const fn pad_netlink_object<T>() -> usize { + let size = size_of::<T>(); + pad_netlink_object_with_variable_size(size) +} + +pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { + ((x & 0xff00) >> 8) as u8 +} + +pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { + (x & 0x00ff) as u8 +} + +pub struct NfNetlinkWriter<'a> { + buf: &'a mut Vec<u8>, + headers: Option<(usize, usize)>, +} + +impl<'a> NfNetlinkWriter<'a> { + pub fn new(buf: &'a mut Vec<u8>) -> NfNetlinkWriter<'a> { + NfNetlinkWriter { buf, headers: None } + } + + pub fn add_data_zeroed<'b>(&'b mut self, size: usize) -> &'b mut [u8] { + let padded_size = pad_netlink_object_with_variable_size(size); + let start = self.buf.len(); + self.buf.resize(start + padded_size, 0); + + if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers { + let mut hdr: &mut nlmsghdr = unsafe { + std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr) + }; + hdr.nlmsg_len += padded_size as u32; + } + + &mut self.buf[start..start + size] + } + + // rewrite of `__nftnl_nlmsg_build_hdr` + pub fn write_header( + &mut self, + msg_type: u16, + family: ProtocolFamily, + flags: u16, + seq: u32, + ressource_id: Option<u16>, + ) { + if self.headers.is_some() { + error!("Calling write_header while still holding headers open!?"); + } + + let nlmsghdr_len = pad_netlink_object::<nlmsghdr>(); + let nfgenmsg_len = pad_netlink_object::<nfgenmsg>(); + + let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len); + let mut hdr: &mut nlmsghdr = + unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) }; + hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32; + hdr.nlmsg_type = msg_type; + // batch messages are not specific to the nftables subsystem + if msg_type != NFNL_MSG_BATCH_BEGIN as u16 && msg_type != NFNL_MSG_BATCH_END as u16 { + hdr.nlmsg_type |= (NFNL_SUBSYS_NFTABLES as u16) << 8; + } + hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags; + hdr.nlmsg_seq = seq; + + let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len); + let mut nfgenmsg: &mut nfgenmsg = + unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) }; + nfgenmsg.nfgen_family = family as u8; + nfgenmsg.version = NFNETLINK_V0 as u8; + nfgenmsg.res_id = ressource_id.unwrap_or(0); + + self.headers = Some(( + self.buf.len() - (nlmsghdr_len + nfgenmsg_len), + self.buf.len() - nfgenmsg_len, + )); + } + + pub fn finalize_writing_object(&mut self) { + self.headers = None; + } +} + +pub trait AttributeDecoder { + fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>; +} + +pub trait NfNetlinkDeserializable: Sized { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>; +} + +pub trait NfNetlinkObject: + Sized + AttributeDecoder + NfNetlinkDeserializable + NfNetlinkAttribute +{ + const MSG_TYPE_ADD: u32; + const MSG_TYPE_DEL: u32; + + fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { + let raw_msg_type = match msg_type { + MsgType::Add => Self::MSG_TYPE_ADD, + MsgType::Del => Self::MSG_TYPE_DEL, + } as u16; + writer.write_header( + raw_msg_type, + self.get_family(), + (if let MsgType::Add = msg_type { + self.get_add_flags() + } else { + self.get_del_flags() + } | NLM_F_ACK) as u16, + seq, + None, + ); + let buf = writer.add_data_zeroed(self.get_size()); + unsafe { + self.write_payload(buf.as_mut_ptr()); + } + writer.finalize_writing_object(); + } + + fn get_family(&self) -> ProtocolFamily; + + fn set_family(&mut self, _family: ProtocolFamily) { + // the default impl do nothing, because some types are family-agnostic + } + + fn with_family(mut self, family: ProtocolFamily) -> Self { + self.set_family(family); + self + } + + fn get_add_flags(&self) -> u32 { + NLM_F_CREATE + } + + fn get_del_flags(&self) -> u32 { + 0 + } +} + +pub type NetlinkType = u16; + +pub trait NfNetlinkAttribute: Debug + Sized { + // is it a nested argument that must be marked with a NLA_F_NESTED flag? + fn is_nested(&self) -> bool { + false + } + + fn get_size(&self) -> usize { + size_of::<Self>() + } + + // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); + unsafe fn write_payload(&self, addr: *mut u8); +} |