diff options
author | Simon THOBY <git@nightmared.fr> | 2023-01-09 18:54:11 +0000 |
---|---|---|
committer | Simon THOBY <git@nightmared.fr> | 2023-01-09 18:54:11 +0000 |
commit | d5b9ec5185a27414286ee303eb3d21ce3069db09 (patch) | |
tree | 369eb90e8a2da307d7cd8f0b15a3318bbdba0003 /src | |
parent | 3e48e7efa516183d623f80d2e4e393cecc2acde9 (diff) | |
parent | c3e3773cccd01f80f2d72a7691e0654d304e6b2d (diff) |
Merge branch 'no_mnl' into 'master'
experimental support for a full-rust rewrite of the codebase (no libnftnl/libmnl anymore)
See merge request rustwall/rustables!16
Diffstat (limited to 'src')
38 files changed, 3668 insertions, 3183 deletions
diff --git a/src/batch.rs b/src/batch.rs index 198e8d0..b5c88b8 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -1,31 +1,29 @@ -use crate::{MsgType, NlMsg}; -use crate::sys::{self as sys, libc}; -use std::ffi::c_void; -use std::os::raw::c_char; -use std::ptr; +use libc; + use thiserror::Error; +use crate::error::QueryError; +use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; +use crate::sys::NFNL_SUBSYS_NFTABLES; +use crate::{MsgType, ProtocolFamily}; + +use nix::sys::socket::{ + self, AddressFamily, MsgFlags, NetlinkAddr, SockAddr, SockFlag, SockProtocol, SockType, +}; + /// Error while communicating with netlink. #[derive(Error, Debug)] #[error("Error while communicating with netlink")] pub struct NetlinkError(()); -#[cfg(feature = "query")] -/// Check if the kernel supports batched netlink messages to netfilter. -pub fn batch_is_supported() -> std::result::Result<bool, NetlinkError> { - match unsafe { sys::nftnl_batch_is_supported() } { - 1 => Ok(true), - 0 => Ok(false), - _ => Err(NetlinkError(())), - } -} - -/// A batch of netfilter messages to be performed in one atomic operation. Corresponds to -/// `nftnl_batch` in libnftnl. +/// A batch of netfilter messages to be performed in one atomic operation. pub struct Batch { - pub(crate) batch: *mut sys::nftnl_batch, - pub(crate) seq: u32, - pub(crate) is_empty: bool, + buf: Box<Vec<u8>>, + // the 'static lifetime here is a cheat, as the writer can only be used as long + // as `self.buf` exists. This is why this member must never be exposed directly to + // the rest of the crate (let alone publicly). + writer: NfNetlinkWriter<'static>, + seq: u32, } impl Batch { @@ -33,48 +31,40 @@ impl Batch { /// /// [default page size]: fn.default_batch_page_size.html pub fn new() -> Self { - Self::with_page_size(default_batch_page_size()) - } - - pub unsafe fn from_raw(batch: *mut sys::nftnl_batch, seq: u32) -> Self { - Batch { - batch, + // TODO: use a pinned Box ? + let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize)); + let mut writer = NfNetlinkWriter::new(unsafe { + std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>) + }); + let seq = 0; + writer.write_header( + libc::NFNL_MSG_BATCH_BEGIN as u16, + ProtocolFamily::Unspec, + 0, seq, - // we assume this batch is not empty by default - is_empty: false, + Some(libc::NFNL_SUBSYS_NFTABLES as u16), + ); + writer.finalize_writing_object(); + Batch { + buf, + writer, + seq: seq + 1, } } - /// Creates a new nftnl batch with the given batch size. - pub fn with_page_size(batch_page_size: u32) -> Self { - let batch = try_alloc!(unsafe { - sys::nftnl_batch_alloc(batch_page_size, crate::nft_nlmsg_maxsize()) - }); - let mut this = Batch { - batch, - seq: 0, - is_empty: true, - }; - this.write_begin_msg(); - this - } - /// Adds the given message to this batch. - pub fn add<T: NlMsg>(&mut self, msg: &T, msg_type: MsgType) { + pub fn add<T: NfNetlinkObject>(&mut self, msg: &T, msg_type: MsgType) { trace!("Writing NlMsg with seq {} to batch", self.seq); - unsafe { msg.write(self.current(), self.seq, msg_type) }; - self.is_empty = false; - self.next() + msg.add_or_remove(&mut self.writer, msg_type, self.seq); + self.seq += 1; } - /// Adds all the messages in the given iterator to this batch. If any message fails to be - /// added the error for that failure is returned and all messages up until that message stay - /// added to the batch. - pub fn add_iter<T, I>(&mut self, msg_iter: I, msg_type: MsgType) - where - T: NlMsg, - I: Iterator<Item = T>, - { + /// Adds all the messages in the given iterator to this batch. + pub fn add_iter<T: NfNetlinkObject, I: Iterator<Item = T>>( + &mut self, + msg_iter: I, + msg_type: MsgType, + ) { for msg in msg_iter { self.add(&msg, msg_type); } @@ -86,109 +76,46 @@ impl Batch { /// Return None if there is no object in the batch (this could block forever). /// /// [`FinalizedBatch`]: struct.FinalizedBatch.html - pub fn finalize(mut self) -> Option<FinalizedBatch> { - self.write_end_msg(); - if self.is_empty { - return None; - } - Some(FinalizedBatch { batch: self }) - } - - fn current(&self) -> *mut c_void { - unsafe { sys::nftnl_batch_buffer(self.batch) } - } - - fn next(&mut self) { - if unsafe { sys::nftnl_batch_update(self.batch) } < 0 { - // See try_alloc definition. - std::process::abort(); - } - self.seq += 1; - } - - fn write_begin_msg(&mut self) { - unsafe { sys::nftnl_batch_begin(self.current() as *mut c_char, self.seq) }; - self.next(); - } - - fn write_end_msg(&mut self) { - unsafe { sys::nftnl_batch_end(self.current() as *mut c_char, self.seq) }; - self.next(); - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_batch { - self.batch as *const sys::nftnl_batch - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_batch { - self.batch - } -} - -impl Drop for Batch { - fn drop(&mut self) { - unsafe { sys::nftnl_batch_free(self.batch) }; - } -} - -/// A wrapper over [`Batch`], guaranteed to start with a proper batch begin and end with a proper -/// batch end message. Created from [`Batch::finalize`]. -/// -/// Can be turned into an iterator of the byte buffers to send to netlink to execute this batch. -/// -/// [`Batch`]: struct.Batch.html -/// [`Batch::finalize`]: struct.Batch.html#method.finalize -pub struct FinalizedBatch { - batch: Batch, -} - -impl FinalizedBatch { - /// Returns the iterator over byte buffers to send to netlink. - pub fn iter(&mut self) -> Iter<'_> { - let num_pages = unsafe { sys::nftnl_batch_iovec_len(self.batch.batch) as usize }; - let mut iovecs = vec![ - libc::iovec { - iov_base: ptr::null_mut(), - iov_len: 0, - }; - num_pages - ]; - let iovecs_ptr = iovecs.as_mut_ptr(); - unsafe { - sys::nftnl_batch_iovec(self.batch.batch, iovecs_ptr, num_pages as u32); - } - Iter { - iovecs: iovecs.into_iter(), - _marker: ::std::marker::PhantomData, + pub fn finalize(mut self) -> Vec<u8> { + self.writer.write_header( + libc::NFNL_MSG_BATCH_END as u16, + ProtocolFamily::Unspec, + 0, + self.seq, + Some(NFNL_SUBSYS_NFTABLES as u16), + ); + self.writer.finalize_writing_object(); + *self.buf + } + + pub fn send(self) -> Result<(), QueryError> { + use crate::query::{recv_and_process, socket_close_wrapper}; + + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(QueryError::NetlinkOpenError)?; + + let max_seq = self.seq - 1; + + let addr = SockAddr::Netlink(NetlinkAddr::new(0, 0)); + // while this bind() is not strictly necessary, strace have trouble decoding the messages + // if we don't + socket::bind(sock, &addr).expect("bind"); + + let to_send = self.finalize(); + let mut sent = 0; + while sent != to_send.len() { + sent += socket::send(sock, &to_send[sent..], MsgFlags::empty()) + .map_err(QueryError::NetlinkSendError)?; } - } -} - -impl<'a> IntoIterator for &'a mut FinalizedBatch { - type Item = &'a [u8]; - type IntoIter = Iter<'a>; - - fn into_iter(self) -> Iter<'a> { - self.iter() - } -} - -pub struct Iter<'a> { - iovecs: ::std::vec::IntoIter<libc::iovec>, - _marker: ::std::marker::PhantomData<&'a ()>, -} - -impl<'a> Iterator for Iter<'a> { - type Item = &'a [u8]; - fn next(&mut self) -> Option<&'a [u8]> { - self.iovecs.next().map(|iovec| unsafe { - ::std::slice::from_raw_parts(iovec.iov_base as *const u8, iovec.iov_len) - }) + Ok(socket_close_wrapper(sock, move |sock| { + recv_and_process(sock, Some(max_seq), None, &mut ()) + })?) } } diff --git a/src/chain.rs b/src/chain.rs index a942a37..37e4cb3 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,41 +1,85 @@ -use crate::{MsgType, Table}; -use crate::sys::{self as sys, libc}; -#[cfg(feature = "query")] -use std::convert::TryFrom; -use std::{ - ffi::{c_void, CStr, CString}, - fmt, - os::raw::c_char, - rc::Rc, +use libc::{NF_ACCEPT, NF_DROP}; +use rustables_macros::nfnetlink_struct; + +use crate::error::{DecodeError, QueryError}; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable, NfNetlinkObject}; +use crate::sys::{ + NFTA_CHAIN_FLAGS, NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_POLICY, NFTA_CHAIN_TABLE, + NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, + NFT_MSG_NEWCHAIN, }; +use crate::{Batch, ProtocolFamily, Table}; +use std::fmt::Debug; -pub type Priority = i32; +pub type ChainPriority = i32; /// The netfilter event hooks a chain can register for. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u16)] -pub enum Hook { +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(i32)] +pub enum HookClass { /// Hook into the pre-routing stage of netfilter. Corresponds to `NF_INET_PRE_ROUTING`. - PreRouting = libc::NF_INET_PRE_ROUTING as u16, + PreRouting = libc::NF_INET_PRE_ROUTING, /// Hook into the input stage of netfilter. Corresponds to `NF_INET_LOCAL_IN`. - In = libc::NF_INET_LOCAL_IN as u16, + In = libc::NF_INET_LOCAL_IN, /// Hook into the forward stage of netfilter. Corresponds to `NF_INET_FORWARD`. - Forward = libc::NF_INET_FORWARD as u16, + Forward = libc::NF_INET_FORWARD, /// Hook into the output stage of netfilter. Corresponds to `NF_INET_LOCAL_OUT`. - Out = libc::NF_INET_LOCAL_OUT as u16, + Out = libc::NF_INET_LOCAL_OUT, /// Hook into the post-routing stage of netfilter. Corresponds to `NF_INET_POST_ROUTING`. - PostRouting = libc::NF_INET_POST_ROUTING as u16, + PostRouting = libc::NF_INET_POST_ROUTING, +} + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct Hook { + /// Define the action netfilter will apply to packets processed by this chain, but that did not match any rules in it. + #[field(NFTA_HOOK_HOOKNUM)] + class: u32, + #[field(NFTA_HOOK_PRIORITY)] + priority: u32, +} + +impl Hook { + pub fn new(class: HookClass, priority: ChainPriority) -> Self { + Hook::default() + .with_class(class as u32) + .with_priority(priority as u32) + } } /// A chain policy. Decides what to do with a packet that was processed by the chain but did not /// match any rules. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u32)] -pub enum Policy { +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(i32)] +pub enum ChainPolicy { /// Accept the packet. - Accept = libc::NF_ACCEPT as u32, + Accept = NF_ACCEPT, /// Drop the packet. - Drop = libc::NF_DROP as u32, + Drop = NF_DROP, +} + +impl NfNetlinkAttribute for ChainPolicy { + fn get_size(&self) -> usize { + (*self as i32).get_size() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + (*self as i32).write_payload(addr); + } +} + +impl NfNetlinkDeserializable for ChainPolicy { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (v, remaining_data) = i32::deserialize(buf)?; + Ok(( + match v { + NF_ACCEPT => ChainPolicy::Accept, + NF_DROP => ChainPolicy::Accept, + _ => return Err(DecodeError::UnknownChainPolicy), + }, + remaining_data, + )) + } } /// Base chain type. @@ -53,240 +97,117 @@ pub enum ChainType { } impl ChainType { - fn as_c_str(&self) -> &'static [u8] { + fn as_str(&self) -> &'static str { match *self { - ChainType::Filter => b"filter\0", - ChainType::Route => b"route\0", - ChainType::Nat => b"nat\0", + ChainType::Filter => "filter", + ChainType::Route => "route", + ChainType::Nat => "nat", } } } -/// Abstraction of a `nftnl_chain`. Chains reside inside [`Table`]s and they hold [`Rule`]s. -/// -/// There are two types of chains, "base chain" and "regular chain". See [`set_hook`] for more -/// details. +impl NfNetlinkAttribute for ChainType { + fn get_size(&self) -> usize { + self.as_str().len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + self.as_str().to_string().write_payload(addr); + } +} + +impl NfNetlinkDeserializable for ChainType { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (s, remaining_data) = String::deserialize(buf)?; + Ok(( + match s.as_str() { + "filter" => ChainType::Filter, + "route" => ChainType::Route, + "nat" => ChainType::Nat, + _ => return Err(DecodeError::UnknownChainType), + }, + remaining_data, + )) + } +} + +/// Abstraction over an nftable chain. Chains reside inside [`Table`]s and they hold [`Rule`]s. /// /// [`Table`]: struct.Table.html /// [`Rule`]: struct.Rule.html -/// [`set_hook`]: #method.set_hook +#[derive(PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Chain { - pub(crate) chain: *mut sys::nftnl_chain, - pub(crate) table: Rc<Table>, + family: ProtocolFamily, + #[field(NFTA_CHAIN_TABLE)] + table: String, + #[field(NFTA_CHAIN_NAME)] + name: String, + #[field(NFTA_CHAIN_HOOK)] + hook: Hook, + #[field(NFTA_CHAIN_POLICY)] + policy: ChainPolicy, + #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")] + chain_type: ChainType, + #[field(NFTA_CHAIN_FLAGS)] + flags: u32, + #[field(NFTA_CHAIN_USERDATA)] + userdata: Vec<u8>, } impl Chain { - /// Creates a new chain instance inside the given [`Table`] and with the given name. + /// Creates a new chain instance inside the given [`Table`]. /// /// [`Table`]: struct.Table.html - pub fn new<T: AsRef<CStr>>(name: &T, table: Rc<Table>) -> Chain { - unsafe { - let chain = try_alloc!(sys::nftnl_chain_alloc()); - sys::nftnl_chain_set_u32( - chain, - sys::NFTNL_CHAIN_FAMILY as u16, - table.get_family() as u32, - ); - sys::nftnl_chain_set_str( - chain, - sys::NFTNL_CHAIN_TABLE as u16, - table.get_name().as_ptr(), - ); - sys::nftnl_chain_set_str(chain, sys::NFTNL_CHAIN_NAME as u16, name.as_ref().as_ptr()); - Chain { chain, table } - } - } - - pub unsafe fn from_raw(chain: *mut sys::nftnl_chain, table: Rc<Table>) -> Self { - Chain { chain, table } - } + pub fn new(table: &Table) -> Chain { + let mut chain = Chain::default(); + chain.family = table.get_family(); - /// Sets the hook and priority for this chain. Without calling this method the chain will - /// become a "regular chain" without any hook and will thus not receive any traffic unless - /// some rule forward packets to it via goto or jump verdicts. - /// - /// By calling `set_hook` with a hook the chain that is created will be registered with that - /// hook and is thus a "base chain". A "base chain" is an entry point for packets from the - /// networking stack. - pub fn set_hook(&mut self, hook: Hook, priority: Priority) { - unsafe { - sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_HOOKNUM as u16, hook as u32); - sys::nftnl_chain_set_s32(self.chain, sys::NFTNL_CHAIN_PRIO as u16, priority); + if let Some(table_name) = table.get_name() { + chain.set_table(table_name); } - } - /// Set the type of a base chain. This only applies if the chain has been registered - /// with a hook by calling `set_hook`. - pub fn set_type(&mut self, chain_type: ChainType) { - unsafe { - sys::nftnl_chain_set_str( - self.chain, - sys::NFTNL_CHAIN_TYPE as u16, - chain_type.as_c_str().as_ptr() as *const c_char, - ); - } + chain } - /// Sets the default policy for this chain. That means what action netfilter will apply to - /// packets processed by this chain, but that did not match any rules in it. - pub fn set_policy(&mut self, policy: Policy) { - unsafe { - sys::nftnl_chain_set_u32(self.chain, sys::NFTNL_CHAIN_POLICY as u16, policy as u32); - } - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16); - if ptr == std::ptr::null() { - return None; - } - Some(CStr::from_ptr(ptr)) - } - } - - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_chain_set_str(self.chain, sys::NFTNL_CHAIN_USERDATA as u16, data.as_ptr()); - } - } - - /// Returns the name of this chain. - pub fn get_name(&self) -> &CStr { - unsafe { - let ptr = sys::nftnl_chain_get_str(self.chain, sys::NFTNL_CHAIN_NAME as u16); - if ptr.is_null() { - panic!("Impossible situation: retrieving the name of a chain failed") - } else { - CStr::from_ptr(ptr) - } - } - } - - /// Returns a textual description of the chain. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_chain_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.chain, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Returns a reference to the [`Table`] this chain belongs to. - /// - /// [`Table`]: struct.Table.html - pub fn get_table(&self) -> Rc<Table> { - self.table.clone() - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_chain { - self.chain as *const sys::nftnl_chain - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_chain { - self.chain + /// Appends this chain to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self } } -impl fmt::Debug for Chain { - /// Returns a string representation of the chain. - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "{:?}", self.get_str()) - } -} +impl NfNetlinkObject for Chain { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWCHAIN; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELCHAIN; -impl PartialEq for Chain { - fn eq(&self, other: &Self) -> bool { - self.get_table() == other.get_table() && self.get_name() == other.get_name() + fn get_family(&self) -> ProtocolFamily { + self.family } -} -unsafe impl crate::NlMsg for Chain { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let raw_msg_type = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWCHAIN, - MsgType::Del => libc::NFT_MSG_DELCHAIN, - }; - let flags: u16 = match msg_type { - MsgType::Add => (libc::NLM_F_ACK | libc::NLM_F_CREATE) as u16, - MsgType::Del => libc::NLM_F_ACK as u16, - } | libc::NLM_F_ACK as u16; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - raw_msg_type as u16, - self.table.get_family() as u16, - flags, - seq, - ); - sys::nftnl_chain_nlmsg_build_payload(header, self.chain); + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -impl Drop for Chain { - fn drop(&mut self) { - unsafe { sys::nftnl_chain_free(self.chain) }; - } -} - -#[cfg(feature = "query")] -pub fn get_chains_cb<'a>( - header: &libc::nlmsghdr, - (table, chains): &mut (&Rc<Table>, &mut Vec<Chain>), -) -> libc::c_int { - unsafe { - let chain = sys::nftnl_chain_alloc(); - if chain == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; - } - let err = sys::nftnl_chain_nlmsg_parse(header, chain); - if err < 0 { - error!("Failed to parse nelink chain message - {}", err); - sys::nftnl_chain_free(chain); - return err; - } - - let table_name = CStr::from_ptr(sys::nftnl_chain_get_str( - chain, - sys::NFTNL_CHAIN_TABLE as u16, - )); - let family = sys::nftnl_chain_get_u32(chain, sys::NFTNL_CHAIN_FAMILY as u16); - let family = match crate::ProtoFamily::try_from(family as i32) { - Ok(family) => family, - Err(crate::InvalidProtocolFamily) => { - error!("The netlink table didn't have a valid protocol family !?"); - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_ERROR; +pub fn list_chains_for_table(table: &Table) -> Result<Vec<Chain>, QueryError> { + let mut result = Vec::new(); + crate::query::list_objects_with_data( + libc::NFT_MSG_GETCHAIN as u16, + &|chain: Chain, (table, chains): &mut (&Table, &mut Vec<Chain>)| { + if chain.get_table() == table.get_name() { + chains.push(chain); + } else { + info!( + "Ignoring chain {:?} because it doesn't map the table {:?}", + chain.get_name(), + table.get_name() + ); } - }; - - if table_name != table.get_name() { - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; - } - - if family != crate::ProtoFamily::Unspec && family != table.get_family() { - sys::nftnl_chain_free(chain); - return mnl::mnl_sys::MNL_CB_OK; - } - - chains.push(Chain::from_raw(chain, table.clone())); - } - mnl::mnl_sys::MNL_CB_OK -} - -#[cfg(feature = "query")] -pub fn list_chains_for_table(table: Rc<Table>) -> Result<Vec<Chain>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETCHAIN as u16, get_chains_cb, &table, None) + Ok(()) + }, + None, + &mut (&table, &mut result), + )?; + Ok(result) } diff --git a/src/chain_methods.rs b/src/chain_methods.rs deleted file mode 100644 index d384c35..0000000 --- a/src/chain_methods.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::{Batch, Chain, Hook, MsgType, Policy, Table}; -use std::ffi::CString; -use std::rc::Rc; - - -/// A helper trait over [`crate::Chain`]. -pub trait ChainMethods { - /// Creates a new Chain instance from a [`crate::Hook`] over a [`crate::Table`]. - fn from_hook(hook: Hook, table: Rc<Table>) -> Self - where Self: std::marker::Sized; - /// Adds a [`crate::Policy`] to the current Chain. - fn verdict(self, policy: Policy) -> Self; - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - - -impl ChainMethods for Chain { - fn from_hook(hook: Hook, table: Rc<Table>) -> Self { - let chain_name = match hook { - Hook::PreRouting => "prerouting", - Hook::Out => "out", - Hook::PostRouting => "postrouting", - Hook::Forward => "forward", - Hook::In => "in", - }; - let chain_name = CString::new(chain_name).unwrap(); - let mut chain = Chain::new(&chain_name, table); - chain.set_hook(hook, 0); - chain - } - fn verdict(mut self, policy: Policy) -> Self { - self.set_policy(policy); - self - } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, MsgType::Add); - self - } -} - diff --git a/src/data_type.rs b/src/data_type.rs new file mode 100644 index 0000000..43a7f1a --- /dev/null +++ b/src/data_type.rs @@ -0,0 +1,42 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +pub trait DataType { + const TYPE: u32; + const LEN: u32; + + fn data(&self) -> Vec<u8>; +} + +impl DataType for Ipv4Addr { + const TYPE: u32 = 7; + const LEN: u32 = 4; + + fn data(&self) -> Vec<u8> { + self.octets().to_vec() + } +} + +impl DataType for Ipv6Addr { + const TYPE: u32 = 8; + const LEN: u32 = 16; + + fn data(&self) -> Vec<u8> { + self.octets().to_vec() + } +} + +impl<const N: usize> DataType for [u8; N] { + const TYPE: u32 = 5; + const LEN: u32 = N as u32; + + fn data(&self) -> Vec<u8> { + self.to_vec() + } +} + +pub fn ip_to_vec(ip: IpAddr) -> Vec<u8> { + match ip { + IpAddr::V4(x) => x.octets().to_vec(), + IpAddr::V6(x) => x.octets().to_vec(), + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..f6b6247 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,180 @@ +use std::string::FromUtf8Error; + +use nix::errno::Errno; +use thiserror::Error; + +use crate::sys::nlmsgerr; + +#[derive(Error, Debug)] +pub enum DecodeError { + #[error("The buffer is too small to hold a valid message")] + BufTooSmall, + + #[error("The message is too small")] + NlMsgTooSmall, + + #[error("The message holds unexpected data")] + InvalidDataSize, + + #[error("Invalid subsystem, expected NFTABLES")] + InvalidSubsystem(u8), + + #[error("Invalid version, expected NFNETLINK_V0")] + InvalidVersion(u8), + + #[error("Invalid port ID")] + InvalidPortId(u32), + + #[error("Invalid sequence number")] + InvalidSeq(u32), + + #[error("The generation number was bumped in the kernel while the operation was running, interrupting it")] + ConcurrentGenerationUpdate, + + #[error("Unsupported message type")] + UnsupportedType(u16), + + #[error("Invalid attribute type")] + InvalidAttributeType, + + #[error("Invalid type for a chain")] + UnknownChainType, + + #[error("Invalid policy for a chain")] + UnknownChainPolicy, + + #[error("Unknown type for a Meta expression")] + UnknownMetaType(u32), + + #[error("Unsupported value for an icmp reject type")] + UnknownRejectType(u32), + + #[error("Unsupported value for an icmp code in a reject expression")] + UnknownIcmpCode(u8), + + #[error("Invalid value for a register")] + UnknownRegister(u32), + + #[error("Invalid type for a verdict expression")] + UnknownVerdictType(i32), + + #[error("Invalid type for a nat expression")] + UnknownNatType(i32), + + #[error("Invalid type for a payload expression")] + UnknownPayloadType(u32), + + #[error("Invalid type for a compare expression")] + UnknownCmpOp(u32), + + #[error("Invalid type for a conntrack key")] + UnknownConntrackKey(u32), + + #[error("Unsupported value for a link layer header field")] + UnknownLinkLayerHeaderField(u32, u32), + + #[error("Unsupported value for an IPv4 header field")] + UnknownIPv4HeaderField(u32, u32), + + #[error("Unsupported value for an IPv6 header field")] + UnknownIPv6HeaderField(u32, u32), + + #[error("Unsupported value for a TCP header field")] + UnknownTCPHeaderField(u32, u32), + + #[error("Unsupported value for an UDP header field")] + UnknownUDPHeaderField(u32, u32), + + #[error("Unsupported value for an ICMPv6 header field")] + UnknownICMPv6HeaderField(u32, u32), + + #[error("Missing the 'base' attribute to deserialize the payload object")] + PayloadMissingBase, + + #[error("Missing the 'offset' attribute to deserialize the payload object")] + PayloadMissingOffset, + + #[error("Missing the 'len' attribute to deserialize the payload object")] + PayloadMissingLen, + + #[error("The object does not contain a name for the expression being parsed")] + MissingExpressionName, + + #[error("Unsupported attribute type")] + UnsupportedAttributeType(u16), + + #[error("Unexpected message type")] + UnexpectedType(u16), + + #[error("The decoded String is not UTF8 compliant")] + StringDecodeFailure(#[from] FromUtf8Error), + + #[error("Invalid value for a protocol family")] + UnknownProtocolFamily(i32), + + #[error("A custom error occured")] + Custom(Box<dyn std::error::Error + 'static>), +} + +#[derive(thiserror::Error, Debug)] +pub enum BuilderError { + #[error("The length of the arguments are not compatible with each other")] + IncompatibleLength, + + #[error("The table does not have a name")] + MissingTableName, + + #[error("Missing information in the chain to create a rule")] + MissingChainInformationError, + + #[error("Missing name for the set")] + MissingSetName, + + #[error("The interface name is too long to be written")] + InterfaceNameTooLong, + + #[error("The log prefix string is more than 127 characters long")] + TooLongLogPrefix, +} + +#[derive(thiserror::Error, Debug)] +pub enum QueryError { + #[error("Unable to open netlink socket to netfilter")] + NetlinkOpenError(#[source] nix::Error), + + #[error("Unable to send netlink command to netfilter")] + NetlinkSendError(#[source] nix::Error), + + #[error("Error while reading from netlink socket")] + NetlinkRecvError(#[source] nix::Error), + + #[error("Error while processing an incoming netlink message")] + ProcessNetlinkError(#[from] DecodeError), + + #[error("Error while building netlink objects in Rust")] + BuilderError(#[from] BuilderError), + + #[error("Error received from the kernel")] + NetlinkError(nlmsgerr), + + #[error("Custom error when customizing the query")] + InitError(#[from] Box<dyn std::error::Error + Send + 'static>), + + #[error("Couldn't allocate a netlink object, out of memory ?")] + NetlinkAllocationFailed, + + #[error("This socket is not a netlink socket")] + NotNetlinkSocket, + + #[error("Couldn't retrieve information on a socket")] + RetrievingSocketInfoFailed, + + #[error("Only a part of the message was sent")] + TruncatedSend, + + #[error("Got a message without the NLM_F_MULTI flag, but a maximum sequence number was not specified")] + UndecidableMessageTermination, + + #[error("Couldn't close the socket")] + CloseFailed(#[source] Errno), +} diff --git a/src/expr/bitwise.rs b/src/expr/bitwise.rs index d34d22c..fb40a04 100644 --- a/src/expr/bitwise.rs +++ b/src/expr/bitwise.rs @@ -1,69 +1,47 @@ -use super::{Expression, Rule, ToSlice}; -use crate::sys::{self, libc}; -use std::ffi::c_void; -use std::os::raw::c_char; - -/// Expression for performing bitwise masking and XOR on the data in a register. -pub struct Bitwise<M: ToSlice, X: ToSlice> { - mask: M, - xor: X, +use rustables_macros::nfnetlink_struct; + +use super::{Expression, Register}; +use crate::error::BuilderError; +use crate::parser_impls::NfNetlinkData; +use crate::sys::{ + NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, NFTA_BITWISE_XOR, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Bitwise { + #[field(NFTA_BITWISE_SREG)] + sreg: Register, + #[field(NFTA_BITWISE_DREG)] + dreg: Register, + #[field(NFTA_BITWISE_LEN)] + len: u32, + #[field(NFTA_BITWISE_MASK)] + mask: NfNetlinkData, + #[field(NFTA_BITWISE_XOR)] + xor: NfNetlinkData, } -impl<M: ToSlice, X: ToSlice> Bitwise<M, X> { - /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and - /// then performs xor with the value in `xor`. - pub fn new(mask: M, xor: X) -> Self { - Self { mask, xor } +impl Expression for Bitwise { + fn get_name() -> &'static str { + "bitwise" } } -impl<M: ToSlice, X: ToSlice> Expression for Bitwise<M, X> { - fn get_raw_name() -> *const c_char { - b"bitwise\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - let mask = self.mask.to_slice(); - let xor = self.xor.to_slice(); - assert!(mask.len() == xor.len()); - let len = mask.len() as u32; - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_BITWISE_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_BITWISE_DREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_BITWISE_LEN as u16, len); - - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_BITWISE_MASK as u16, - mask.as_ref() as *const _ as *const c_void, - len, - ); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_BITWISE_XOR as u16, - xor.as_ref() as *const _ as *const c_void, - len, - ); - - expr +impl Bitwise { + /// Returns a new `Bitwise` instance that first masks the value it's applied to with `mask` and + /// then performs xor with the value in `xor` + pub fn new(mask: impl Into<Vec<u8>>, xor: impl Into<Vec<u8>>) -> Result<Self, BuilderError> { + let mask = mask.into(); + let xor = xor.into(); + if mask.len() != xor.len() { + return Err(BuilderError::IncompatibleLength); } + Ok(Bitwise::default() + .with_sreg(Register::Reg1) + .with_dreg(Register::Reg1) + .with_len(mask.len() as u32) + .with_xor(NfNetlinkData::default().with_value(xor)) + .with_mask(NfNetlinkData::default().with_value(mask))) } } - -#[macro_export] -macro_rules! nft_expr_bitwise { - (mask $mask:expr,xor $xor:expr) => { - $crate::expr::Bitwise::new($mask, $xor) - }; -} diff --git a/src/expr/cmp.rs b/src/expr/cmp.rs index f6ea900..86d3587 100644 --- a/src/expr/cmp.rs +++ b/src/expr/cmp.rs @@ -1,187 +1,64 @@ -use super::{DeserializationError, Expression, Rule, ToSlice}; -use crate::sys::{self, libc}; -use std::{ - borrow::Cow, - ffi::{c_void, CString}, - os::raw::c_char, +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::{ + parser_impls::NfNetlinkData, + sys::{ + NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFT_CMP_EQ, NFT_CMP_GT, NFT_CMP_GTE, NFT_CMP_LT, + NFT_CMP_LTE, NFT_CMP_NEQ, + }, }; +use super::{Expression, Register}; + /// Comparison operator. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[nfnetlink_enum(u32, nested = true)] pub enum CmpOp { /// Equals. - Eq, + Eq = NFT_CMP_EQ, /// Not equal. - Neq, + Neq = NFT_CMP_NEQ, /// Less than. - Lt, + Lt = NFT_CMP_LT, /// Less than, or equal. - Lte, + Lte = NFT_CMP_LTE, /// Greater than. - Gt, + Gt = NFT_CMP_GT, /// Greater than, or equal. - Gte, -} - -impl CmpOp { - /// Returns the corresponding `NFT_*` constant for this comparison operation. - pub fn to_raw(self) -> u32 { - use self::CmpOp::*; - match self { - Eq => libc::NFT_CMP_EQ as u32, - Neq => libc::NFT_CMP_NEQ as u32, - Lt => libc::NFT_CMP_LT as u32, - Lte => libc::NFT_CMP_LTE as u32, - Gt => libc::NFT_CMP_GT as u32, - Gte => libc::NFT_CMP_GTE as u32, - } - } - - pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { - use self::CmpOp::*; - match val as i32 { - libc::NFT_CMP_EQ => Ok(Eq), - libc::NFT_CMP_NEQ => Ok(Neq), - libc::NFT_CMP_LT => Ok(Lt), - libc::NFT_CMP_LTE => Ok(Lte), - libc::NFT_CMP_GT => Ok(Gt), - libc::NFT_CMP_GTE => Ok(Gte), - _ => Err(DeserializationError::InvalidValue), - } - } + Gte = NFT_CMP_GTE, } /// Comparator expression. Allows comparing the content of the netfilter register with any value. -#[derive(Debug, PartialEq)] -pub struct Cmp<T> { +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct] +pub struct Cmp { + #[field(NFTA_CMP_SREG)] + sreg: Register, + #[field(NFTA_CMP_OP)] op: CmpOp, - data: T, + #[field(NFTA_CMP_DATA)] + data: NfNetlinkData, } -impl<T: ToSlice> Cmp<T> { +impl Cmp { /// Returns a new comparison expression comparing the value loaded in the register with the /// data in `data` using the comparison operator `op`. - pub fn new(op: CmpOp, data: T) -> Self { - Cmp { op, data } - } -} - -impl<T: ToSlice> Expression for Cmp<T> { - fn get_raw_name() -> *const c_char { - b"cmp\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - let data = self.data.to_slice(); - trace!("Creating a cmp expr comparing with data {:?}", data); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CMP_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16, self.op.to_raw()); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_CMP_DATA as u16, - data.as_ptr() as *const c_void, - data.len() as u32, - ); - - expr - } - } -} - -impl<const N: usize> Expression for Cmp<[u8; N]> { - fn get_raw_name() -> *const c_char { - Cmp::<u8>::get_raw_name() - } - - /// The raw data contained inside `Cmp` expressions can only be deserialized to arrays of - /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your - /// responsibility to provide the correct length of the byte data. If the data size is invalid, - /// you will get the error `DeserializationError::InvalidDataSize`. - /// - /// Example (warning, no error checking!): - /// ```rust - /// use std::ffi::CString; - /// use std::net::Ipv4Addr; - /// use std::rc::Rc; - /// - /// use rustables::{Chain, expr::{Cmp, CmpOp}, ProtoFamily, Rule, Table}; - /// - /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); - /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); - /// let mut rule = Rule::new(chain); - /// rule.add_expr(&Cmp::new(CmpOp::Eq, 1337u16)); - /// for expr in Rc::new(rule).get_exprs() { - /// println!("{:?}", expr.decode_expr::<Cmp<[u8; 2]>>().unwrap()); - /// } - /// ``` - /// These limitations occur because casting bytes to any type of the same size - /// as the raw input would be *extremely* dangerous in terms of memory safety. - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let ref_len = std::mem::size_of::<[u8; N]>() as u32; - let mut data_len = 0; - let data = sys::nftnl_expr_get( - expr, - sys::NFTNL_EXPR_CMP_DATA as u16, - &mut data_len as *mut u32, - ); - - if data.is_null() { - return Err(DeserializationError::NullPointer); - } else if data_len != ref_len { - return Err(DeserializationError::InvalidDataSize); - } - - let data = *(data as *const [u8; N]); - - let op = CmpOp::from_raw(sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CMP_OP as u16))?; - Ok(Cmp { op, data }) - } - } - - // call to the other implementation to generate the expression - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { + pub fn new(op: CmpOp, data: impl Into<Vec<u8>>) -> Self { Cmp { - data: &self.data as &[u8], - op: self.op, + sreg: Some(Register::Reg1), + op: Some(op), + data: Some(NfNetlinkData::default().with_value(data.into())), } - .to_expr(rule) } } -#[macro_export(local_inner_macros)] -macro_rules! nft_expr_cmp { - (@cmp_op ==) => { - $crate::expr::CmpOp::Eq - }; - (@cmp_op !=) => { - $crate::expr::CmpOp::Neq - }; - (@cmp_op <) => { - $crate::expr::CmpOp::Lt - }; - (@cmp_op <=) => { - $crate::expr::CmpOp::Lte - }; - (@cmp_op >) => { - $crate::expr::CmpOp::Gt - }; - (@cmp_op >=) => { - $crate::expr::CmpOp::Gte - }; - ($op:tt $data:expr) => { - $crate::expr::Cmp::new(nft_expr_cmp!(@cmp_op $op), $data) - }; +impl Expression for Cmp { + fn get_name() -> &'static str { + "cmp" + } } +/* /// Can be used to compare the value loaded by [`Meta::IifName`] and [`Meta::OifName`]. Please note /// that it is faster to check interface index than name. /// @@ -207,13 +84,4 @@ impl ToSlice for InterfaceName { Cow::from(bytes) } } - -impl<'a> ToSlice for &'a InterfaceName { - fn to_slice(&self) -> Cow<'_, [u8]> { - let bytes = match *self { - InterfaceName::Exact(ref name) => name.as_bytes_with_nul(), - InterfaceName::StartingWith(ref name) => name.as_bytes(), - }; - Cow::from(bytes) - } -} +*/ diff --git a/src/expr/counter.rs b/src/expr/counter.rs index 4732e85..d22fb8a 100644 --- a/src/expr/counter.rs +++ b/src/expr/counter.rs @@ -1,46 +1,21 @@ -use super::{DeserializationError, Expression, Rule}; +use rustables_macros::nfnetlink_struct; + +use super::Expression; use crate::sys; -use std::os::raw::c_char; /// A counter expression adds a counter to the rule that is incremented to count number of packets /// and number of bytes for all packets that have matched the rule. -#[derive(Debug, PartialEq)] +#[derive(Default, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_struct] pub struct Counter { + #[field(sys::NFTA_COUNTER_BYTES)] pub nb_bytes: u64, + #[field(sys::NFTA_COUNTER_PACKETS)] pub nb_packets: u64, } -impl Counter { - pub fn new() -> Self { - Self { - nb_bytes: 0, - nb_packets: 0, - } - } -} - impl Expression for Counter { - fn get_raw_name() -> *const c_char { - b"counter\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let nb_bytes = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16); - let nb_packets = sys::nftnl_expr_get_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16); - Ok(Counter { - nb_bytes, - nb_packets, - }) - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_BYTES as u16, self.nb_bytes); - sys::nftnl_expr_set_u64(expr, sys::NFTNL_EXPR_CTR_PACKETS as u16, self.nb_packets); - expr - } + fn get_name() -> &'static str { + "counter" } } diff --git a/src/expr/ct.rs b/src/expr/ct.rs index 7d6614c..ad76989 100644 --- a/src/expr/ct.rs +++ b/src/expr/ct.rs @@ -1,9 +1,13 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::sys::{ + NFTA_CT_DIRECTION, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_CT_SREG, NFT_CT_MARK, NFT_CT_STATE, +}; + +use super::{Expression, Register}; bitflags::bitflags! { - pub struct States: u32 { + pub struct ConnTrackState: u32 { const INVALID = 1; const ESTABLISHED = 2; const RELATED = 4; @@ -12,76 +16,54 @@ bitflags::bitflags! { } } -pub enum Conntrack { - State, - Mark { set: bool }, +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_enum(u32, nested = true)] +pub enum ConntrackKey { + State = NFT_CT_STATE, + Mark = NFT_CT_MARK, } -impl Conntrack { - fn raw_key(&self) -> u32 { - match *self { - Conntrack::State => libc::NFT_CT_STATE as u32, - Conntrack::Mark { .. } => libc::NFT_CT_MARK as u32, - } - } +#[derive(Default, Clone, Debug, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct Conntrack { + #[field(NFTA_CT_DREG)] + pub dreg: Register, + #[field(NFTA_CT_KEY)] + pub key: ConntrackKey, + #[field(NFTA_CT_DIRECTION)] + pub direction: u8, + #[field(NFTA_CT_SREG)] + pub sreg: Register, } impl Expression for Conntrack { - fn get_raw_name() -> *const c_char { - b"ct\0" as *const _ as *const c_char + fn get_name() -> &'static str { + "ct" } +} - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let ct_key = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16); - let ct_sreg_is_set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_CT_SREG as u16); - - match ct_key as i32 { - libc::NFT_CT_STATE => Ok(Conntrack::State), - libc::NFT_CT_MARK => Ok(Conntrack::Mark { - set: ct_sreg_is_set, - }), - _ => Err(DeserializationError::InvalidValue), - } - } +impl Conntrack { + pub fn new(key: ConntrackKey) -> Self { + Self::default().with_dreg(Register::Reg1).with_key(key) } - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); + pub fn set_mark_value(&mut self, reg: Register) { + self.set_sreg(reg); + self.set_key(ConntrackKey::Mark); + } - if let Conntrack::Mark { set: true } = self { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CT_SREG as u16, - libc::NFT_REG_1 as u32, - ); - } else { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_CT_DREG as u16, - libc::NFT_REG_1 as u32, - ); - } - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_CT_KEY as u16, self.raw_key()); + pub fn with_mark_value(mut self, reg: Register) -> Self { + self.set_mark_value(reg); + self + } - expr - } + pub fn retrieve_value(&mut self, key: ConntrackKey) { + self.set_key(key); + self.set_dreg(Register::Reg1); } -} -#[macro_export] -macro_rules! nft_expr_ct { - (state) => { - $crate::expr::Conntrack::State - }; - (mark set) => { - $crate::expr::Conntrack::Mark { set: true } - }; - (mark) => { - $crate::expr::Conntrack::Mark { set: false } - }; + pub fn with_retrieve_value(mut self, key: ConntrackKey) -> Self { + self.retrieve_value(key); + self + } } diff --git a/src/expr/immediate.rs b/src/expr/immediate.rs index 71453b3..2fd9bd5 100644 --- a/src/expr/immediate.rs +++ b/src/expr/immediate.rs @@ -1,124 +1,50 @@ -use super::{DeserializationError, Expression, Register, Rule, ToSlice}; -use crate::sys; -use std::ffi::c_void; -use std::os::raw::c_char; - -/// An immediate expression. Used to set immediate data. Verdicts are handled separately by -/// [crate::expr::Verdict]. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct Immediate<T> { - pub data: T, - pub register: Register, +use rustables_macros::nfnetlink_struct; + +use super::{Expression, Register, Verdict, VerdictKind, VerdictType}; +use crate::{ + parser_impls::NfNetlinkData, + sys::{NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG}, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Immediate { + #[field(NFTA_IMMEDIATE_DREG)] + dreg: Register, + #[field(NFTA_IMMEDIATE_DATA)] + data: NfNetlinkData, } -impl<T> Immediate<T> { - pub fn new(data: T, register: Register) -> Self { - Self { data, register } +impl Immediate { + pub fn new_data(data: Vec<u8>, register: Register) -> Self { + Immediate::default() + .with_dreg(register) + .with_data(NfNetlinkData::default().with_value(data)) } -} - -impl<T: ToSlice> Expression for Immediate<T> { - fn get_raw_name() -> *const c_char { - b"immediate\0" as *const _ as *const c_char - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - self.register.to_raw(), - ); - - let data = self.data.to_slice(); - sys::nftnl_expr_set( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - data.as_ptr() as *const c_void, - data.len() as u32, - ); - - expr + pub fn new_verdict(kind: VerdictKind) -> Self { + let code = match kind { + VerdictKind::Drop => VerdictType::Drop, + VerdictKind::Accept => VerdictType::Accept, + VerdictKind::Queue => VerdictType::Queue, + VerdictKind::Continue => VerdictType::Continue, + VerdictKind::Break => VerdictType::Break, + VerdictKind::Jump { .. } => VerdictType::Jump, + VerdictKind::Goto { .. } => VerdictType::Goto, + VerdictKind::Return => VerdictType::Return, + }; + let mut data = Verdict::default().with_code(code); + if let VerdictKind::Jump { chain } | VerdictKind::Goto { chain } = kind { + data.set_chain(chain); } + Immediate::default() + .with_dreg(Register::Verdict) + .with_data(NfNetlinkData::default().with_verdict(data)) } } -impl<const N: usize> Expression for Immediate<[u8; N]> { - fn get_raw_name() -> *const c_char { - Immediate::<u8>::get_raw_name() - } - - /// The raw data contained inside `Immediate` expressions can only be deserialized to arrays of - /// bytes, to ensure that the memory layout of retrieved data cannot be violated. It is your - /// responsibility to provide the correct length of the byte data. If the data size is invalid, - /// you will get the error `DeserializationError::InvalidDataSize`. - /// - /// Example (warning, no error checking!): - /// ```rust - /// use std::ffi::CString; - /// use std::net::Ipv4Addr; - /// use std::rc::Rc; - /// - /// use rustables::{Chain, expr::{Immediate, Register}, ProtoFamily, Rule, Table}; - /// - /// let table = Rc::new(Table::new(&CString::new("mytable").unwrap(), ProtoFamily::Inet)); - /// let chain = Rc::new(Chain::new(&CString::new("mychain").unwrap(), table)); - /// let mut rule = Rule::new(chain); - /// rule.add_expr(&Immediate::new(42u8, Register::Reg1)); - /// for expr in Rc::new(rule).get_exprs() { - /// println!("{:?}", expr.decode_expr::<Immediate<[u8; 1]>>().unwrap()); - /// } - /// ``` - /// These limitations occur because casting bytes to any type of the same size as the raw input - /// would be *extremely* dangerous in terms of memory safety. - // As casting bytes to any type of the same size as the input would be *extremely* dangerous in - // terms of memory safety, rustables only accept to deserialize expressions with variable-size - // data to arrays of bytes, so that the memory layout cannot be invalid. - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let ref_len = std::mem::size_of::<[u8; N]>() as u32; - let mut data_len = 0; - let data = sys::nftnl_expr_get( - expr, - sys::NFTNL_EXPR_IMM_DATA as u16, - &mut data_len as *mut u32, - ); - - if data.is_null() { - return Err(DeserializationError::NullPointer); - } else if data_len != ref_len { - return Err(DeserializationError::InvalidDataSize); - } - - let data = *(data as *const [u8; N]); - - let register = Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - ))?; - - Ok(Immediate { data, register }) - } - } - - // call to the other implementation to generate the expression - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - Immediate { - register: self.register, - data: &self.data as &[u8], - } - .to_expr(rule) +impl Expression for Immediate { + fn get_name() -> &'static str { + "immediate" } } - -#[macro_export] -macro_rules! nft_expr_immediate { - (data $value:expr) => { - $crate::expr::Immediate { - data: $value, - register: $crate::expr::Register::Reg1, - } - }; -} diff --git a/src/expr/log.rs b/src/expr/log.rs index 8d20b48..cc2728e 100644 --- a/src/expr/log.rs +++ b/src/expr/log.rs @@ -1,112 +1,41 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; -use thiserror::Error; +use rustables_macros::nfnetlink_struct; +use super::Expression; +use crate::{ + error::BuilderError, + sys::{NFTA_LOG_GROUP, NFTA_LOG_PREFIX}, +}; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] /// A Log expression will log all packets that match the rule. -#[derive(Debug, PartialEq)] pub struct Log { - pub group: Option<LogGroup>, - pub prefix: Option<LogPrefix>, + #[field(NFTA_LOG_GROUP)] + group: u16, + #[field(NFTA_LOG_PREFIX)] + prefix: String, } -impl Expression for Log { - fn get_raw_name() -> *const sys::libc::c_char { - b"log\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let mut group = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_GROUP as u16) { - group = Some(LogGroup(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_LOG_GROUP as u16, - ) as u16)); - } - let mut prefix = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16) { - let raw_prefix = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16); - if raw_prefix.is_null() { - return Err(DeserializationError::NullPointer); - } else { - prefix = Some(LogPrefix(CStr::from_ptr(raw_prefix).to_owned())); - } - } - Ok(Log { group, prefix }) +impl Log { + pub fn new(group: Option<u16>, prefix: Option<impl Into<String>>) -> Result<Log, BuilderError> { + let mut res = Log::default(); + if let Some(group) = group { + res.set_group(group); } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(b"log\0" as *const _ as *const c_char)); - if let Some(log_group) = self.group { - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOG_GROUP as u16, log_group.0 as u32); - }; - if let Some(LogPrefix(prefix)) = &self.prefix { - sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_LOG_PREFIX as u16, prefix.as_ptr()); - }; + if let Some(prefix) = prefix { + let prefix = prefix.into(); - expr + if prefix.bytes().count() > 127 { + return Err(BuilderError::TooLongLogPrefix); + } + res.set_prefix(prefix); } + Ok(res) } } -#[derive(Error, Debug)] -pub enum LogPrefixError { - #[error("The log prefix string is more than 128 characters long")] - TooLongPrefix, - #[error("The log prefix string contains an invalid Nul character.")] - PrefixContainsANul(#[from] std::ffi::NulError), -} - -/// The NFLOG group that will be assigned to each log line. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub struct LogGroup(pub u16); - -/// A prefix that will get prepended to each log line. -#[derive(Debug, Clone, PartialEq)] -pub struct LogPrefix(CString); - -impl LogPrefix { - /// Creates a new LogPrefix from a String. Converts it to CString as needed by nftnl. Note that - /// LogPrefix should not be more than 127 characters long. - pub fn new(prefix: &str) -> Result<Self, LogPrefixError> { - if prefix.chars().count() > 127 { - return Err(LogPrefixError::TooLongPrefix); - } - Ok(LogPrefix(CString::new(prefix)?)) +impl Expression for Log { + fn get_name() -> &'static str { + "log" } } - -#[macro_export] -macro_rules! nft_expr_log { - (group $group:ident prefix $prefix:expr) => { - $crate::expr::Log { - group: $group, - prefix: $prefix, - } - }; - (prefix $prefix:expr) => { - $crate::expr::Log { - group: None, - prefix: $prefix, - } - }; - (group $group:ident) => { - $crate::expr::Log { - group: $group, - prefix: None, - } - }; - () => { - $crate::expr::Log { - group: None, - prefix: None, - } - }; -} diff --git a/src/expr/lookup.rs b/src/expr/lookup.rs index a0cc021..2ef830e 100644 --- a/src/expr/lookup.rs +++ b/src/expr/lookup.rs @@ -1,78 +1,40 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::set::Set; -use crate::sys::{self, libc}; -use std::ffi::{CStr, CString}; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; -#[derive(Debug, PartialEq)] +use super::{Expression, Register}; +use crate::error::BuilderError; +use crate::sys::{NFTA_LOOKUP_DREG, NFTA_LOOKUP_SET, NFTA_LOOKUP_SET_ID, NFTA_LOOKUP_SREG}; +use crate::Set; + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] pub struct Lookup { - set_name: CString, + #[field(NFTA_LOOKUP_SET)] + set: String, + #[field(NFTA_LOOKUP_SREG)] + sreg: Register, + #[field(NFTA_LOOKUP_DREG)] + dreg: Register, + #[field(NFTA_LOOKUP_SET_ID)] set_id: u32, } impl Lookup { - /// Creates a new lookup entry. May return None if the set has no name. - pub fn new<K>(set: &Set<K>) -> Option<Self> { - set.get_name().map(|set_name| Lookup { - set_name: set_name.to_owned(), - set_id: set.get_id(), - }) - } -} - -impl Expression for Lookup { - fn get_raw_name() -> *const libc::c_char { - b"lookup\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let set_name = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_LOOKUP_SET as u16); - let set_id = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16); - - if set_name.is_null() { - return Err(DeserializationError::NullPointer); - } - - let set_name = CStr::from_ptr(set_name).to_owned(); - - Ok(Lookup { set_id, set_name }) + /// Creates a new lookup entry. May return BuilderError::MissingSetName if the set has no name. + pub fn new(set: &Set) -> Result<Self, BuilderError> { + let mut res = Lookup::default() + .with_set(set.get_name().ok_or(BuilderError::MissingSetName)?) + .with_sreg(Register::Reg1); + + if let Some(id) = set.get_id() { + res.set_set_id(*id); } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_LOOKUP_SREG as u16, - libc::NFT_REG_1 as u32, - ); - sys::nftnl_expr_set_str( - expr, - sys::NFTNL_EXPR_LOOKUP_SET as u16, - self.set_name.as_ptr() as *const _ as *const c_char, - ); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_SET_ID as u16, self.set_id); - // This code is left here since it's quite likely we need it again when we get further - // if self.reverse { - // sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_LOOKUP_FLAGS as u16, - // libc::NFT_LOOKUP_F_INV as u32); - // } - - expr - } + Ok(res) } } -#[macro_export] -macro_rules! nft_expr_lookup { - ($set:expr) => { - $crate::expr::Lookup::new($set) - }; +impl Expression for Lookup { + fn get_name() -> &'static str { + "lookup" + } } diff --git a/src/expr/masquerade.rs b/src/expr/masquerade.rs index c1a06de..dce787f 100644 --- a/src/expr/masquerade.rs +++ b/src/expr/masquerade.rs @@ -1,24 +1,20 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; + +use super::Expression; /// Sets the source IP to that of the output interface. -#[derive(Debug, PartialEq)] +#[derive(Default, Debug, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] pub struct Masquerade; -impl Expression for Masquerade { - fn get_raw_name() -> *const sys::libc::c_char { - b"masq\0" as *const _ as *const c_char - } - - fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - Ok(Masquerade) +impl Clone for Masquerade { + fn clone(&self) -> Self { + Masquerade {} } +} - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) }) +impl Expression for Masquerade { + fn get_name() -> &'static str { + "masq" } } diff --git a/src/expr/meta.rs b/src/expr/meta.rs index a015f65..3ecb1d1 100644 --- a/src/expr/meta.rs +++ b/src/expr/meta.rs @@ -1,175 +1,62 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use super::{Expression, Register}; +use crate::sys; /// A meta expression refers to meta data associated with a packet. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[nfnetlink_enum(u32)] #[non_exhaustive] -pub enum Meta { +pub enum MetaType { /// Packet ethertype protocol (skb->protocol), invalid in OUTPUT. - Protocol, + Protocol = sys::NFT_META_PROTOCOL, /// Packet mark. - Mark { set: bool }, + Mark = sys::NFT_META_MARK, /// Packet input interface index (dev->ifindex). - Iif, + Iif = sys::NFT_META_IIF, /// Packet output interface index (dev->ifindex). - Oif, + Oif = sys::NFT_META_OIF, /// Packet input interface name (dev->name). - IifName, + IifName = sys::NFT_META_IIFNAME, /// Packet output interface name (dev->name). - OifName, + OifName = sys::NFT_META_OIFNAME, /// Packet input interface type (dev->type). - IifType, + IifType = libc::NFT_META_IIFTYPE, /// Packet output interface type (dev->type). - OifType, + OifType = sys::NFT_META_OIFTYPE, /// Originating socket UID (fsuid). - SkUid, + SkUid = sys::NFT_META_SKUID, /// Originating socket GID (fsgid). - SkGid, + SkGid = sys::NFT_META_SKGID, /// Netfilter protocol (Transport layer protocol). - NfProto, + NfProto = sys::NFT_META_NFPROTO, /// Layer 4 protocol number. - L4Proto, + L4Proto = sys::NFT_META_L4PROTO, /// Socket control group (skb->sk->sk_classid). - Cgroup, + Cgroup = sys::NFT_META_CGROUP, /// A 32bit pseudo-random number. - PRandom, + PRandom = sys::NFT_META_PRANDOM, } -impl Meta { - /// Returns the corresponding `NFT_*` constant for this meta expression. - pub fn to_raw_key(&self) -> u32 { - use Meta::*; - match *self { - Protocol => libc::NFT_META_PROTOCOL as u32, - Mark { .. } => libc::NFT_META_MARK as u32, - Iif => libc::NFT_META_IIF as u32, - Oif => libc::NFT_META_OIF as u32, - IifName => libc::NFT_META_IIFNAME as u32, - OifName => libc::NFT_META_OIFNAME as u32, - IifType => libc::NFT_META_IIFTYPE as u32, - OifType => libc::NFT_META_OIFTYPE as u32, - SkUid => libc::NFT_META_SKUID as u32, - SkGid => libc::NFT_META_SKGID as u32, - NfProto => libc::NFT_META_NFPROTO as u32, - L4Proto => libc::NFT_META_L4PROTO as u32, - Cgroup => libc::NFT_META_CGROUP as u32, - PRandom => libc::NFT_META_PRANDOM as u32, - } - } +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +pub struct Meta { + #[field(sys::NFTA_META_DREG)] + dreg: Register, + #[field(sys::NFTA_META_KEY)] + key: MetaType, + #[field(sys::NFTA_META_SREG)] + sreg: Register, +} - fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_META_PROTOCOL => Ok(Self::Protocol), - libc::NFT_META_MARK => Ok(Self::Mark { set: false }), - libc::NFT_META_IIF => Ok(Self::Iif), - libc::NFT_META_OIF => Ok(Self::Oif), - libc::NFT_META_IIFNAME => Ok(Self::IifName), - libc::NFT_META_OIFNAME => Ok(Self::OifName), - libc::NFT_META_IIFTYPE => Ok(Self::IifType), - libc::NFT_META_OIFTYPE => Ok(Self::OifType), - libc::NFT_META_SKUID => Ok(Self::SkUid), - libc::NFT_META_SKGID => Ok(Self::SkGid), - libc::NFT_META_NFPROTO => Ok(Self::NfProto), - libc::NFT_META_L4PROTO => Ok(Self::L4Proto), - libc::NFT_META_CGROUP => Ok(Self::Cgroup), - libc::NFT_META_PRANDOM => Ok(Self::PRandom), - _ => Err(DeserializationError::InvalidValue), - } +impl Meta { + pub fn new(ty: MetaType) -> Self { + Meta::default().with_dreg(Register::Reg1).with_key(ty) } } impl Expression for Meta { - fn get_raw_name() -> *const libc::c_char { - b"meta\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let mut ret = Self::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_META_KEY as u16, - ))?; - - if let Self::Mark { ref mut set } = ret { - *set = sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_META_SREG as u16); - } - - Ok(ret) - } + fn get_name() -> &'static str { + "meta" } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - if let Meta::Mark { set: true } = self { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_META_SREG as u16, - libc::NFT_REG_1 as u32, - ); - } else { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_META_DREG as u16, - libc::NFT_REG_1 as u32, - ); - } - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_META_KEY as u16, self.to_raw_key()); - expr - } - } -} - -#[macro_export] -macro_rules! nft_expr_meta { - (proto) => { - $crate::expr::Meta::Protocol - }; - (mark set) => { - $crate::expr::Meta::Mark { set: true } - }; - (mark) => { - $crate::expr::Meta::Mark { set: false } - }; - (iif) => { - $crate::expr::Meta::Iif - }; - (oif) => { - $crate::expr::Meta::Oif - }; - (iifname) => { - $crate::expr::Meta::IifName - }; - (oifname) => { - $crate::expr::Meta::OifName - }; - (iiftype) => { - $crate::expr::Meta::IifType - }; - (oiftype) => { - $crate::expr::Meta::OifType - }; - (skuid) => { - $crate::expr::Meta::SkUid - }; - (skgid) => { - $crate::expr::Meta::SkGid - }; - (nfproto) => { - $crate::expr::Meta::NfProto - }; - (l4proto) => { - $crate::expr::Meta::L4Proto - }; - (cgroup) => { - $crate::expr::Meta::Cgroup - }; - (random) => { - $crate::expr::Meta::PRandom - }; } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index dc59507..058b0cb 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -3,14 +3,14 @@ //! //! [`Rule`]: struct.Rule.html -use std::borrow::Cow; -use std::net::IpAddr; -use std::net::Ipv4Addr; -use std::net::Ipv6Addr; +use std::fmt::Debug; -use super::rule::Rule; -use crate::sys::{self, libc}; -use thiserror::Error; +use rustables_macros::nfnetlink_struct; + +use crate::error::DecodeError; +use crate::nlmsg::{NfNetlinkAttribute, NfNetlinkDeserializable}; +use crate::parser_impls::NfNetlinkList; +use crate::sys::{self, NFTA_EXPR_DATA, NFTA_EXPR_NAME}; mod bitwise; pub use self::bitwise::*; @@ -46,7 +46,7 @@ mod payload; pub use self::payload::*; mod reject; -pub use self::reject::{IcmpCode, Reject}; +pub use self::reject::{IcmpCode, Reject, RejectType}; mod register; pub use self::register::Register; @@ -54,189 +54,161 @@ pub use self::register::Register; mod verdict; pub use self::verdict::*; -mod wrapper; -pub use self::wrapper::ExpressionWrapper; - -#[derive(Debug, Error)] -pub enum DeserializationError { - #[error("The expected expression type doesn't match the name of the raw expression")] - /// The expected expression type doesn't match the name of the raw expression. - InvalidExpressionKind, - - #[error("Deserializing the requested type isn't implemented yet")] - /// Deserializing the requested type isn't implemented yet. - NotImplemented, - - #[error("The expression value cannot be deserialized to the requested type")] - /// The expression value cannot be deserialized to the requested type. - InvalidValue, - - #[error("A pointer was null while a non-null pointer was expected")] - /// A pointer was null while a non-null pointer was expected. - NullPointer, - - #[error( - "The size of a raw value was incoherent with the expected type of the deserialized value" - )] - /// The size of a raw value was incoherent with the expected type of the deserialized value/ - InvalidDataSize, - - #[error(transparent)] - /// Couldn't find a matching protocol. - InvalidProtolFamily(#[from] super::InvalidProtocolFamily), -} - -/// Trait for every safe wrapper of an nftables expression. pub trait Expression { - /// Returns the raw name used by nftables to identify the rule. - fn get_raw_name() -> *const libc::c_char; - - /// Try to parse the expression from a raw nftables expression, returning a - /// [DeserializationError] if the attempted parsing failed. - fn from_expr(_expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - Err(DeserializationError::NotImplemented) - } - - /// Allocates and returns the low level `nftnl_expr` representation of this expression. The - /// caller to this method is responsible for freeing the expression. - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr; + fn get_name() -> &'static str; } -/// A type that can be converted into a byte buffer. -pub trait ToSlice { - /// Returns the data this type represents. - fn to_slice(&self) -> Cow<'_, [u8]>; +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true, derive_decoder = false)] +pub struct RawExpression { + #[field(NFTA_EXPR_NAME)] + name: String, + #[field(NFTA_EXPR_DATA)] + data: ExpressionVariant, } -impl<'a> ToSlice for &'a [u8] { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Borrowed(self) +impl<T> From<T> for RawExpression +where + T: Expression, + ExpressionVariant: From<T>, +{ + fn from(val: T) -> Self { + RawExpression::default() + .with_name(T::get_name()) + .with_data(ExpressionVariant::from(val)) } } -impl<'a> ToSlice for &'a [u16] { - fn to_slice(&self) -> Cow<'_, [u8]> { - let ptr = self.as_ptr() as *const u8; - let len = self.len() * 2; - Cow::Borrowed(unsafe { std::slice::from_raw_parts(ptr, len) }) - } -} - -impl ToSlice for IpAddr { - fn to_slice(&self) -> Cow<'_, [u8]> { - match *self { - IpAddr::V4(ref addr) => addr.to_slice(), - IpAddr::V6(ref addr) => addr.to_slice(), +macro_rules! create_expr_variant { + ($enum:ident $(, [$name:ident, $type:ty])+) => { + #[derive(Debug, Clone, PartialEq, Eq)] + pub enum $enum { + $( + $name($type), + )+ } - } -} -impl ToSlice for Ipv4Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) - } -} - -impl ToSlice for Ipv6Addr { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(self.octets().to_vec()) - } -} + impl $crate::nlmsg::NfNetlinkAttribute for $enum { + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + match self { + $( + $enum::$name(val) => val.get_size(), + )+ + } + } + + unsafe fn write_payload(&self, addr: *mut u8) { + match self { + $( + $enum::$name(val) => val.write_payload(addr), + )+ + } + } + } -impl ToSlice for u8 { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::Owned(vec![*self]) - } + $( + impl From<$type> for $enum { + fn from(val: $type) -> Self { + $enum::$name(val) + } + } + )+ + + impl $crate::nlmsg::AttributeDecoder for RawExpression { + fn decode_attribute( + &mut self, + attr_type: u16, + buf: &[u8], + ) -> Result<(), $crate::error::DecodeError> { + debug!("Decoding attribute {} in an expression", attr_type); + match attr_type { + x if x == sys::NFTA_EXPR_NAME => { + debug!("Calling {}::deserialize()", std::any::type_name::<String>()); + let (val, remaining) = String::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::error::DecodeError::InvalidDataSize); + } + self.name = Some(val); + Ok(()) + }, + x if x == sys::NFTA_EXPR_DATA => { + // we can assume we have already the name parsed, as that's how we identify the + // type of expression + let name = self.name.as_ref() + .ok_or($crate::error::DecodeError::MissingExpressionName)?; + match name { + $( + x if x == <$type>::get_name() => { + debug!("Calling {}::deserialize()", std::any::type_name::<$type>()); + let (res, remaining) = <$type>::deserialize(buf)?; + if remaining.len() != 0 { + return Err($crate::error::DecodeError::InvalidDataSize); + } + self.data = Some(ExpressionVariant::from(res)); + Ok(()) + }, + )+ + name => { + info!("Unrecognized expression '{}', generating an ExpressionRaw", name); + self.data = Some(ExpressionVariant::ExpressionRaw(ExpressionRaw::deserialize(buf)?.0)); + Ok(()) + } + } + }, + _ => Err(DecodeError::UnsupportedAttributeType(attr_type)), + } + } + } + }; } -impl ToSlice for u16 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = (*self & 0x00ff) as u8; - let b1 = (*self >> 8) as u8; - Cow::Owned(vec![b0, b1]) +create_expr_variant!( + ExpressionVariant, + [Bitwise, Bitwise], + [Cmp, Cmp], + [Conntrack, Conntrack], + [Counter, Counter], + [ExpressionRaw, ExpressionRaw], + [Immediate, Immediate], + [Log, Log], + [Lookup, Lookup], + [Masquerade, Masquerade], + [Meta, Meta], + [Nat, Nat], + [Payload, Payload], + [Reject, Reject] +); + +pub type ExpressionList = NfNetlinkList<RawExpression>; + +// default type for expressions that we do not handle yet +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExpressionRaw(Vec<u8>); + +impl NfNetlinkAttribute for ExpressionRaw { + fn get_size(&self) -> usize { + self.0.get_size() } -} -impl ToSlice for u32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) + unsafe fn write_payload(&self, addr: *mut u8) { + self.0.write_payload(addr); } } -impl ToSlice for i32 { - fn to_slice(&self) -> Cow<'_, [u8]> { - let b0 = *self as u8; - let b1 = (*self >> 8) as u8; - let b2 = (*self >> 16) as u8; - let b3 = (*self >> 24) as u8; - Cow::Owned(vec![b0, b1, b2, b3]) +impl NfNetlinkDeserializable for ExpressionRaw { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((ExpressionRaw(buf.to_vec()), &[])) } } -impl<'a> ToSlice for &'a str { - fn to_slice(&self) -> Cow<'_, [u8]> { - Cow::from(self.as_bytes()) +// Because we loose the name of the expression when parsing, this is the only expression +// where deserializing a message and then reserializing it is invalid +impl Expression for ExpressionRaw { + fn get_name() -> &'static str { + "unknown_expression" } } - -#[macro_export(local_inner_macros)] -macro_rules! nft_expr { - (bitwise mask $mask:expr,xor $xor:expr) => { - nft_expr_bitwise!(mask $mask, xor $xor) - }; - (cmp $op:tt $data:expr) => { - nft_expr_cmp!($op $data) - }; - (counter) => { - $crate::expr::Counter { nb_bytes: 0, nb_packets: 0} - }; - (ct $key:ident set) => { - nft_expr_ct!($key set) - }; - (ct $key:ident) => { - nft_expr_ct!($key) - }; - (immediate $expr:ident $value:expr) => { - nft_expr_immediate!($expr $value) - }; - (log group $group:ident prefix $prefix:expr) => { - nft_expr_log!(group $group prefix $prefix) - }; - (log group $group:ident) => { - nft_expr_log!(group $group) - }; - (log prefix $prefix:expr) => { - nft_expr_log!(prefix $prefix) - }; - (log) => { - nft_expr_log!() - }; - (lookup $set:expr) => { - nft_expr_lookup!($set) - }; - (masquerade) => { - $crate::expr::Masquerade - }; - (meta $expr:ident set) => { - nft_expr_meta!($expr set) - }; - (meta $expr:ident) => { - nft_expr_meta!($expr) - }; - (payload $proto:ident $field:ident) => { - nft_expr_payload!($proto $field) - }; - (verdict $verdict:ident) => { - nft_expr_verdict!($verdict) - }; - (verdict $verdict:ident $chain:expr) => { - nft_expr_verdict!($verdict $chain) - }; -} diff --git a/src/expr/nat.rs b/src/expr/nat.rs index ce6b881..406b2e6 100644 --- a/src/expr/nat.rs +++ b/src/expr/nat.rs @@ -1,99 +1,37 @@ -use super::{DeserializationError, Expression, Register, Rule}; -use crate::ProtoFamily; -use crate::sys::{self, libc}; -use std::{convert::TryFrom, os::raw::c_char}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use super::{Expression, Register}; +use crate::{ + sys::{self, NFT_NAT_DNAT, NFT_NAT_SNAT}, + ProtocolFamily, +}; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(i32)] +#[nfnetlink_enum(i32)] pub enum NatType { /// Source NAT. Changes the source address of a packet. - SNat = libc::NFT_NAT_SNAT, + SNat = NFT_NAT_SNAT, /// Destination NAT. Changes the destination address of a packet. - DNat = libc::NFT_NAT_DNAT, -} - -impl NatType { - fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_NAT_SNAT => Ok(NatType::SNat), - libc::NFT_NAT_DNAT => Ok(NatType::DNat), - _ => Err(DeserializationError::InvalidValue), - } - } + DNat = NFT_NAT_DNAT, } /// A source or destination NAT statement. Modifies the source or destination address (and possibly /// port) of packets. -#[derive(Debug, PartialEq)] +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] pub struct Nat { + #[field(sys::NFTA_NAT_TYPE)] pub nat_type: NatType, - pub family: ProtoFamily, + #[field(sys::NFTA_NAT_FAMILY)] + pub family: ProtocolFamily, + #[field(sys::NFTA_NAT_REG_ADDR_MIN)] pub ip_register: Register, - pub port_register: Option<Register>, + #[field(sys::NFTA_NAT_REG_PROTO_MIN)] + pub port_register: Register, } impl Expression for Nat { - fn get_raw_name() -> *const libc::c_char { - b"nat\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - let nat_type = NatType::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_TYPE as u16, - ))?; - - let family = ProtoFamily::try_from(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_FAMILY as u16, - ) as i32)?; - - let ip_register = Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, - ))?; - - let mut port_register = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16) { - port_register = Some(Register::from_raw(sys::nftnl_expr_get_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, - ))?); - } - - Ok(Nat { - ip_register, - nat_type, - family, - port_register, - }) - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - let expr = try_alloc!(unsafe { sys::nftnl_expr_alloc(Self::get_raw_name()) }); - - unsafe { - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_TYPE as u16, self.nat_type as u32); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_NAT_FAMILY as u16, self.family as u32); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_ADDR_MIN as u16, - self.ip_register.to_raw(), - ); - if let Some(port_register) = self.port_register { - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_NAT_REG_PROTO_MIN as u16, - port_register.to_raw(), - ); - } - } - - expr + fn get_name() -> &'static str { + "nat" } } diff --git a/src/expr/payload.rs b/src/expr/payload.rs index a108fe8..d0b2cea 100644 --- a/src/expr/payload.rs +++ b/src/expr/payload.rs @@ -1,128 +1,96 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc}; -use std::os::raw::c_char; +use rustables_macros::nfnetlink_struct; -pub trait HeaderField { - fn offset(&self) -> u32; - fn len(&self) -> u32; +use super::{Expression, Register}; +use crate::{ + error::DecodeError, + sys::{self, NFT_PAYLOAD_LL_HEADER, NFT_PAYLOAD_NETWORK_HEADER, NFT_PAYLOAD_TRANSPORT_HEADER}, +}; + +/// Payload expressions refer to data from the packet's payload. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct Payload { + #[field(sys::NFTA_PAYLOAD_DREG)] + dreg: Register, + #[field(sys::NFTA_PAYLOAD_BASE)] + base: u32, + #[field(sys::NFTA_PAYLOAD_OFFSET)] + offset: u32, + #[field(sys::NFTA_PAYLOAD_LEN)] + len: u32, + #[field(sys::NFTA_PAYLOAD_SREG)] + sreg: Register, +} + +impl Expression for Payload { + fn get_name() -> &'static str { + "payload" + } } /// Payload expressions refer to data from the packet's payload. #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum Payload { +pub enum HighLevelPayload { LinkLayer(LLHeaderField), Network(NetworkHeaderField), Transport(TransportHeaderField), } -impl Payload { - pub fn build(&self) -> RawPayload { +impl HighLevelPayload { + pub fn build(&self) -> Payload { match *self { - Payload::LinkLayer(ref f) => RawPayload::LinkLayer(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), - Payload::Network(ref f) => RawPayload::Network(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), - Payload::Transport(ref f) => RawPayload::Transport(RawPayloadData { - offset: f.offset(), - len: f.len(), - }), + HighLevelPayload::LinkLayer(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_LL_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), + HighLevelPayload::Network(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_NETWORK_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), + HighLevelPayload::Transport(ref f) => Payload::default() + .with_base(NFT_PAYLOAD_TRANSPORT_HEADER) + .with_offset(f.offset()) + .with_len(f.len()), } + .with_dreg(Register::Reg1) } } -impl Expression for Payload { - fn get_raw_name() -> *const libc::c_char { - RawPayload::get_raw_name() - } - - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - self.build().to_expr(rule) - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct RawPayloadData { - offset: u32, - len: u32, -} - -/// Because deserializing a `Payload` expression is not possible (there is not enough information -/// in the expression itself), this enum should be used to deserialize payloads. +/// Payload expressions refer to data from the packet's payload. #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum RawPayload { - LinkLayer(RawPayloadData), - Network(RawPayloadData), - Transport(RawPayloadData), +pub enum PayloadType { + LinkLayer(LLHeaderField), + Network, + Transport, } -impl RawPayload { - fn base(&self) -> u32 { - match self { - Self::LinkLayer(_) => libc::NFT_PAYLOAD_LL_HEADER as u32, - Self::Network(_) => libc::NFT_PAYLOAD_NETWORK_HEADER as u32, - Self::Transport(_) => libc::NFT_PAYLOAD_TRANSPORT_HEADER as u32, +impl PayloadType { + pub fn parse_from_payload(raw: &Payload) -> Result<Self, DecodeError> { + if raw.base.is_none() { + return Err(DecodeError::PayloadMissingBase); } - } -} - -impl HeaderField for RawPayload { - fn offset(&self) -> u32 { - match self { - Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.offset, + if raw.len.is_none() { + return Err(DecodeError::PayloadMissingLen); } - } - - fn len(&self) -> u32 { - match self { - Self::LinkLayer(ref f) | Self::Network(ref f) | Self::Transport(ref f) => f.len, + if raw.offset.is_none() { + return Err(DecodeError::PayloadMissingOffset); } + Ok(match raw.base { + Some(NFT_PAYLOAD_LL_HEADER) => PayloadType::LinkLayer(LLHeaderField::from_raw_data( + raw.offset.unwrap(), + raw.len.unwrap(), + )?), + Some(NFT_PAYLOAD_NETWORK_HEADER) => PayloadType::Network, + Some(NFT_PAYLOAD_TRANSPORT_HEADER) => PayloadType::Transport, + Some(v) => return Err(DecodeError::UnknownPayloadType(v)), + None => return Err(DecodeError::PayloadMissingBase), + }) } } -impl Expression for RawPayload { - fn get_raw_name() -> *const libc::c_char { - b"payload\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let base = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16); - let offset = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16); - let len = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16); - match base as i32 { - libc::NFT_PAYLOAD_LL_HEADER => Ok(Self::LinkLayer(RawPayloadData { offset, len })), - libc::NFT_PAYLOAD_NETWORK_HEADER => { - Ok(Self::Network(RawPayloadData { offset, len })) - } - libc::NFT_PAYLOAD_TRANSPORT_HEADER => { - Ok(Self::Transport(RawPayloadData { offset, len })) - } - - _ => return Err(DeserializationError::InvalidValue), - } - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_BASE as u16, self.base()); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_OFFSET as u16, self.offset()); - sys::nftnl_expr_set_u32(expr, sys::NFTNL_EXPR_PAYLOAD_LEN as u16, self.len()); - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_PAYLOAD_DREG as u16, - libc::NFT_REG_1 as u32, - ); - - expr - } - } +pub trait HeaderField { + fn offset(&self) -> u32; + fn len(&self) -> u32; } #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -154,58 +122,52 @@ impl HeaderField for LLHeaderField { } impl LLHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 6 { - Ok(Self::Daddr) - } else if off == 6 && len == 6 { - Ok(Self::Saddr) - } else if off == 12 && len == 2 { - Ok(Self::EtherType) - } else { - Err(DeserializationError::InvalidValue) - } + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 6) => Self::Daddr, + (6, 6) => Self::Saddr, + (12, 2) => Self::EtherType, + _ => return Err(DecodeError::UnknownLinkLayerHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum NetworkHeaderField { - Ipv4(Ipv4HeaderField), - Ipv6(Ipv6HeaderField), + IPv4(IPv4HeaderField), + IPv6(IPv6HeaderField), } impl HeaderField for NetworkHeaderField { fn offset(&self) -> u32 { use self::NetworkHeaderField::*; match *self { - Ipv4(ref f) => f.offset(), - Ipv6(ref f) => f.offset(), + IPv4(ref f) => f.offset(), + IPv6(ref f) => f.offset(), } } fn len(&self) -> u32 { use self::NetworkHeaderField::*; match *self { - Ipv4(ref f) => f.len(), - Ipv6(ref f) => f.len(), + IPv4(ref f) => f.len(), + IPv6(ref f) => f.len(), } } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Ipv4HeaderField { +pub enum IPv4HeaderField { Ttl, Protocol, Saddr, Daddr, } -impl HeaderField for Ipv4HeaderField { +impl HeaderField for IPv4HeaderField { fn offset(&self) -> u32 { - use self::Ipv4HeaderField::*; + use self::IPv4HeaderField::*; match *self { Ttl => 8, Protocol => 9, @@ -215,7 +177,7 @@ impl HeaderField for Ipv4HeaderField { } fn len(&self) -> u32 { - use self::Ipv4HeaderField::*; + use self::IPv4HeaderField::*; match *self { Ttl => 1, Protocol => 1, @@ -225,37 +187,30 @@ impl HeaderField for Ipv4HeaderField { } } -impl Ipv4HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 8 && len == 1 { - Ok(Self::Ttl) - } else if off == 9 && len == 1 { - Ok(Self::Protocol) - } else if off == 12 && len == 4 { - Ok(Self::Saddr) - } else if off == 16 && len == 4 { - Ok(Self::Daddr) - } else { - Err(DeserializationError::InvalidValue) - } +impl IPv4HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (8, 1) => Self::Ttl, + (9, 1) => Self::Protocol, + (12, 4) => Self::Saddr, + (16, 4) => Self::Daddr, + _ => return Err(DecodeError::UnknownIPv4HeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Ipv6HeaderField { +pub enum IPv6HeaderField { NextHeader, HopLimit, Saddr, Daddr, } -impl HeaderField for Ipv6HeaderField { +impl HeaderField for IPv6HeaderField { fn offset(&self) -> u32 { - use self::Ipv6HeaderField::*; + use self::IPv6HeaderField::*; match *self { NextHeader => 6, HopLimit => 7, @@ -265,7 +220,7 @@ impl HeaderField for Ipv6HeaderField { } fn len(&self) -> u32 { - use self::Ipv6HeaderField::*; + use self::IPv6HeaderField::*; match *self { NextHeader => 1, HopLimit => 1, @@ -275,31 +230,24 @@ impl HeaderField for Ipv6HeaderField { } } -impl Ipv6HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 6 && len == 1 { - Ok(Self::NextHeader) - } else if off == 7 && len == 1 { - Ok(Self::HopLimit) - } else if off == 8 && len == 16 { - Ok(Self::Saddr) - } else if off == 24 && len == 16 { - Ok(Self::Daddr) - } else { - Err(DeserializationError::InvalidValue) - } +impl IPv6HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (6, 1) => Self::NextHeader, + (7, 1) => Self::HopLimit, + (8, 16) => Self::Saddr, + (24, 16) => Self::Daddr, + _ => return Err(DecodeError::UnknownIPv6HeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] pub enum TransportHeaderField { - Tcp(TcpHeaderField), - Udp(UdpHeaderField), - Icmpv6(Icmpv6HeaderField), + Tcp(TCPHeaderField), + Udp(UDPHeaderField), + ICMPv6(ICMPv6HeaderField), } impl HeaderField for TransportHeaderField { @@ -308,7 +256,7 @@ impl HeaderField for TransportHeaderField { match *self { Tcp(ref f) => f.offset(), Udp(ref f) => f.offset(), - Icmpv6(ref f) => f.offset(), + ICMPv6(ref f) => f.offset(), } } @@ -317,21 +265,21 @@ impl HeaderField for TransportHeaderField { match *self { Tcp(ref f) => f.len(), Udp(ref f) => f.len(), - Icmpv6(ref f) => f.len(), + ICMPv6(ref f) => f.len(), } } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum TcpHeaderField { +pub enum TCPHeaderField { Sport, Dport, } -impl HeaderField for TcpHeaderField { +impl HeaderField for TCPHeaderField { fn offset(&self) -> u32 { - use self::TcpHeaderField::*; + use self::TCPHeaderField::*; match *self { Sport => 0, Dport => 2, @@ -339,7 +287,7 @@ impl HeaderField for TcpHeaderField { } fn len(&self) -> u32 { - use self::TcpHeaderField::*; + use self::TCPHeaderField::*; match *self { Sport => 2, Dport => 2, @@ -347,32 +295,27 @@ impl HeaderField for TcpHeaderField { } } -impl TcpHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 2 { - Ok(Self::Sport) - } else if off == 2 && len == 2 { - Ok(Self::Dport) - } else { - Err(DeserializationError::InvalidValue) - } +impl TCPHeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 2) => Self::Sport, + (2, 2) => Self::Dport, + _ => return Err(DecodeError::UnknownTCPHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum UdpHeaderField { +pub enum UDPHeaderField { Sport, Dport, Len, } -impl HeaderField for UdpHeaderField { +impl HeaderField for UDPHeaderField { fn offset(&self) -> u32 { - use self::UdpHeaderField::*; + use self::UDPHeaderField::*; match *self { Sport => 0, Dport => 2, @@ -381,7 +324,7 @@ impl HeaderField for UdpHeaderField { } fn len(&self) -> u32 { - use self::UdpHeaderField::*; + use self::UDPHeaderField::*; match *self { Sport => 2, Dport => 2, @@ -390,34 +333,28 @@ impl HeaderField for UdpHeaderField { } } -impl UdpHeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 2 { - Ok(Self::Sport) - } else if off == 2 && len == 2 { - Ok(Self::Dport) - } else if off == 4 && len == 2 { - Ok(Self::Len) - } else { - Err(DeserializationError::InvalidValue) - } +impl UDPHeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 2) => Self::Sport, + (2, 2) => Self::Dport, + (4, 2) => Self::Len, + _ => return Err(DecodeError::UnknownUDPHeaderField(offset, len)), + }) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] -pub enum Icmpv6HeaderField { +pub enum ICMPv6HeaderField { Type, Code, Checksum, } -impl HeaderField for Icmpv6HeaderField { +impl HeaderField for ICMPv6HeaderField { fn offset(&self) -> u32 { - use self::Icmpv6HeaderField::*; + use self::ICMPv6HeaderField::*; match *self { Type => 0, Code => 1, @@ -426,7 +363,7 @@ impl HeaderField for Icmpv6HeaderField { } fn len(&self) -> u32 { - use self::Icmpv6HeaderField::*; + use self::ICMPv6HeaderField::*; match *self { Type => 1, Code => 1, @@ -435,97 +372,13 @@ impl HeaderField for Icmpv6HeaderField { } } -impl Icmpv6HeaderField { - pub fn from_raw_data(data: &RawPayloadData) -> Result<Self, DeserializationError> { - let off = data.offset; - let len = data.len; - - if off == 0 && len == 1 { - Ok(Self::Type) - } else if off == 1 && len == 1 { - Ok(Self::Code) - } else if off == 2 && len == 2 { - Ok(Self::Checksum) - } else { - Err(DeserializationError::InvalidValue) - } +impl ICMPv6HeaderField { + pub fn from_raw_data(offset: u32, len: u32) -> Result<Self, DecodeError> { + Ok(match (offset, len) { + (0, 1) => Self::Type, + (1, 1) => Self::Code, + (2, 2) => Self::Checksum, + _ => return Err(DecodeError::UnknownICMPv6HeaderField(offset, len)), + }) } } - -#[macro_export(local_inner_macros)] -macro_rules! nft_expr_payload { - (@ipv4_field ttl) => { - $crate::expr::Ipv4HeaderField::Ttl - }; - (@ipv4_field protocol) => { - $crate::expr::Ipv4HeaderField::Protocol - }; - (@ipv4_field saddr) => { - $crate::expr::Ipv4HeaderField::Saddr - }; - (@ipv4_field daddr) => { - $crate::expr::Ipv4HeaderField::Daddr - }; - - (@ipv6_field nextheader) => { - $crate::expr::Ipv6HeaderField::NextHeader - }; - (@ipv6_field hoplimit) => { - $crate::expr::Ipv6HeaderField::HopLimit - }; - (@ipv6_field saddr) => { - $crate::expr::Ipv6HeaderField::Saddr - }; - (@ipv6_field daddr) => { - $crate::expr::Ipv6HeaderField::Daddr - }; - - (@tcp_field sport) => { - $crate::expr::TcpHeaderField::Sport - }; - (@tcp_field dport) => { - $crate::expr::TcpHeaderField::Dport - }; - - (@udp_field sport) => { - $crate::expr::UdpHeaderField::Sport - }; - (@udp_field dport) => { - $crate::expr::UdpHeaderField::Dport - }; - (@udp_field len) => { - $crate::expr::UdpHeaderField::Len - }; - - (ethernet daddr) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Daddr) - }; - (ethernet saddr) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::Saddr) - }; - (ethernet ethertype) => { - $crate::expr::Payload::LinkLayer($crate::expr::LLHeaderField::EtherType) - }; - - (ipv4 $field:ident) => { - $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv4( - nft_expr_payload!(@ipv4_field $field), - )) - }; - (ipv6 $field:ident) => { - $crate::expr::Payload::Network($crate::expr::NetworkHeaderField::Ipv6( - nft_expr_payload!(@ipv6_field $field), - )) - }; - - (tcp $field:ident) => { - $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Tcp( - nft_expr_payload!(@tcp_field $field), - )) - }; - (udp $field:ident) => { - $crate::expr::Payload::Transport($crate::expr::TransportHeaderField::Udp( - nft_expr_payload!(@udp_field $field), - )) - }; -} diff --git a/src/expr/register.rs b/src/expr/register.rs index a05af7e..9cc1bee 100644 --- a/src/expr/register.rs +++ b/src/expr/register.rs @@ -1,34 +1,17 @@ use std::fmt::Debug; -use crate::sys::libc; +use rustables_macros::nfnetlink_enum; -use super::DeserializationError; +use crate::sys::{NFT_REG_1, NFT_REG_2, NFT_REG_3, NFT_REG_4, NFT_REG_VERDICT}; /// A netfilter data register. The expressions store and read data to and from these when /// evaluating rule statements. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(i32)] +#[nfnetlink_enum(u32)] pub enum Register { - Verdict = libc::NFT_REG_VERDICT, - Reg1 = libc::NFT_REG_1, - Reg2 = libc::NFT_REG_2, - Reg3 = libc::NFT_REG_3, - Reg4 = libc::NFT_REG_4, -} - -impl Register { - pub fn to_raw(self) -> u32 { - self as u32 - } - - pub fn from_raw(val: u32) -> Result<Self, DeserializationError> { - match val as i32 { - libc::NFT_REG_VERDICT => Ok(Self::Verdict), - libc::NFT_REG_1 => Ok(Self::Reg1), - libc::NFT_REG_2 => Ok(Self::Reg2), - libc::NFT_REG_3 => Ok(Self::Reg3), - libc::NFT_REG_4 => Ok(Self::Reg4), - _ => Err(DeserializationError::InvalidValue), - } - } + Verdict = NFT_REG_VERDICT, + Reg1 = NFT_REG_1, + Reg2 = NFT_REG_2, + Reg3 = NFT_REG_3, + Reg4 = NFT_REG_4, } diff --git a/src/expr/reject.rs b/src/expr/reject.rs index 19752ce..83fd843 100644 --- a/src/expr/reject.rs +++ b/src/expr/reject.rs @@ -1,95 +1,40 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::ProtoFamily; -use crate::sys::{self, libc::{self, c_char}}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; -/// A reject expression that defines the type of rejection message sent when discarding a packet. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub enum Reject { - /// Returns an ICMP unreachable packet. - Icmp(IcmpCode), - /// Rejects by sending a TCP RST packet. - TcpRst, -} +use crate::sys; -impl Reject { - fn to_raw(&self, family: ProtoFamily) -> u32 { - use libc::*; - let value = match *self { - Self::Icmp(..) => match family { - ProtoFamily::Bridge | ProtoFamily::Inet => NFT_REJECT_ICMPX_UNREACH, - _ => NFT_REJECT_ICMP_UNREACH, - }, - Self::TcpRst => NFT_REJECT_TCP_RST, - }; - value as u32 - } -} +use super::Expression; impl Expression for Reject { - fn get_raw_name() -> *const libc::c_char { - b"reject\0" as *const _ as *const c_char + fn get_name() -> &'static str { + "reject" } +} - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> - where - Self: Sized, - { - unsafe { - if sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_REJECT_TYPE as u16) - == libc::NFT_REJECT_TCP_RST as u32 - { - Ok(Self::TcpRst) - } else { - Ok(Self::Icmp(IcmpCode::from_raw(sys::nftnl_expr_get_u8( - expr, - sys::NFTNL_EXPR_REJECT_CODE as u16, - ))?)) - } - } - } - - fn to_expr(&self, rule: &Rule) -> *mut sys::nftnl_expr { - let family = rule.get_chain().get_table().get_family(); - - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc(Self::get_raw_name())); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_REJECT_TYPE as u16, - self.to_raw(family), - ); - - let reject_code = match *self { - Reject::Icmp(code) => code as u8, - Reject::TcpRst => 0, - }; - - sys::nftnl_expr_set_u8(expr, sys::NFTNL_EXPR_REJECT_CODE as u16, reject_code); - - expr - } - } +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct] +/// A reject expression that defines the type of rejection message sent when discarding a packet. +pub struct Reject { + #[field(sys::NFTA_REJECT_TYPE, name_in_functions = "type")] + reject_type: RejectType, + #[field(sys::NFTA_REJECT_ICMP_CODE)] + icmp_code: IcmpCode, } /// An ICMP reject code. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -#[repr(u8)] -pub enum IcmpCode { - NoRoute = libc::NFT_REJECT_ICMPX_NO_ROUTE as u8, - PortUnreach = libc::NFT_REJECT_ICMPX_PORT_UNREACH as u8, - HostUnreach = libc::NFT_REJECT_ICMPX_HOST_UNREACH as u8, - AdminProhibited = libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED as u8, +#[nfnetlink_enum(u32)] +pub enum RejectType { + IcmpUnreach = sys::NFT_REJECT_ICMP_UNREACH, + TcpRst = sys::NFT_REJECT_TCP_RST, + IcmpxUnreach = sys::NFT_REJECT_ICMPX_UNREACH, } -impl IcmpCode { - fn from_raw(code: u8) -> Result<Self, DeserializationError> { - match code as i32 { - libc::NFT_REJECT_ICMPX_NO_ROUTE => Ok(Self::NoRoute), - libc::NFT_REJECT_ICMPX_PORT_UNREACH => Ok(Self::PortUnreach), - libc::NFT_REJECT_ICMPX_HOST_UNREACH => Ok(Self::HostUnreach), - libc::NFT_REJECT_ICMPX_ADMIN_PROHIBITED => Ok(Self::AdminProhibited), - _ => Err(DeserializationError::InvalidValue), - } - } +/// An ICMP reject code. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +#[nfnetlink_enum(u8)] +pub enum IcmpCode { + NoRoute = sys::NFT_REJECT_ICMPX_NO_ROUTE, + PortUnreach = sys::NFT_REJECT_ICMPX_PORT_UNREACH, + HostUnreach = sys::NFT_REJECT_ICMPX_HOST_UNREACH, + AdminProhibited = sys::NFT_REJECT_ICMPX_ADMIN_PROHIBITED, } diff --git a/src/expr/verdict.rs b/src/expr/verdict.rs index 3c4c374..7edf7cd 100644 --- a/src/expr/verdict.rs +++ b/src/expr/verdict.rs @@ -1,11 +1,39 @@ -use super::{DeserializationError, Expression, Rule}; -use crate::sys::{self, libc::{self, c_char}}; -use std::ffi::{CStr, CString}; +use std::fmt::Debug; + +use libc::{NF_ACCEPT, NF_DROP, NF_QUEUE}; +use rustables_macros::{nfnetlink_enum, nfnetlink_struct}; + +use crate::sys::{ + NFTA_VERDICT_CHAIN, NFTA_VERDICT_CHAIN_ID, NFTA_VERDICT_CODE, NFT_BREAK, NFT_CONTINUE, + NFT_GOTO, NFT_JUMP, NFT_RETURN, +}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[nfnetlink_enum(i32)] +pub enum VerdictType { + Drop = NF_DROP, + Accept = NF_ACCEPT, + Queue = NF_QUEUE, + Continue = NFT_CONTINUE, + Break = NFT_BREAK, + Jump = NFT_JUMP, + Goto = NFT_GOTO, + Return = NFT_RETURN, +} + +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct Verdict { + #[field(NFTA_VERDICT_CODE)] + code: VerdictType, + #[field(NFTA_VERDICT_CHAIN)] + chain: String, + #[field(NFTA_VERDICT_CHAIN_ID)] + chain_id: u32, +} -/// A verdict expression. In the background, this is usually an "Immediate" expression in nftnl -/// terms, but here it is simplified to only represent a verdict. #[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub enum Verdict { +pub enum VerdictKind { /// Silently drop the packet. Drop, /// Accept the packet and let it pass. @@ -14,135 +42,10 @@ pub enum Verdict { Continue, Break, Jump { - chain: CString, + chain: String, }, Goto { - chain: CString, + chain: String, }, Return, } - -impl Verdict { - fn chain(&self) -> Option<&CStr> { - match *self { - Verdict::Jump { ref chain } => Some(chain.as_c_str()), - Verdict::Goto { ref chain } => Some(chain.as_c_str()), - _ => None, - } - } -} - -impl Expression for Verdict { - fn get_raw_name() -> *const libc::c_char { - b"immediate\0" as *const _ as *const c_char - } - - fn from_expr(expr: *const sys::nftnl_expr) -> Result<Self, DeserializationError> { - unsafe { - let mut chain = None; - if sys::nftnl_expr_is_set(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16) { - let raw_chain = sys::nftnl_expr_get_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16); - - if raw_chain.is_null() { - return Err(DeserializationError::NullPointer); - } - chain = Some(CStr::from_ptr(raw_chain).to_owned()); - } - - let verdict = sys::nftnl_expr_get_u32(expr, sys::NFTNL_EXPR_IMM_VERDICT as u16); - - match verdict as i32 { - libc::NF_DROP => Ok(Verdict::Drop), - libc::NF_ACCEPT => Ok(Verdict::Accept), - libc::NF_QUEUE => Ok(Verdict::Queue), - libc::NFT_CONTINUE => Ok(Verdict::Continue), - libc::NFT_BREAK => Ok(Verdict::Break), - libc::NFT_JUMP => { - if let Some(chain) = chain { - Ok(Verdict::Jump { chain }) - } else { - Err(DeserializationError::InvalidValue) - } - } - libc::NFT_GOTO => { - if let Some(chain) = chain { - Ok(Verdict::Goto { chain }) - } else { - Err(DeserializationError::InvalidValue) - } - } - libc::NFT_RETURN => Ok(Verdict::Return), - _ => Err(DeserializationError::InvalidValue), - } - } - } - - fn to_expr(&self, _rule: &Rule) -> *mut sys::nftnl_expr { - let immediate_const = match *self { - Verdict::Drop => libc::NF_DROP, - Verdict::Accept => libc::NF_ACCEPT, - Verdict::Queue => libc::NF_QUEUE, - Verdict::Continue => libc::NFT_CONTINUE, - Verdict::Break => libc::NFT_BREAK, - Verdict::Jump { .. } => libc::NFT_JUMP, - Verdict::Goto { .. } => libc::NFT_GOTO, - Verdict::Return => libc::NFT_RETURN, - }; - unsafe { - let expr = try_alloc!(sys::nftnl_expr_alloc( - b"immediate\0" as *const _ as *const c_char - )); - - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_DREG as u16, - libc::NFT_REG_VERDICT as u32, - ); - - if let Some(chain) = self.chain() { - sys::nftnl_expr_set_str(expr, sys::NFTNL_EXPR_IMM_CHAIN as u16, chain.as_ptr()); - } - sys::nftnl_expr_set_u32( - expr, - sys::NFTNL_EXPR_IMM_VERDICT as u16, - immediate_const as u32, - ); - - expr - } - } -} - -#[macro_export] -macro_rules! nft_expr_verdict { - (drop) => { - $crate::expr::Verdict::Drop - }; - (accept) => { - $crate::expr::Verdict::Accept - }; - (reject icmp $code:expr) => { - $crate::expr::Verdict::Reject(RejectionType::Icmp($code)) - }; - (reject tcp-rst) => { - $crate::expr::Verdict::Reject(RejectionType::TcpRst) - }; - (queue) => { - $crate::expr::Verdict::Queue - }; - (continue) => { - $crate::expr::Verdict::Continue - }; - (break) => { - $crate::expr::Verdict::Break - }; - (jump $chain:expr) => { - $crate::expr::Verdict::Jump { chain: $chain } - }; - (goto $chain:expr) => { - $crate::expr::Verdict::Goto { chain: $chain } - }; - (return) => { - $crate::expr::Verdict::Return - }; -} diff --git a/src/expr/wrapper.rs b/src/expr/wrapper.rs deleted file mode 100644 index 12ef60b..0000000 --- a/src/expr/wrapper.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::ffi::CStr; -use std::ffi::CString; -use std::fmt::Debug; -use std::rc::Rc; -use std::os::raw::c_char; - -use super::{DeserializationError, Expression}; -use crate::{sys, Rule}; - -pub struct ExpressionWrapper { - pub(crate) expr: *const sys::nftnl_expr, - // we also need the rule here to ensure that the rule lives as long as the `expr` pointer - #[allow(dead_code)] - pub(crate) rule: Rc<Rule>, -} - -impl Debug for ExpressionWrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -impl ExpressionWrapper { - /// Retrieves a textual description of the expression. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_expr_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.expr, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Retrieves the type of expression ("log", "counter", ...). - pub fn get_kind(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_expr_get_str(self.expr, sys::NFTNL_EXPR_NAME as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - /// Attempts to decode the expression as the type T. - pub fn decode_expr<T: Expression>(&self) -> Result<T, DeserializationError> { - if let Some(kind) = self.get_kind() { - let raw_name = unsafe { CStr::from_ptr(T::get_raw_name()) }; - if kind == raw_name { - return T::from_expr(self.expr); - } - } - Err(DeserializationError::InvalidExpressionKind) - } -} @@ -1,4 +1,4 @@ -// Copyryght (c) 2021 GPL lafleur@boum.org and Simon Thoby +// Copyryght (c) 2021-2022 GPL lafleur@boum.org and Simon Thoby // // This file is free software: you may copy, redistribute and/or modify it // under the terms of the GNU General Public License as published by the @@ -24,106 +24,70 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! Safe abstraction for [`libnftnl`]. Provides userspace access to the in-kernel nf_tables -//! subsystem. Can be used to create and remove tables, chains, sets and rules from the nftables +//! Safe abstraction for userspace access to the in-kernel nf_tables subsystem. +//! Can be used to create and remove tables, chains, sets and rules from the nftables //! firewall, the successor to iptables. //! //! This library currently has quite rough edges and does not make adding and removing netfilter //! entries super easy and elegant. That is partly because the library needs more work, but also //! partly because nftables is super low level and extremely customizable, making it hard, and //! probably wrong, to try and create a too simple/limited wrapper. See examples for inspiration. -//! One can also look at how the original project this crate was developed to support uses it: -//! [Mullvad VPN app](https://github.com/mullvad/mullvadvpn-app) //! -//! Understanding how to use [`libnftnl`] and implementing this crate has mostly been done by -//! reading the source code for the [`nftables`] program and attaching debuggers to the `nft` -//! binary. Since the implementation is mostly based on trial and error, there might of course be -//! a number of places where the underlying library is used in an invalid or not intended way. -//! Large portions of [`libnftnl`] are also not covered yet. Contributions are welcome! +//! Understanding how to use the netlink subsystem and implementing this crate has mostly been done by +//! reading the source code for the [`nftables`] userspace program and its corresponding kernel code, +//! as well as attaching debuggers to the `nft` binary. +//! Since the implementation is mostly based on trial and error, there might of course be +//! a number of places where the forged netlink messages are used in an invalid or not intended way. +//! Contributions are welcome! //! -//! # Supported versions of `libnftnl` -//! -//! This crate will automatically link to the currently installed version of libnftnl upon build. -//! It requires libnftnl version 1.0.6 or higher. See how the low level FFI bindings to the C -//! library are generated in [`build.rs`]. -//! -//! # Access to raw handles -//! -//! Retrieving raw handles is considered unsafe and should only ever be enabled if you absolutely -//! need it. It is disabled by default and hidden behind the feature gate `unsafe-raw-handles`. -//! The reason for that special treatment is we cannot guarantee the lack of aliasing. For -//! example, a program using a const handle to a object in a thread and writing through a mutable -//! handle in another could reach all kind of undefined (and dangerous!) behaviors. By enabling -//! that feature flag, you acknowledge that guaranteeing the respect of safety invariants is now -//! your responsibility! Despite these shortcomings, that feature is still available because it -//! may allow you to perform manipulations that this library doesn't currently expose. If that is -//! your case, we would be very happy to hear from you and maybe help you get the necessary -//! functionality upstream. -//! -//! Our current lack of confidence in our availability to provide a safe abstraction over the use -//! of raw handles in the face of concurrency is the reason we decided to settly on `Rc` pointers -//! instead of `Arc` (besides, this should gives us some nice performance boost, not that it -//! matters much of course) and why we do not declare the types exposed by the library as `Send` -//! nor `Sync`. -//! -//! [`libnftnl`]: https://netfilter.org/projects/libnftnl/ //! [`nftables`]: https://netfilter.org/projects/nftables/ -//! [`build.rs`]: https://gitlab.com/rustwall/rustables/-/blob/master/build.rs - -use thiserror::Error; #[macro_use] extern crate log; -pub mod sys; -use std::{convert::TryFrom, ffi::c_void, ops::Deref}; -use sys::libc; - -macro_rules! try_alloc { - ($e:expr) => {{ - let ptr = $e; - if ptr.is_null() { - // OOM, and the tried allocation was likely very small, - // so we are in a very tight situation. We do what libstd does, aborts. - std::process::abort(); - } - ptr - }}; -} +use libc; + +use rustables_macros::nfnetlink_enum; +use std::convert::TryFrom; mod batch; -#[cfg(feature = "query")] -pub use batch::{batch_is_supported, default_batch_page_size}; -pub use batch::{Batch, FinalizedBatch, NetlinkError}; +pub use batch::{default_batch_page_size, Batch}; -pub mod expr; +pub mod data_type; -pub mod table; +mod table; +pub use table::list_tables; pub use table::Table; -#[cfg(feature = "query")] -pub use table::{get_tables_cb, list_tables}; mod chain; -#[cfg(feature = "query")] -pub use chain::{get_chains_cb, list_chains_for_table}; -pub use chain::{Chain, ChainType, Hook, Policy, Priority}; +pub use chain::list_chains_for_table; +pub use chain::{Chain, ChainPolicy, ChainPriority, ChainType, Hook, HookClass}; -mod chain_methods; -pub use chain_methods::ChainMethods; +pub mod error; pub mod query; +pub(crate) mod nlmsg; +pub(crate) mod parser; +pub(crate) mod parser_impls; + mod rule; +pub use rule::list_rules_for_chain; pub use rule::Rule; -#[cfg(feature = "query")] -pub use rule::{get_rules_cb, list_rules_for_chain}; + +pub mod expr; mod rule_methods; -pub use rule_methods::{iface_index, Protocol, RuleMethods, Error as MatchError}; +pub use rule_methods::{iface_index, Protocol}; pub mod set; pub use set::Set; +pub mod sys; + +#[cfg(test)] +mod tests; + /// The type of the message as it's sent to netfilter. A message consists of an object, such as a /// [`Table`], [`Chain`] or [`Rule`] for example, and a [`MsgType`] to describe what to do with /// that object. If a [`Table`] object is sent with `MsgType::Add` then that table will be added @@ -133,7 +97,7 @@ pub use set::Set; /// [`Chain`]: struct.Chain.html /// [`Rule`]: struct.Rule.html /// [`MsgType`]: enum.MsgType.html -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum MsgType { /// Add the object to netfilter. Add, @@ -142,69 +106,22 @@ pub enum MsgType { } /// Denotes a protocol. Used to specify which protocol a table or set belongs to. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u16)] -pub enum ProtoFamily { - Unspec = libc::NFPROTO_UNSPEC as u16, +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[nfnetlink_enum(i32)] +pub enum ProtocolFamily { + Unspec = libc::NFPROTO_UNSPEC, /// Inet - Means both IPv4 and IPv6 - Inet = libc::NFPROTO_INET as u16, - Ipv4 = libc::NFPROTO_IPV4 as u16, - Arp = libc::NFPROTO_ARP as u16, - NetDev = libc::NFPROTO_NETDEV as u16, - Bridge = libc::NFPROTO_BRIDGE as u16, - Ipv6 = libc::NFPROTO_IPV6 as u16, - DecNet = libc::NFPROTO_DECNET as u16, -} -#[derive(Error, Debug)] -#[error("Couldn't find a matching protocol")] -pub struct InvalidProtocolFamily; - -impl TryFrom<i32> for ProtoFamily { - type Error = InvalidProtocolFamily; - fn try_from(value: i32) -> Result<Self, Self::Error> { - match value { - libc::NFPROTO_UNSPEC => Ok(ProtoFamily::Unspec), - libc::NFPROTO_INET => Ok(ProtoFamily::Inet), - libc::NFPROTO_IPV4 => Ok(ProtoFamily::Ipv4), - libc::NFPROTO_ARP => Ok(ProtoFamily::Arp), - libc::NFPROTO_NETDEV => Ok(ProtoFamily::NetDev), - libc::NFPROTO_BRIDGE => Ok(ProtoFamily::Bridge), - libc::NFPROTO_IPV6 => Ok(ProtoFamily::Ipv6), - libc::NFPROTO_DECNET => Ok(ProtoFamily::DecNet), - _ => Err(InvalidProtocolFamily), - } - } + Inet = libc::NFPROTO_INET, + Ipv4 = libc::NFPROTO_IPV4, + Arp = libc::NFPROTO_ARP, + NetDev = libc::NFPROTO_NETDEV, + Bridge = libc::NFPROTO_BRIDGE, + Ipv6 = libc::NFPROTO_IPV6, + DecNet = libc::NFPROTO_DECNET, } -/// Trait for all types in this crate that can serialize to a Netlink message. -/// -/// # Unsafe -/// -/// This trait is unsafe to implement because it must never serialize to anything larger than the -/// largest possible netlink message. Internally the `nft_nlmsg_maxsize()` function is used to -/// make sure the `buf` pointer passed to `write` always has room for the largest possible Netlink -/// message. -pub unsafe trait NlMsg { - /// Serializes the Netlink message to the buffer at `buf`. `buf` must have space for at least - /// `nft_nlmsg_maxsize()` bytes. This is not checked by the compiler, which is why this method - /// is unsafe. - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType); -} - -unsafe impl<T, R> NlMsg for T -where - T: Deref<Target = R>, - R: NlMsg, -{ - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - self.deref().write(buf, seq, msg_type); +impl Default for ProtocolFamily { + fn default() -> Self { + ProtocolFamily::Unspec } } - -/// The largest nf_tables netlink message is the set element message, which contains the -/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set -/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is -/// a bit larger than 64 KBytes. -pub fn nft_nlmsg_maxsize() -> u32 { - u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 -} diff --git a/src/nlmsg.rs b/src/nlmsg.rs new file mode 100644 index 0000000..1c5b519 --- /dev/null +++ b/src/nlmsg.rs @@ -0,0 +1,182 @@ +use std::{fmt::Debug, mem::size_of}; + +use crate::{ + error::DecodeError, + sys::{ + nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END, + NFNL_SUBSYS_NFTABLES, NLMSG_ALIGNTO, NLM_F_ACK, NLM_F_CREATE, + }, + MsgType, ProtocolFamily, +}; +/// +/// The largest nf_tables netlink message is the set element message, which contains the +/// NFTA_SET_ELEM_LIST_ELEMENTS attribute. This attribute is a nest that describes the set +/// elements. Given that the netlink attribute length (nla_len) is 16 bits, the largest message is +/// a bit larger than 64 KBytes. +pub fn nft_nlmsg_maxsize() -> u32 { + u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32 +} + +#[inline] +pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize { + // align on a 4 bytes boundary + (size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1) +} + +#[inline] +pub const fn pad_netlink_object<T>() -> usize { + let size = size_of::<T>(); + pad_netlink_object_with_variable_size(size) +} + +pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 { + ((x & 0xff00) >> 8) as u8 +} + +pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 { + (x & 0x00ff) as u8 +} + +pub struct NfNetlinkWriter<'a> { + buf: &'a mut Vec<u8>, + headers: Option<(usize, usize)>, +} + +impl<'a> NfNetlinkWriter<'a> { + pub fn new(buf: &'a mut Vec<u8>) -> NfNetlinkWriter<'a> { + NfNetlinkWriter { buf, headers: None } + } + + pub fn add_data_zeroed<'b>(&'b mut self, size: usize) -> &'b mut [u8] { + let padded_size = pad_netlink_object_with_variable_size(size); + let start = self.buf.len(); + self.buf.resize(start + padded_size, 0); + + if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers { + let mut hdr: &mut nlmsghdr = unsafe { + std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr) + }; + hdr.nlmsg_len += padded_size as u32; + } + + &mut self.buf[start..start + size] + } + + // rewrite of `__nftnl_nlmsg_build_hdr` + pub fn write_header( + &mut self, + msg_type: u16, + family: ProtocolFamily, + flags: u16, + seq: u32, + ressource_id: Option<u16>, + ) { + if self.headers.is_some() { + error!("Calling write_header while still holding headers open!?"); + } + + let nlmsghdr_len = pad_netlink_object::<nlmsghdr>(); + let nfgenmsg_len = pad_netlink_object::<nfgenmsg>(); + + let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len); + let mut hdr: &mut nlmsghdr = + unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) }; + hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32; + hdr.nlmsg_type = msg_type; + // batch messages are not specific to the nftables subsystem + if msg_type != NFNL_MSG_BATCH_BEGIN as u16 && msg_type != NFNL_MSG_BATCH_END as u16 { + hdr.nlmsg_type |= (NFNL_SUBSYS_NFTABLES as u16) << 8; + } + hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags; + hdr.nlmsg_seq = seq; + + let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len); + let mut nfgenmsg: &mut nfgenmsg = + unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) }; + nfgenmsg.nfgen_family = family as u8; + nfgenmsg.version = NFNETLINK_V0 as u8; + nfgenmsg.res_id = ressource_id.unwrap_or(0); + + self.headers = Some(( + self.buf.len() - (nlmsghdr_len + nfgenmsg_len), + self.buf.len() - nfgenmsg_len, + )); + } + + pub fn finalize_writing_object(&mut self) { + self.headers = None; + } +} + +pub trait AttributeDecoder { + fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), DecodeError>; +} + +pub trait NfNetlinkDeserializable: Sized { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>; +} + +pub trait NfNetlinkObject: + Sized + AttributeDecoder + NfNetlinkDeserializable + NfNetlinkAttribute +{ + const MSG_TYPE_ADD: u32; + const MSG_TYPE_DEL: u32; + + fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) { + let raw_msg_type = match msg_type { + MsgType::Add => Self::MSG_TYPE_ADD, + MsgType::Del => Self::MSG_TYPE_DEL, + } as u16; + writer.write_header( + raw_msg_type, + self.get_family(), + (if let MsgType::Add = msg_type { + self.get_add_flags() + } else { + self.get_del_flags() + } | NLM_F_ACK) as u16, + seq, + None, + ); + let buf = writer.add_data_zeroed(self.get_size()); + unsafe { + self.write_payload(buf.as_mut_ptr()); + } + writer.finalize_writing_object(); + } + + fn get_family(&self) -> ProtocolFamily; + + fn set_family(&mut self, _family: ProtocolFamily) { + // the default impl do nothing, because some types are family-agnostic + } + + fn with_family(mut self, family: ProtocolFamily) -> Self { + self.set_family(family); + self + } + + fn get_add_flags(&self) -> u32 { + NLM_F_CREATE + } + + fn get_del_flags(&self) -> u32 { + 0 + } +} + +pub type NetlinkType = u16; + +pub trait NfNetlinkAttribute: Debug + Sized { + // is it a nested argument that must be marked with a NLA_F_NESTED flag? + fn is_nested(&self) -> bool { + false + } + + fn get_size(&self) -> usize { + size_of::<Self>() + } + + // example body: std::ptr::copy_nonoverlapping(self as *const Self as *const u8, addr, self.get_size()); + unsafe fn write_payload(&self, addr: *mut u8); +} diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..6ea34c1 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,216 @@ +use std::{ + fmt::{Debug, DebugStruct}, + mem::{size_of, transmute}, +}; + +use crate::{ + error::DecodeError, + nlmsg::{ + get_operation_from_nlmsghdr_type, get_subsystem_from_nlmsghdr_type, pad_netlink_object, + pad_netlink_object_with_variable_size, AttributeDecoder, NetlinkType, NfNetlinkAttribute, + }, + sys::{ + nfgenmsg, nlattr, nlmsgerr, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, + NFNL_MSG_BATCH_END, NFNL_SUBSYS_NFTABLES, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, + NLMSG_ERROR, NLMSG_MIN_TYPE, NLMSG_NOOP, NLM_F_DUMP_INTR, + }, +}; + +pub fn get_nlmsghdr(buf: &[u8]) -> Result<nlmsghdr, DecodeError> { + let size_of_hdr = size_of::<nlmsghdr>(); + + if buf.len() < size_of_hdr { + return Err(DecodeError::BufTooSmall); + } + + let nlmsghdr_ptr = buf[0..size_of_hdr].as_ptr() as *const nlmsghdr; + let nlmsghdr = unsafe { *nlmsghdr_ptr }; + + if nlmsghdr.nlmsg_len as usize > buf.len() || (nlmsghdr.nlmsg_len as usize) < size_of_hdr { + return Err(DecodeError::NlMsgTooSmall); + } + + if nlmsghdr.nlmsg_flags & NLM_F_DUMP_INTR as u16 != 0 { + return Err(DecodeError::ConcurrentGenerationUpdate); + } + + Ok(nlmsghdr) +} + +#[derive(Debug, Clone, PartialEq)] +pub enum NlMsg<'a> { + Done, + Noop, + Error(nlmsgerr), + NfGenMsg(nfgenmsg, &'a [u8]), +} + +pub fn parse_nlmsg<'a>(buf: &'a [u8]) -> Result<(nlmsghdr, NlMsg<'a>), DecodeError> { + // in theory the message is composed of the following parts: + // - nlmsghdr (contains the message size and type) + // - struct nlmsgerr OR nfgenmsg (nftables header that describes the message family) + // - the raw value that we want to validate (if the previous part is nfgenmsg) + let hdr = get_nlmsghdr(buf)?; + + let size_of_hdr = pad_netlink_object::<nlmsghdr>(); + + if hdr.nlmsg_type < NLMSG_MIN_TYPE as u16 { + match hdr.nlmsg_type as u32 { + x if x == NLMSG_NOOP => return Ok((hdr, NlMsg::Noop)), + x if x == NLMSG_ERROR => { + if (hdr.nlmsg_len as usize) < size_of_hdr + size_of::<nlmsgerr>() { + return Err(DecodeError::NlMsgTooSmall); + } + let mut err = unsafe { + *(buf[size_of_hdr..size_of_hdr + size_of::<nlmsgerr>()].as_ptr() + as *const nlmsgerr) + }; + // some APIs return negative values, while other return positive values + err.error = err.error.abs(); + return Ok((hdr, NlMsg::Error(err))); + } + x if x == NLMSG_DONE => return Ok((hdr, NlMsg::Done)), + x => return Err(DecodeError::UnsupportedType(x as u16)), + } + } + + // batch messages are not specific to the nftables subsystem + if hdr.nlmsg_type != NFNL_MSG_BATCH_BEGIN as u16 && hdr.nlmsg_type != NFNL_MSG_BATCH_END as u16 + { + // verify that we are decoding nftables messages + let subsys = get_subsystem_from_nlmsghdr_type(hdr.nlmsg_type); + if subsys != NFNL_SUBSYS_NFTABLES as u8 { + return Err(DecodeError::InvalidSubsystem(subsys)); + } + } + + let size_of_nfgenmsg = pad_netlink_object::<nfgenmsg>(); + if hdr.nlmsg_len as usize > buf.len() + || (hdr.nlmsg_len as usize) < size_of_hdr + size_of_nfgenmsg + { + return Err(DecodeError::NlMsgTooSmall); + } + + let nfgenmsg_ptr = buf[size_of_hdr..size_of_hdr + size_of_nfgenmsg].as_ptr() as *const nfgenmsg; + let nfgenmsg = unsafe { *nfgenmsg_ptr }; + + if nfgenmsg.version != NFNETLINK_V0 as u8 { + return Err(DecodeError::InvalidVersion(nfgenmsg.version)); + } + + let raw_value = &buf[size_of_hdr + size_of_nfgenmsg..hdr.nlmsg_len as usize]; + + Ok((hdr, NlMsg::NfGenMsg(nfgenmsg, raw_value))) +} + +/// Write the attribute, preceded by a `libc::nlattr` +// rewrite of `mnl_attr_put` +pub unsafe fn write_attribute<'a>( + ty: NetlinkType, + obj: &impl NfNetlinkAttribute, + mut buf: *mut u8, +) { + let header_len = pad_netlink_object::<libc::nlattr>(); + // copy the header + *(buf as *mut nlattr) = nlattr { + // nla_len contains the header size + the unpadded attribute length + nla_len: (header_len + obj.get_size() as usize) as u16, + nla_type: if obj.is_nested() { + ty | NLA_F_NESTED as u16 + } else { + ty + }, + }; + buf = buf.offset(pad_netlink_object::<nlattr>() as isize); + // copy the attribute data itself + obj.write_payload(buf); +} + +pub(crate) fn read_attributes<T: AttributeDecoder + Default>(buf: &[u8]) -> Result<T, DecodeError> { + debug!( + "Calling <{} as NfNetlinkDeserialize>::deserialize()", + std::any::type_name::<T>() + ); + let mut remaining_size = buf.len(); + let mut pos = 0; + let mut res = T::default(); + while remaining_size > pad_netlink_object::<nlattr>() { + let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; + // ignore the byteorder and nested attributes + let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; + + pos += pad_netlink_object::<nlattr>(); + let attr_remaining_size = nlattr.nla_len as usize - pad_netlink_object::<nlattr>(); + match T::decode_attribute(&mut res, nla_type, &buf[pos..pos + attr_remaining_size]) { + Ok(()) => {} + Err(DecodeError::UnsupportedAttributeType(t)) => info!( + "Ignoring unsupported attribute type {} for type {}", + t, + std::any::type_name::<T>() + ), + Err(e) => return Err(e), + } + pos += pad_netlink_object_with_variable_size(attr_remaining_size); + + remaining_size -= pad_netlink_object_with_variable_size(nlattr.nla_len as usize); + } + + if remaining_size != 0 { + Err(DecodeError::InvalidDataSize) + } else { + Ok(res) + } +} + +pub trait InnerFormat { + fn inner_format_struct<'a, 'b: 'a>( + &'a self, + s: DebugStruct<'a, 'b>, + ) -> Result<DebugStruct<'a, 'b>, std::fmt::Error>; +} + +pub trait Parsable +where + Self: Sized, +{ + fn parse_object( + buf: &[u8], + add_obj: u32, + del_obj: u32, + ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError>; +} + +impl<T> Parsable for T +where + T: AttributeDecoder + Default + Sized, +{ + fn parse_object( + buf: &[u8], + add_obj: u32, + del_obj: u32, + ) -> Result<(Self, nfgenmsg, &[u8]), DecodeError> { + debug!("parse_object() started"); + let (hdr, msg) = parse_nlmsg(buf)?; + + let op = get_operation_from_nlmsghdr_type(hdr.nlmsg_type) as u32; + + if op != add_obj && op != del_obj { + return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)); + } + + let obj_size = hdr.nlmsg_len as usize + - pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()); + + let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); + let remaining_data = &buf[remaining_data_offset..]; + + let (nfgenmsg, res) = match msg { + NlMsg::NfGenMsg(nfgenmsg, content) => { + (nfgenmsg, read_attributes(&content[..obj_size])?) + } + _ => return Err(DecodeError::UnexpectedType(hdr.nlmsg_type)), + }; + + Ok((res, nfgenmsg, remaining_data)) + } +} diff --git a/src/parser_impls.rs b/src/parser_impls.rs new file mode 100644 index 0000000..b2681bb --- /dev/null +++ b/src/parser_impls.rs @@ -0,0 +1,243 @@ +use std::{fmt::Debug, mem::transmute}; + +use rustables_macros::nfnetlink_struct; + +use crate::{ + error::DecodeError, + expr::Verdict, + nlmsg::{ + pad_netlink_object, pad_netlink_object_with_variable_size, NfNetlinkAttribute, + NfNetlinkDeserializable, NfNetlinkObject, + }, + parser::{write_attribute, Parsable}, + sys::{nlattr, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, NFTA_LIST_ELEM, NLA_TYPE_MASK}, + ProtocolFamily, +}; + +impl NfNetlinkAttribute for u8 { + unsafe fn write_payload(&self, addr: *mut u8) { + *addr = *self; + } +} + +impl NfNetlinkDeserializable for u8 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf[0], &buf[1..])) + } +} + +impl NfNetlinkAttribute for u16 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u16 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((u16::from_be_bytes([buf[0], buf[1]]), &buf[2..])) + } +} + +impl NfNetlinkAttribute for i32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for i32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + +impl NfNetlinkAttribute for u32 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u32 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), + &buf[4..], + )) + } +} + +impl NfNetlinkAttribute for u64 { + unsafe fn write_payload(&self, addr: *mut u8) { + *(addr as *mut Self) = self.to_be(); + } +} + +impl NfNetlinkDeserializable for u64 { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok(( + u64::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], + ]), + &buf[8..], + )) + } +} + +impl NfNetlinkAttribute for String { + fn get_size(&self) -> usize { + self.len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + std::ptr::copy_nonoverlapping(self.as_bytes().as_ptr(), addr, self.len()); + } +} + +impl NfNetlinkDeserializable for String { + fn deserialize(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + // ignore the NULL byte terminator, if any + if buf.len() > 0 && buf[buf.len() - 1] == 0 { + buf = &buf[..buf.len() - 1]; + } + Ok((String::from_utf8(buf.to_vec())?, &[])) + } +} + +impl NfNetlinkAttribute for Vec<u8> { + fn get_size(&self) -> usize { + self.len() + } + + unsafe fn write_payload(&self, addr: *mut u8) { + std::ptr::copy_nonoverlapping(self.as_ptr(), addr, self.len()); + } +} + +impl NfNetlinkDeserializable for Vec<u8> { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + Ok((buf.to_vec(), &[])) + } +} +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(nested = true)] +pub struct NfNetlinkData { + #[field(NFTA_DATA_VALUE)] + value: Vec<u8>, + #[field(NFTA_DATA_VERDICT)] + verdict: Verdict, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Debug + Clone + Eq + Default, +{ + objs: Vec<T>, +} + +impl<T> NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + pub fn add_value(&mut self, e: impl Into<T>) { + self.objs.push(e.into()); + } + + pub fn with_value(mut self, e: impl Into<T>) -> Self { + self.add_value(e); + self + } + + pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> { + self.objs.iter() + } +} + +impl<T> NfNetlinkAttribute for NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn is_nested(&self) -> bool { + true + } + + fn get_size(&self) -> usize { + // one nlattr LIST_ELEM per object + self.objs.iter().fold(0, |acc, item| { + acc + item.get_size() + pad_netlink_object::<nlattr>() + }) + } + + unsafe fn write_payload(&self, mut addr: *mut u8) { + for item in &self.objs { + write_attribute(NFTA_LIST_ELEM, item, addr); + addr = addr.offset((pad_netlink_object::<nlattr>() + item.get_size()) as isize); + } + } +} + +impl<T> NfNetlinkDeserializable for NfNetlinkList<T> +where + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let mut objs = Vec::new(); + + let mut pos = 0; + while buf.len() - pos > pad_netlink_object::<nlattr>() { + let nlattr = unsafe { *transmute::<*const u8, *const nlattr>(buf[pos..].as_ptr()) }; + // ignore the byteorder and nested attributes + let nla_type = nlattr.nla_type & NLA_TYPE_MASK as u16; + + if nla_type != NFTA_LIST_ELEM { + return Err(DecodeError::UnsupportedAttributeType(nla_type)); + } + + let (obj, remaining) = T::deserialize( + &buf[pos + pad_netlink_object::<nlattr>()..pos + nlattr.nla_len as usize], + )?; + if remaining.len() != 0 { + return Err(DecodeError::InvalidDataSize); + } + objs.push(obj); + + pos += pad_netlink_object_with_variable_size(nlattr.nla_len as usize); + } + + if pos != buf.len() { + Err(DecodeError::InvalidDataSize) + } else { + Ok((Self { objs }, &[])) + } + } +} + +impl<O, T> From<Vec<O>> for NfNetlinkList<T> +where + T: From<O>, + T: NfNetlinkDeserializable + NfNetlinkAttribute + Clone + Eq + Default, +{ + fn from(v: Vec<O>) -> Self { + NfNetlinkList { + objs: v.into_iter().map(T::from).collect(), + } + } +} + +impl<T> NfNetlinkDeserializable for T +where + T: NfNetlinkObject + Parsable, +{ + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> { + let (mut obj, nfgenmsg, remaining_data) = Self::parse_object( + buf, + <T as NfNetlinkObject>::MSG_TYPE_ADD, + <T as NfNetlinkObject>::MSG_TYPE_DEL, + )?; + obj.set_family(ProtocolFamily::try_from(nfgenmsg.nfgen_family as i32)?); + + Ok((obj, remaining_data)) + } +} diff --git a/src/query.rs b/src/query.rs index bc1d02e..7cf5050 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,129 +1,178 @@ -use crate::{nft_nlmsg_maxsize, sys, ProtoFamily}; -use sys::libc; - -/// Returns a buffer containing a netlink message which requests a list of all the netfilter -/// matching objects (e.g. tables, chains, rules, ...). -/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and optionally a callback -/// to execute on the header, to set parameters for example. -/// To pass arbitrary data inside that callback, please use a closure. -pub fn get_list_of_objects<Error>( - seq: u32, - target: u16, - setup_cb: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, -) -> Result<Vec<u8>, Error> { - let mut buffer = vec![0; nft_nlmsg_maxsize() as usize]; - let hdr = unsafe { - &mut *sys::nftnl_nlmsg_build_hdr( - buffer.as_mut_ptr() as *mut libc::c_char, - target, - ProtoFamily::Unspec as u16, - (libc::NLM_F_ROOT | libc::NLM_F_MATCH) as u16, - seq, - ) - }; - if let Some(cb) = setup_cb { - cb(hdr)?; - } - Ok(buffer) -} - -#[cfg(feature = "query")] -mod inner { - use crate::FinalizedBatch; - - use super::*; - - #[derive(thiserror::Error, Debug)] - pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - - #[error("Unable to send netlink command to netfilter")] - NetlinkSendError(#[source] std::io::Error), - - #[error("Error while reading from netlink socket")] - NetlinkRecvError(#[source] std::io::Error), +use std::os::unix::prelude::RawFd; + +use nix::sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType}; + +use crate::{ + error::QueryError, + nlmsg::{ + nft_nlmsg_maxsize, pad_netlink_object_with_variable_size, NfNetlinkAttribute, + NfNetlinkObject, NfNetlinkWriter, + }, + parser::{parse_nlmsg, NlMsg}, + sys::{NLM_F_DUMP, NLM_F_MULTI}, + ProtocolFamily, +}; + +pub(crate) fn recv_and_process<'a, T>( + sock: RawFd, + max_seq: Option<u32>, + cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), QueryError>>, + working_data: &'a mut T, +) -> Result<(), QueryError> { + let mut msg_buffer = vec![0; 2 * nft_nlmsg_maxsize() as usize]; + let mut buf_start = 0; + let mut end_pos = 0; + + loop { + let nb_recv = socket::recv(sock, &mut msg_buffer[end_pos..], MsgFlags::empty()) + .map_err(QueryError::NetlinkRecvError)?; + if nb_recv <= 0 { + return Ok(()); + } + end_pos += nb_recv; + loop { + let buf = &msg_buffer.as_slice()[buf_start..end_pos]; + // exit the loop and try to receive further messages when we consumed all the buffer + if buf.len() == 0 { + break; + } - #[error("Error while processing an incoming netlink message")] - ProcessNetlinkError(#[source] std::io::Error), + debug!("Calling parse_nlmsg"); + let (nlmsghdr, msg) = parse_nlmsg(&buf)?; + debug!("Got a valid netlink message: {:?} {:?}", nlmsghdr, msg); + + match msg { + NlMsg::Done => { + return Ok(()); + } + NlMsg::Error(e) => { + if e.error != 0 { + return Err(QueryError::NetlinkError(e)); + } + } + NlMsg::Noop => {} + NlMsg::NfGenMsg(_genmsg, _data) => { + if let Some(cb) = cb { + cb(&buf[0..nlmsghdr.nlmsg_len as usize], working_data)?; + } + } + } - #[error("Custom error when customizing the query")] - InitError(#[from] Box<dyn std::error::Error + 'static>), + // we cannot know when a sequence of messages will end if the messages do not end + // with an NlMsg::Done marker while if a maximum sequence number wasn't specified + if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 { + return Err(QueryError::UndecidableMessageTermination); + } - #[error("Couldn't allocate a netlink object, out of memory ?")] - NetlinkAllocationFailed, - } + // retrieve the next message + if let Some(max_seq) = max_seq { + if nlmsghdr.nlmsg_seq >= max_seq { + return Ok(()); + } + } - /// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper - /// function called by mnl::cb_run2. - /// The callback expects a tuple of additional data (supplied as an argument to this function) - /// and of the output vector, to which it should append the parsed object it received. - pub fn list_objects_with_data<'a, A, T>( - data_type: u16, - cb: fn(&libc::nlmsghdr, &mut (&'a A, &mut Vec<T>)) -> libc::c_int, - additional_data: &'a A, - req_hdr_customize: Option<&dyn Fn(&mut libc::nlmsghdr) -> Result<(), Error>>, - ) -> Result<Vec<T>, Error> - where - T: 'a, - { - debug!("listing objects of kind {}", data_type); - let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; - - let seq = 0; - let portid = 0; - - let chains_buf = get_list_of_objects(seq, data_type, req_hdr_customize)?; - socket.send(&chains_buf).map_err(Error::NetlinkSendError)?; - - let mut res = Vec::new(); - - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; - while socket - .recv(&mut msg_buffer) - .map_err(Error::NetlinkRecvError)? - > 0 - { - if let mnl::CbResult::Stop = mnl::cb_run2( - &msg_buffer, - seq, - portid, - cb, - &mut (additional_data, &mut res), - ) - .map_err(Error::ProcessNetlinkError)? - { - break; + // netlink messages are 4bytes aligned + let aligned_length = pad_netlink_object_with_variable_size(nlmsghdr.nlmsg_len as usize); + buf_start += aligned_length; + } + // Ensure that we always have nft_nlmsg_maxsize() free space available in the buffer. + // We achieve this by relocating the buffer content at the beginning of the buffer + if end_pos >= nft_nlmsg_maxsize() as usize { + if buf_start < end_pos { + unsafe { + std::ptr::copy( + msg_buffer[buf_start..end_pos].as_ptr(), + msg_buffer.as_mut_ptr(), + end_pos - buf_start, + ); + } } + end_pos = end_pos - buf_start; + buf_start = 0; } - - Ok(res) } +} - pub fn send_batch(batch: &mut FinalizedBatch) -> Result<(), Error> { - let socket = mnl::Socket::new(mnl::Bus::Netfilter).map_err(Error::NetlinkOpenError)?; +pub(crate) fn socket_close_wrapper<E>( + sock: RawFd, + cb: impl FnOnce(RawFd) -> Result<(), E>, +) -> Result<(), QueryError> +where + QueryError: From<E>, +{ + let ret = cb(sock); - let seq = 0; - let portid = socket.portid(); + // we don't need to shutdown the socket (in fact, Linux doesn't support that operation; + // and return EOPNOTSUPP if we try) + nix::unistd::close(sock).map_err(QueryError::CloseFailed)?; - socket.send_all(batch).map_err(Error::NetlinkSendError)?; - debug!("sent"); + Ok(ret?) +} - let mut msg_buffer = vec![0; nft_nlmsg_maxsize() as usize]; - while socket - .recv(&mut msg_buffer) - .map_err(Error::NetlinkRecvError)? - > 0 - { - if let mnl::CbResult::Stop = - mnl::cb_run(&msg_buffer, seq, portid).map_err(Error::ProcessNetlinkError)? - { - break; - } +/// Returns a buffer containing a netlink message which requests a list of all the netfilter +/// matching objects (e.g. tables, chains, rules, ...). +/// Supply the type of objects to retrieve (e.g. libc::NFT_MSG_GETTABLE), and a search filter. +pub fn get_list_of_objects<T: NfNetlinkAttribute>( + msg_type: u16, + seq: u32, + filter: Option<&T>, +) -> Result<Vec<u8>, QueryError> { + let mut buffer = Vec::new(); + let mut writer = NfNetlinkWriter::new(&mut buffer); + writer.write_header( + msg_type, + ProtocolFamily::Unspec, + NLM_F_DUMP as u16, + seq, + None, + ); + if let Some(filter) = filter { + let buf = writer.add_data_zeroed(filter.get_size()); + unsafe { + filter.write_payload(buf.as_mut_ptr()); } - Ok(()) } + writer.finalize_writing_object(); + Ok(buffer) } -#[cfg(feature = "query")] -pub use inner::*; +/// Lists objects of a certain type (e.g. libc::NFT_MSG_GETTABLE) with the help of a helper +/// function called by mnl::cb_run2. +/// The callback expects a tuple of additional data (supplied as an argument to this function) +/// and of the output vector, to which it should append the parsed object it received. +pub fn list_objects_with_data<'a, Object, Accumulator>( + data_type: u16, + cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), QueryError>, + filter: Option<&Object>, + working_data: &'a mut Accumulator, +) -> Result<(), QueryError> +where + Object: NfNetlinkObject + NfNetlinkAttribute, +{ + debug!("Listing objects of kind {}", data_type); + let sock = socket::socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + SockProtocol::NetlinkNetFilter, + ) + .map_err(QueryError::NetlinkOpenError)?; + + let seq = 0; + + let chains_buf = get_list_of_objects(data_type, seq, filter)?; + socket::send(sock, &chains_buf, MsgFlags::empty()).map_err(QueryError::NetlinkSendError)?; + + socket_close_wrapper(sock, move |sock| { + // the kernel should return NLM_F_MULTI objects + recv_and_process( + sock, + None, + Some(&|buf: &[u8], working_data: &mut Accumulator| { + debug!("Calling Object::deserialize()"); + cb(Object::deserialize(buf)?.0, working_data) + }), + working_data, + ) + }) +} diff --git a/src/rule.rs b/src/rule.rs index 2ee5308..858b9ce 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -1,341 +1,111 @@ -use crate::expr::ExpressionWrapper; -use crate::{chain::Chain, expr::Expression, MsgType}; -use crate::sys::{self, libc}; -use std::ffi::{c_void, CStr, CString}; use std::fmt::Debug; -use std::os::raw::c_char; -use std::rc::Rc; + +use rustables_macros::nfnetlink_struct; + +use crate::chain::Chain; +use crate::error::{BuilderError, QueryError}; +use crate::expr::{ExpressionList, RawExpression}; +use crate::nlmsg::NfNetlinkObject; +use crate::query::list_objects_with_data; +use crate::sys::{ + NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_HANDLE, NFTA_RULE_ID, NFTA_RULE_POSITION, + NFTA_RULE_TABLE, NFTA_RULE_USERDATA, NFT_MSG_DELRULE, NFT_MSG_NEWRULE, NLM_F_APPEND, + NLM_F_CREATE, +}; +use crate::{Batch, ProtocolFamily}; /// A nftables firewall rule. +#[derive(Clone, PartialEq, Eq, Default, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Rule { - pub(crate) rule: *mut sys::nftnl_rule, - pub(crate) chain: Rc<Chain>, + family: ProtocolFamily, + #[field(NFTA_RULE_TABLE)] + table: String, + #[field(NFTA_RULE_CHAIN)] + chain: String, + #[field(NFTA_RULE_HANDLE)] + handle: u64, + #[field(NFTA_RULE_EXPRESSIONS)] + expressions: ExpressionList, + #[field(NFTA_RULE_POSITION)] + position: u64, + #[field(NFTA_RULE_USERDATA)] + userdata: Vec<u8>, + #[field(NFTA_RULE_ID)] + id: u32, } impl Rule { /// Creates a new rule object in the given [`Chain`]. /// /// [`Chain`]: struct.Chain.html - pub fn new(chain: Rc<Chain>) -> Rule { - unsafe { - let rule = try_alloc!(sys::nftnl_rule_alloc()); - sys::nftnl_rule_set_u32( - rule, - sys::NFTNL_RULE_FAMILY as u16, - chain.get_table().get_family() as u32, - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_TABLE as u16, - chain.get_table().get_name().as_ptr(), - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_CHAIN as u16, - chain.get_name().as_ptr(), - ); - - Rule { rule, chain } - } - } - - pub unsafe fn from_raw(rule: *mut sys::nftnl_rule, chain: Rc<Chain>) -> Self { - Rule { rule, chain } - } - - pub fn get_position(&self) -> u64 { - unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_POSITION as u16) } - } - - /// Sets the position of this rule within the chain it lives in. By default a new rule is added - /// to the end of the chain. - pub fn set_position(&mut self, position: u64) { - unsafe { - sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_POSITION as u16, position); - } - } - - pub fn get_handle(&self) -> u64 { - unsafe { sys::nftnl_rule_get_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16) } - } - - pub fn set_handle(&mut self, handle: u64) { - unsafe { - sys::nftnl_rule_set_u64(self.rule, sys::NFTNL_RULE_HANDLE as u16, handle); - } - } - - /// Adds an expression to this rule. Expressions are evaluated from first to last added. - /// As soon as an expression does not match the packet it's being evaluated for, evaluation - /// stops and the packet is evaluated against the next rule in the chain. - pub fn add_expr(&mut self, expr: &impl Expression) { - unsafe { sys::nftnl_rule_add_expr(self.rule, expr.to_expr(self)) } - } - - /// Returns a reference to the [`Chain`] this rule lives in. - /// - /// [`Chain`]: struct.Chain.html - pub fn get_chain(&self) -> Rc<Chain> { - self.chain.clone() - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_rule_get_str(self.rule, sys::NFTNL_RULE_USERDATA as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_rule_set_str(self.rule, sys::NFTNL_RULE_USERDATA as u16, data.as_ptr()); - } - } - - /// Returns a textual description of the rule. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_rule_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.rule, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Retrieves an iterator to loop over the expressions of the rule. - pub fn get_exprs(self: &Rc<Self>) -> RuleExprsIter { - RuleExprsIter::new(self.clone()) - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_rule { - self.rule as *const sys::nftnl_rule - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&mut self) -> *mut sys::nftnl_rule { - self.rule - } - - /// Performs a deep comparizon of rules, by checking they have the same expressions inside. - /// This is not enabled by default in our PartialEq implementation because of the difficulty to - /// compare an expression generated by the library with the expressions returned by the kernel - /// when iterating over the currently in-use rules. The kernel-returned expressions may have - /// additional attributes despite being generated from the same rule. This is particularly true - /// for the 'nat' expression). - pub fn deep_eq(&self, other: &Self) -> bool { - if self != other { - return false; - } - - let self_exprs = - try_alloc!(unsafe { sys::nftnl_expr_iter_create(self.rule as *const sys::nftnl_rule) }); - let other_exprs = try_alloc!(unsafe { - sys::nftnl_expr_iter_create(other.rule as *const sys::nftnl_rule) - }); - - loop { - let self_next = unsafe { sys::nftnl_expr_iter_next(self_exprs) }; - let other_next = unsafe { sys::nftnl_expr_iter_next(other_exprs) }; - if self_next.is_null() && other_next.is_null() { - return true; - } else if self_next.is_null() || other_next.is_null() { - return false; - } - - // we are falling back on comparing the strings, because there is no easy mechanism to - // perform a memcmp() between the two expressions :/ - let mut self_str = [0; 256]; - let mut other_str = [0; 256]; - unsafe { - sys::nftnl_expr_snprintf( - self_str.as_mut_ptr(), - (self_str.len() - 1) as u64, - self_next, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - sys::nftnl_expr_snprintf( - other_str.as_mut_ptr(), - (other_str.len() - 1) as u64, - other_next, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); + pub fn new(chain: &Chain) -> Result<Rule, BuilderError> { + Ok(Rule::default() + .with_family(chain.get_family()) + .with_table( + chain + .get_table() + .ok_or(BuilderError::MissingChainInformationError)?, + ) + .with_chain( + chain + .get_name() + .ok_or(BuilderError::MissingChainInformationError)?, + )) + } + + pub fn add_expr(&mut self, e: impl Into<RawExpression>) { + let exprs = match self.get_mut_expressions() { + Some(x) => x, + None => { + self.set_expressions(ExpressionList::default()); + self.get_mut_expressions().unwrap() } - - if self_str != other_str { - return false; - } - } - } -} - -impl Debug for Rule { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -impl PartialEq for Rule { - fn eq(&self, other: &Self) -> bool { - if self.get_chain() != other.get_chain() { - return false; - } - - unsafe { - if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_HANDLE as u16) - && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_HANDLE as u16) - { - if self.get_handle() != other.get_handle() { - return false; - } - } - if sys::nftnl_rule_is_set(self.rule, sys::NFTNL_RULE_POSITION as u16) - && sys::nftnl_rule_is_set(other.rule, sys::NFTNL_RULE_POSITION as u16) - { - if self.get_position() != other.get_position() { - return false; - } - } - } - - return false; - } -} - -unsafe impl crate::NlMsg for Rule { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let type_ = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWRULE, - MsgType::Del => libc::NFT_MSG_DELRULE, }; - let flags: u16 = match msg_type { - MsgType::Add => (libc::NLM_F_CREATE | libc::NLM_F_APPEND | libc::NLM_F_EXCL) as u16, - MsgType::Del => 0u16, - } | libc::NLM_F_ACK as u16; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.chain.get_table().get_family() as u16, - flags, - seq, - ); - sys::nftnl_rule_nlmsg_build_payload(header, self.rule); + exprs.add_value(e); } -} -impl Drop for Rule { - fn drop(&mut self) { - unsafe { sys::nftnl_rule_free(self.rule) }; + pub fn with_expr(mut self, e: impl Into<RawExpression>) -> Self { + self.add_expr(e); + self } -} - -pub struct RuleExprsIter { - rule: Rc<Rule>, - iter: *mut sys::nftnl_expr_iter, -} -impl RuleExprsIter { - fn new(rule: Rc<Rule>) -> Self { - let iter = - try_alloc!(unsafe { sys::nftnl_expr_iter_create(rule.rule as *const sys::nftnl_rule) }); - RuleExprsIter { rule, iter } + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self } } -impl Iterator for RuleExprsIter { - type Item = ExpressionWrapper; +impl NfNetlinkObject for Rule { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWRULE; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELRULE; - fn next(&mut self) -> Option<Self::Item> { - let next = unsafe { sys::nftnl_expr_iter_next(self.iter) }; - if next.is_null() { - trace!("RulesExprsIter iterator ending"); - None - } else { - trace!("RulesExprsIter returning new expression"); - Some(ExpressionWrapper { - expr: next, - rule: self.rule.clone(), - }) - } + fn get_family(&self) -> ProtocolFamily { + self.family } -} -impl Drop for RuleExprsIter { - fn drop(&mut self) { - unsafe { sys::nftnl_expr_iter_destroy(self.iter) }; + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } -} - -#[cfg(feature = "query")] -pub fn get_rules_cb( - header: &libc::nlmsghdr, - (chain, rules): &mut (&Rc<Chain>, &mut Vec<Rule>), -) -> libc::c_int { - unsafe { - let rule = sys::nftnl_rule_alloc(); - if rule == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; - } - let err = sys::nftnl_rule_nlmsg_parse(header, rule); - if err < 0 { - error!("Failed to parse nelink rule message - {}", err); - sys::nftnl_rule_free(rule); - return err; - } - rules.push(Rule::from_raw(rule, chain.clone())); + // append at the end of the chain, instead of the beginning + fn get_add_flags(&self) -> u32 { + NLM_F_CREATE | NLM_F_APPEND } - mnl::mnl_sys::MNL_CB_OK } -#[cfg(feature = "query")] -pub fn list_rules_for_chain(chain: &Rc<Chain>) -> Result<Vec<Rule>, crate::query::Error> { - crate::query::list_objects_with_data( +pub fn list_rules_for_chain(chain: &Chain) -> Result<Vec<Rule>, QueryError> { + let mut result = Vec::new(); + list_objects_with_data( libc::NFT_MSG_GETRULE as u16, - get_rules_cb, - &chain, - // only retrieve rules from the currently targetted chain - Some(&|hdr| unsafe { - let rule = sys::nftnl_rule_alloc(); - if rule as *const _ == std::ptr::null() { - return Err(crate::query::Error::NetlinkAllocationFailed); - } - - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_TABLE as u16, - chain.get_table().get_name().as_ptr(), - ); - sys::nftnl_rule_set_u32( - rule, - sys::NFTNL_RULE_FAMILY as u16, - chain.get_table().get_family() as u32, - ); - sys::nftnl_rule_set_str( - rule, - sys::NFTNL_RULE_CHAIN as u16, - chain.get_name().as_ptr(), - ); - - sys::nftnl_rule_nlmsg_build_payload(hdr, rule); - - sys::nftnl_rule_free(rule); + &|rule: Rule, rules: &mut Vec<Rule>| { + rules.push(rule); Ok(()) - }), - ) + }, + // only retrieve rules from the currently targetted chain + Some(&Rule::new(chain)?), + &mut result, + )?; + Ok(result) } diff --git a/src/rule_methods.rs b/src/rule_methods.rs index d7145d7..dff9bf6 100644 --- a/src/rule_methods.rs +++ b/src/rule_methods.rs @@ -1,230 +1,211 @@ -use crate::{Batch, Rule, nft_expr, sys::libc}; -use crate::expr::{LogGroup, LogPrefix}; -use ipnetwork::IpNetwork; -use std::ffi::{CString, NulError}; +use std::ffi::CString; use std::net::IpAddr; -use std::num::ParseIntError; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Unable to open netlink socket to netfilter")] - NetlinkOpenError(#[source] std::io::Error), - #[error("Firewall is already started")] - AlreadyDone, - #[error("Error converting from a C string to a string")] - CStringError(#[from] NulError), - #[error("no interface found under that name")] - NoSuchIface, - #[error("Error converting from a string to an integer")] - ParseError(#[from] ParseIntError), - #[error("the interface name is too long")] - NameTooLong, -} +use ipnetwork::IpNetwork; +use crate::data_type::ip_to_vec; +use crate::error::BuilderError; +use crate::expr::ct::{ConnTrackState, Conntrack, ConntrackKey}; +use crate::expr::{ + Bitwise, Cmp, CmpOp, HighLevelPayload, IPv4HeaderField, IPv6HeaderField, Immediate, Meta, + MetaType, NetworkHeaderField, TCPHeaderField, TransportHeaderField, UDPHeaderField, + VerdictKind, +}; +use crate::Rule; /// Simple protocol description. Note that it does not implement other layer 4 protocols as /// IGMP et al. See [`Rule::igmp`] for a workaround. -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Protocol { TCP, - UDP + UDP, } -/// A RuleMethods trait over [`crate::Rule`], to make it match some criteria, and give it a -/// verdict. Mostly adapted from [talpid-core's firewall]. -/// All methods return the rule itself, allowing them to be chained. Usage example : -/// ```rust -/// use rustables::{Batch, Chain, ChainMethods, Protocol, ProtoFamily, Rule, RuleMethods, Table, MsgType, Hook}; -/// use std::ffi::CString; -/// use std::rc::Rc; -/// let table = Rc::new(Table::new(&CString::new("main_table").unwrap(), ProtoFamily::Inet)); -/// let mut batch = Batch::new(); -/// batch.add(&table, MsgType::Add); -/// let inbound = Rc::new(Chain::from_hook(Hook::In, Rc::clone(&table)) -/// .add_to_batch(&mut batch)); -/// let rule = Rule::new(inbound) -/// .dport("80", &Protocol::TCP).unwrap() -/// .accept() -/// .add_to_batch(&mut batch); -/// ``` -/// [talpid-core's firewall]: -/// https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs -pub trait RuleMethods { - /// Matches ICMP packets. - fn icmp(self) -> Self; - /// Matches IGMP packets. - fn igmp(self) -> Self; - /// Matches packets to destination `port` and `protocol`. - fn dport(self, port: &str, protocol: &Protocol) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets on `protocol`. - fn protocol(self, protocol: Protocol) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets in an already established connection. - fn established(self) -> Self where Self: std::marker::Sized; - /// Matches packets going through `iface_index`. Interface indexes can be queried with - /// `iface_index()`. - fn iface_id(self, iface_index: libc::c_uint) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo". - fn iface(self, iface_name: &str) -> Result<Self, Error> - where Self: std::marker::Sized; - /// Adds a log instruction to the rule. `group` is the NFLog group, `prefix` is a prefix - /// appended to each log line. - fn log(self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self; - /// Matches packets whose source IP address is `saddr`. - fn saddr(self, ip: IpAddr) -> Self; - /// Matches packets whose source network is `snet`. - fn snetwork(self, ip: IpNetwork) -> Self; - /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. - fn accept(self) -> Self; - /// Adds the `Drop` verdict to the rule. The packet will be dropped. - fn drop(self) -> Self; - /// Appends this rule to `batch`. - fn add_to_batch(self, batch: &mut Batch) -> Self; -} - -/// A trait to add helper functions to match some criterium over `crate::Rule`. -impl RuleMethods for Rule { - fn icmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - //self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMPV6 as u8)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_ICMP as u8)); - self - } - fn igmp(mut self) -> Self { - self.add_expr(&nft_expr!(meta l4proto)); - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_IGMP as u8)); +impl Rule { + fn match_port(mut self, port: u16, protocol: Protocol, source: bool) -> Self { + self = self.protocol(protocol); + self.add_expr( + HighLevelPayload::Transport(match protocol { + Protocol::TCP => TransportHeaderField::Tcp(if source { + TCPHeaderField::Sport + } else { + TCPHeaderField::Dport + }), + Protocol::UDP => TransportHeaderField::Udp(if source { + UDPHeaderField::Sport + } else { + UDPHeaderField::Dport + }), + }) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, port.to_be_bytes())); self } - fn dport(mut self, port: &str, protocol: &Protocol) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - &Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - self.add_expr(&nft_expr!(payload tcp dport)); - }, - &Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - self.add_expr(&nft_expr!(payload udp dport)); - } - } - // Convert the port to Big-Endian number spelling. - // See https://github.com/mullvad/mullvadvpn-app/blob/d92376b4d1df9b547930c68aa9bae9640ff2a022/talpid-core/src/firewall/linux.rs#L969 - self.add_expr(&nft_expr!(cmp == port.parse::<u16>()?.to_be())); - Ok(self) - } - fn protocol(mut self, protocol: Protocol) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta l4proto)); - match protocol { - Protocol::TCP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_TCP as u8)); - }, - Protocol::UDP => { - self.add_expr(&nft_expr!(cmp == libc::IPPROTO_UDP as u8)); - } - } - Ok(self) - } - fn established(mut self) -> Self { - let allowed_states = crate::expr::ct::States::ESTABLISHED.bits(); - self.add_expr(&nft_expr!(ct state)); - self.add_expr(&nft_expr!(bitwise mask allowed_states, xor 0u32)); - self.add_expr(&nft_expr!(cmp != 0u32)); - self - } - fn iface_id(mut self, iface_index: libc::c_uint) -> Result<Self, Error> { - self.add_expr(&nft_expr!(meta iif)); - self.add_expr(&nft_expr!(cmp == iface_index)); - Ok(self) - } - fn iface(mut self, iface_name: &str) -> Result<Self, Error> { - if iface_name.len() >= libc::IFNAMSIZ { - return Err(Error::NameTooLong); - } - let mut name_arr = [0u8; libc::IFNAMSIZ]; - for (pos, i) in iface_name.bytes().enumerate() { - name_arr[pos] = i; - } - self.add_expr(&nft_expr!(meta iifname)); - self.add_expr(&nft_expr!(cmp == name_arr.as_ref())); - Ok(self) - } - fn saddr(mut self, ip: IpAddr) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + pub fn match_ip(mut self, ip: IpAddr, source: bool) -> Self { + self.add_expr(Meta::new(MetaType::NfProto)); match ip { IpAddr::V4(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); + } IpAddr::V6(addr) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(cmp == addr)) + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Cmp::new(CmpOp::Eq, addr.octets())); } } self } - fn snetwork(mut self, net: IpNetwork) -> Self { - self.add_expr(&nft_expr!(meta nfproto)); + + pub fn match_network(mut self, net: IpNetwork, source: bool) -> Result<Self, BuilderError> { + self.add_expr(Meta::new(MetaType::NfProto)); match net { IpNetwork::V4(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV4 as u8)); - self.add_expr(&nft_expr!(payload ipv4 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor 0u32)); - self.add_expr(&nft_expr!(cmp == net.network())); - }, + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV4 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv4(if source { + IPv4HeaderField::Saddr + } else { + IPv4HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u32.to_be_bytes())?); + } IpNetwork::V6(_) => { - self.add_expr(&nft_expr!(cmp == libc::NFPROTO_IPV6 as u8)); - self.add_expr(&nft_expr!(payload ipv6 saddr)); - self.add_expr(&nft_expr!(bitwise mask net.mask(), xor &[0u16; 8][..])); - self.add_expr(&nft_expr!(cmp == net.network())); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8])); + self.add_expr( + HighLevelPayload::Network(NetworkHeaderField::IPv6(if source { + IPv6HeaderField::Saddr + } else { + IPv6HeaderField::Daddr + })) + .build(), + ); + self.add_expr(Bitwise::new(ip_to_vec(net.mask()), 0u128.to_be_bytes())?); } } + self.add_expr(Cmp::new(CmpOp::Eq, ip_to_vec(net.network()))); + Ok(self) + } +} + +impl Rule { + /// Matches ICMP packets. + pub fn icmp(mut self) -> Self { + // quid of icmpv6? + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_ICMP as u8])); self } - fn log(mut self, group: Option<LogGroup>, prefix: Option<LogPrefix>) -> Self { - match (group.is_some(), prefix.is_some()) { - (true, true) => { - self.add_expr(&nft_expr!(log group group prefix prefix)); - }, - (false, true) => { - self.add_expr(&nft_expr!(log prefix prefix)); - }, - (true, false) => { - self.add_expr(&nft_expr!(log group group)); - }, - (false, false) => { - self.add_expr(&nft_expr!(log)); - } - } + /// Matches IGMP packets. + pub fn igmp(mut self) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new(CmpOp::Eq, [libc::IPPROTO_IGMP as u8])); self } - fn accept(mut self) -> Self { - self.add_expr(&nft_expr!(verdict accept)); + /// Matches packets from source `port` and `protocol`. + pub fn sport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets to destination `port` and `protocol`. + pub fn dport(self, port: u16, protocol: Protocol) -> Self { + self.match_port(port, protocol, false) + } + /// Matches packets on `protocol`. + pub fn protocol(mut self, protocol: Protocol) -> Self { + self.add_expr(Meta::new(MetaType::L4Proto)); + self.add_expr(Cmp::new( + CmpOp::Eq, + [match protocol { + Protocol::TCP => libc::IPPROTO_TCP, + Protocol::UDP => libc::IPPROTO_UDP, + } as u8], + )); + self + } + /// Matches packets in an already established connection. + pub fn established(mut self) -> Result<Self, BuilderError> { + let allowed_states = ConnTrackState::ESTABLISHED.bits(); + self.add_expr(Conntrack::new(ConntrackKey::State)); + self.add_expr(Bitwise::new( + allowed_states.to_le_bytes(), + 0u32.to_be_bytes(), + )?); + self.add_expr(Cmp::new(CmpOp::Neq, 0u32.to_be_bytes())); + Ok(self) + } + /// Matches packets going through `iface_index`. Interface indexes can be queried with + /// `iface_index()`. + pub fn iface_id(mut self, iface_index: libc::c_uint) -> Self { + self.add_expr(Meta::new(MetaType::Iif)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_index.to_be_bytes())); self } - fn drop(mut self) -> Self { - self.add_expr(&nft_expr!(verdict drop)); + /// Matches packets going through `iface_name`, an interface name, as in "wlan0" or "lo" + pub fn iface(mut self, iface_name: &str) -> Result<Self, BuilderError> { + if iface_name.len() >= libc::IFNAMSIZ { + return Err(BuilderError::InterfaceNameTooLong); + } + let mut iface_vec = iface_name.as_bytes().to_vec(); + // null terminator + iface_vec.push(0u8); + + self.add_expr(Meta::new(MetaType::IifName)); + self.add_expr(Cmp::new(CmpOp::Eq, iface_vec)); + Ok(self) + } + /// Matches packets whose source IP address is `saddr`. + pub fn saddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, true) + } + /// Matches packets whose destination IP address is `saddr`. + pub fn daddr(self, ip: IpAddr) -> Self { + self.match_ip(ip, false) + } + /// Matches packets whose source network is `net`. + pub fn snetwork(self, net: IpNetwork) -> Result<Self, BuilderError> { + self.match_network(net, true) + } + /// Matches packets whose destination network is `net`. + pub fn dnetwork(self, net: IpNetwork) -> Result<Self, BuilderError> { + self.match_network(net, false) + } + /// Adds the `Accept` verdict to the rule. The packet will be sent to destination. + pub fn accept(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Accept)); self } - fn add_to_batch(self, batch: &mut Batch) -> Self { - batch.add(&self, crate::MsgType::Add); + /// Adds the `Drop` verdict to the rule. The packet will be dropped. + pub fn drop(mut self) -> Self { + self.add_expr(Immediate::new_verdict(VerdictKind::Drop)); self } } /// Looks up the interface index for a given interface name. -pub fn iface_index(name: &str) -> Result<libc::c_uint, Error> { +pub fn iface_index(name: &str) -> Result<libc::c_uint, std::io::Error> { let c_name = CString::new(name)?; let index = unsafe { libc::if_nametoindex(c_name.as_ptr()) }; match index { - 0 => Err(Error::NoSuchIface), - _ => Ok(index) + 0 => Err(std::io::Error::last_os_error()), + _ => Ok(index), } } - - @@ -1,273 +1,116 @@ -use crate::sys::{self, libc}; -use crate::{table::Table, MsgType, ProtoFamily}; -use std::{ - cell::Cell, - ffi::{c_void, CStr, CString}, - fmt::Debug, - net::{Ipv4Addr, Ipv6Addr}, - os::raw::c_char, - rc::Rc, +use rustables_macros::nfnetlink_struct; + +use crate::data_type::DataType; +use crate::error::BuilderError; +use crate::nlmsg::NfNetlinkObject; +use crate::parser_impls::{NfNetlinkData, NfNetlinkList}; +use crate::sys::{ + NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, NFTA_SET_ELEM_LIST_SET, + NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_FLAGS, NFTA_SET_ID, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE, + NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_DELSETELEM, + NFT_MSG_NEWSET, NFT_MSG_NEWSETELEM, }; - -#[macro_export] -macro_rules! nft_set { - ($name:expr, $id:expr, $table:expr, $family:expr) => { - $crate::set::Set::new($name, $id, $table, $family) - }; - ($name:expr, $id:expr, $table:expr, $family:expr; [ ]) => { - nft_set!($name, $id, $table, $family) - }; - ($name:expr, $id:expr, $table:expr, $family:expr; [ $($value:expr,)* ]) => {{ - let mut set = nft_set!($name, $id, $table, $family).expect("Set allocation failed"); - $( - set.add($value).expect(stringify!(Unable to add $value to set $name)); - )* - set - }}; -} - -pub struct Set<K> { - pub(crate) set: *mut sys::nftnl_set, - pub(crate) table: Rc<Table>, - pub(crate) family: ProtoFamily, - _marker: ::std::marker::PhantomData<K>, +use crate::table::Table; +use crate::ProtocolFamily; +use std::fmt::Debug; +use std::marker::PhantomData; + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(derive_deserialize = false)] +pub struct Set { + pub family: ProtocolFamily, + #[field(NFTA_SET_TABLE)] + pub table: String, + #[field(NFTA_SET_NAME)] + pub name: String, + #[field(NFTA_SET_FLAGS)] + pub flags: u32, + #[field(NFTA_SET_KEY_TYPE)] + pub key_type: u32, + #[field(NFTA_SET_KEY_LEN)] + pub key_len: u32, + #[field(NFTA_SET_ID)] + pub id: u32, + #[field(NFTA_SET_USERDATA)] + pub userdata: String, } -impl<K> Set<K> { - pub fn new(name: &CStr, id: u32, table: Rc<Table>, family: ProtoFamily) -> Self - where - K: SetKey, - { - unsafe { - let set = try_alloc!(sys::nftnl_set_alloc()); - - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_FAMILY as u16, family as u32); - sys::nftnl_set_set_str(set, sys::NFTNL_SET_TABLE as u16, table.get_name().as_ptr()); - sys::nftnl_set_set_str(set, sys::NFTNL_SET_NAME as u16, name.as_ptr()); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_ID as u16, id); - - sys::nftnl_set_set_u32( - set, - sys::NFTNL_SET_FLAGS as u16, - (libc::NFT_SET_ANONYMOUS | libc::NFT_SET_CONSTANT) as u32, - ); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_TYPE as u16, K::TYPE); - sys::nftnl_set_set_u32(set, sys::NFTNL_SET_KEY_LEN as u16, K::LEN); - - Set { - set, - table, - family, - _marker: ::std::marker::PhantomData, - } - } - } - - pub unsafe fn from_raw(set: *mut sys::nftnl_set, table: Rc<Table>, family: ProtoFamily) -> Self - where - K: SetKey, - { - Set { - set, - table, - family, - _marker: ::std::marker::PhantomData, - } - } - - pub fn add(&mut self, key: &K) - where - K: SetKey, - { - unsafe { - let elem = try_alloc!(sys::nftnl_set_elem_alloc()); - - let data = key.data(); - let data_len = data.len() as u32; - trace!("Adding key {:?} with len {}", data, data_len); - sys::nftnl_set_elem_set( - elem, - sys::NFTNL_SET_ELEM_KEY as u16, - data.as_ref() as *const _ as *const c_void, - data_len, - ); - sys::nftnl_set_elem_add(self.set, elem); - } - } - - pub fn elems_iter(&self) -> SetElemsIter<K> { - SetElemsIter::new(self) - } +impl NfNetlinkObject for Set { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSET; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELSET; - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_set { - self.set as *const sys::nftnl_set - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&self) -> *mut sys::nftnl_set { - self.set - } - - pub fn get_family(&self) -> ProtoFamily { + fn get_family(&self) -> ProtocolFamily { self.family } - /// Returns a textual description of the set. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_set_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.set, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - pub fn get_name(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_set_get_str(self.set, sys::NFTNL_SET_NAME as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } - } - - pub fn get_id(&self) -> u32 { - unsafe { sys::nftnl_set_get_u32(self.set, sys::NFTNL_SET_ID as u16) } + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -impl<K> Debug for Set<K> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} - -unsafe impl<K> crate::NlMsg for Set<K> { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let type_ = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWSET, - MsgType::Del => libc::NFT_MSG_DELSET, - }; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.table.get_family() as u16, - (libc::NLM_F_APPEND | libc::NLM_F_CREATE | libc::NLM_F_ACK) as u16, - seq, - ); - sys::nftnl_set_nlmsg_build_payload(header, self.set); - } +pub struct SetBuilder<K: DataType> { + inner: Set, + list: SetElementList, + _phantom: PhantomData<K>, } -impl<K> Drop for Set<K> { - fn drop(&mut self) { - unsafe { sys::nftnl_set_free(self.set) }; - } -} - -pub struct SetElemsIter<'a, K> { - set: &'a Set<K>, - iter: *mut sys::nftnl_set_elems_iter, - ret: Rc<Cell<i32>>, -} - -impl<'a, K> SetElemsIter<'a, K> { - fn new(set: &'a Set<K>) -> Self { - let iter = try_alloc!(unsafe { - sys::nftnl_set_elems_iter_create(set.set as *const sys::nftnl_set) +impl<K: DataType> SetBuilder<K> { + pub fn new(name: impl Into<String>, table: &Table) -> Result<Self, BuilderError> { + let table_name = table.get_name().ok_or(BuilderError::MissingTableName)?; + let set_name = name.into(); + let set = Set::default() + .with_key_type(K::TYPE) + .with_key_len(K::LEN) + .with_table(table_name) + .with_name(&set_name); + + Ok(SetBuilder { + inner: set, + list: SetElementList { + table: Some(table_name.clone()), + set: Some(set_name), + elements: Some(SetElementListElements::default()), + }, + _phantom: PhantomData, + }) + } + + pub fn add(&mut self, key: &K) { + self.list.elements.as_mut().unwrap().add_value(SetElement { + key: Some(NfNetlinkData::default().with_value(key.data())), }); - SetElemsIter { - set, - iter, - ret: Rc::new(Cell::new(1)), - } } -} - -impl<'a, K> Iterator for SetElemsIter<'a, K> { - type Item = SetElemsMsg<'a, K>; - fn next(&mut self) -> Option<Self::Item> { - if self.ret.get() <= 0 || unsafe { sys::nftnl_set_elems_iter_cur(self.iter).is_null() } { - trace!("SetElemsIter iterator ending"); - None - } else { - trace!("SetElemsIter returning new SetElemsMsg"); - Some(SetElemsMsg { - set: self.set, - iter: self.iter, - ret: self.ret.clone(), - }) - } + pub fn finish(self) -> (Set, SetElementList) { + (self.inner, self.list) } } -impl<'a, K> Drop for SetElemsIter<'a, K> { - fn drop(&mut self) { - unsafe { sys::nftnl_set_elems_iter_destroy(self.iter) }; - } +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true, derive_deserialize = false)] +pub struct SetElementList { + #[field(NFTA_SET_ELEM_LIST_TABLE)] + pub table: String, + #[field(NFTA_SET_ELEM_LIST_SET)] + pub set: String, + #[field(NFTA_SET_ELEM_LIST_ELEMENTS)] + pub elements: SetElementListElements, } -pub struct SetElemsMsg<'a, K> { - set: &'a Set<K>, - iter: *mut sys::nftnl_set_elems_iter, - ret: Rc<Cell<i32>>, -} +impl NfNetlinkObject for SetElementList { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWSETELEM; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELSETELEM; -unsafe impl<'a, K> crate::NlMsg for SetElemsMsg<'a, K> { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - trace!("Writing SetElemsMsg to NlMsg"); - let (type_, flags) = match msg_type { - MsgType::Add => ( - libc::NFT_MSG_NEWSETELEM, - libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK, - ), - MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK), - }; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - type_ as u16, - self.set.get_family() as u16, - flags as u16, - seq, - ); - self.ret.set(sys::nftnl_set_elems_nlmsg_build_payload_iter( - header, self.iter, - )); + fn get_family(&self) -> ProtocolFamily { + ProtocolFamily::Unspec } } -pub trait SetKey { - const TYPE: u32; - const LEN: u32; - - fn data(&self) -> Box<[u8]>; -} - -impl SetKey for Ipv4Addr { - const TYPE: u32 = 7; - const LEN: u32 = 4; - - fn data(&self) -> Box<[u8]> { - self.octets().to_vec().into_boxed_slice() - } +#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[nfnetlink_struct(nested = true)] +pub struct SetElement { + #[field(NFTA_SET_ELEM_KEY)] + pub key: NfNetlinkData, } -impl SetKey for Ipv6Addr { - const TYPE: u32 = 8; - const LEN: u32 = 16; - - fn data(&self) -> Box<[u8]> { - self.octets().to_vec().into_boxed_slice() - } -} +type SetElementListElements = NfNetlinkList<SetElement>; diff --git a/src/sys.rs b/src/sys.rs new file mode 100644 index 0000000..4384a1c --- /dev/null +++ b/src/sys.rs @@ -0,0 +1,3 @@ +#![allow(non_camel_case_types, dead_code)] + +include!(concat!(env!("OUT_DIR"), "/sys.rs")); diff --git a/src/table.rs b/src/table.rs index 593fffb..81a26ef 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,171 +1,68 @@ -use crate::{MsgType, ProtoFamily}; -use crate::sys::{self, libc}; -#[cfg(feature = "query")] -use std::convert::TryFrom; -use std::{ - ffi::{c_void, CStr, CString}, - fmt::Debug, - os::raw::c_char, +use std::fmt::Debug; + +use rustables_macros::nfnetlink_struct; + +use crate::error::QueryError; +use crate::nlmsg::NfNetlinkObject; +use crate::sys::{ + NFTA_TABLE_FLAGS, NFTA_TABLE_NAME, NFTA_TABLE_USERDATA, NFT_MSG_DELTABLE, NFT_MSG_GETTABLE, + NFT_MSG_NEWTABLE, }; +use crate::{Batch, ProtocolFamily}; -/// Abstraction of `nftnl_table`, the top level container in netfilter. A table has a protocol +/// Abstraction of a `nftnl_table`, the top level container in netfilter. A table has a protocol /// family and contains [`Chain`]s that in turn hold the rules. /// /// [`Chain`]: struct.Chain.html +#[derive(Default, PartialEq, Eq, Debug)] +#[nfnetlink_struct(derive_deserialize = false)] pub struct Table { - table: *mut sys::nftnl_table, - family: ProtoFamily, + family: ProtocolFamily, + #[field(NFTA_TABLE_NAME)] + name: String, + #[field(NFTA_TABLE_FLAGS)] + flags: u32, + #[field(NFTA_TABLE_USERDATA)] + userdata: Vec<u8>, } impl Table { - /// Creates a new table instance with the given name and protocol family. - pub fn new<T: AsRef<CStr>>(name: &T, family: ProtoFamily) -> Table { - unsafe { - let table = try_alloc!(sys::nftnl_table_alloc()); - - sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FAMILY as u16, family as u32); - sys::nftnl_table_set_str(table, sys::NFTNL_TABLE_NAME as u16, name.as_ref().as_ptr()); - sys::nftnl_table_set_u32(table, sys::NFTNL_TABLE_FLAGS as u16, 0u32); - Table { table, family } - } - } - - pub unsafe fn from_raw(table: *mut sys::nftnl_table, family: ProtoFamily) -> Self { - Table { table, family } - } - - /// Returns the name of this table. - pub fn get_name(&self) -> &CStr { - unsafe { - let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_NAME as u16); - if ptr.is_null() { - panic!("Impossible situation: retrieving the name of a chain failed") - } else { - CStr::from_ptr(ptr) - } - } - } - - /// Returns a textual description of the table. - pub fn get_str(&self) -> CString { - let mut descr_buf = vec![0i8; 4096]; - unsafe { - sys::nftnl_table_snprintf( - descr_buf.as_mut_ptr() as *mut c_char, - (descr_buf.len() - 1) as u64, - self.table, - sys::NFTNL_OUTPUT_DEFAULT, - 0, - ); - CStr::from_ptr(descr_buf.as_ptr() as *mut c_char).to_owned() - } - } - - /// Returns the protocol family for this table. - pub fn get_family(&self) -> ProtoFamily { - self.family - } - - /// Returns the userdata of this chain. - pub fn get_userdata(&self) -> Option<&CStr> { - unsafe { - let ptr = sys::nftnl_table_get_str(self.table, sys::NFTNL_TABLE_USERDATA as u16); - if !ptr.is_null() { - Some(CStr::from_ptr(ptr)) - } else { - None - } - } + pub fn new(family: ProtocolFamily) -> Table { + let mut res = Self::default(); + res.family = family; + res } - /// Updates the userdata of this chain. - pub fn set_userdata(&self, data: &CStr) { - unsafe { - sys::nftnl_table_set_str(self.table, sys::NFTNL_TABLE_USERDATA as u16, data.as_ptr()); - } - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns the raw handle. - pub fn as_ptr(&self) -> *const sys::nftnl_table { - self.table as *const sys::nftnl_table - } - - #[cfg(feature = "unsafe-raw-handles")] - /// Returns a mutable version of the raw handle. - pub fn as_mut_ptr(&self) -> *mut sys::nftnl_table { - self.table - } -} - -impl PartialEq for Table { - fn eq(&self, other: &Self) -> bool { - self.get_name() == other.get_name() && self.get_family() == other.get_family() + /// Appends this rule to `batch` + pub fn add_to_batch(self, batch: &mut Batch) -> Self { + batch.add(&self, crate::MsgType::Add); + self } } -impl Debug for Table { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.get_str()) - } -} +impl NfNetlinkObject for Table { + const MSG_TYPE_ADD: u32 = NFT_MSG_NEWTABLE; + const MSG_TYPE_DEL: u32 = NFT_MSG_DELTABLE; -unsafe impl crate::NlMsg for Table { - unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { - let raw_msg_type = match msg_type { - MsgType::Add => libc::NFT_MSG_NEWTABLE, - MsgType::Del => libc::NFT_MSG_DELTABLE, - }; - let header = sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - raw_msg_type as u16, - self.family as u16, - libc::NLM_F_ACK as u16, - seq, - ); - sys::nftnl_table_nlmsg_build_payload(header, self.table); - } -} - -impl Drop for Table { - fn drop(&mut self) { - unsafe { sys::nftnl_table_free(self.table) }; + fn get_family(&self) -> ProtocolFamily { + self.family } -} -#[cfg(feature = "query")] -/// A callback to parse the response for messages created with `get_tables_nlmsg`. -pub fn get_tables_cb( - header: &libc::nlmsghdr, - (_, tables): &mut (&(), &mut Vec<Table>), -) -> libc::c_int { - unsafe { - let table = sys::nftnl_table_alloc(); - if table == std::ptr::null_mut() { - return mnl::mnl_sys::MNL_CB_ERROR; - } - let err = sys::nftnl_table_nlmsg_parse(header, table); - if err < 0 { - error!("Failed to parse nelink table message - {}", err); - sys::nftnl_table_free(table); - return err; - } - let family = sys::nftnl_table_get_u32(table, sys::NFTNL_TABLE_FAMILY as u16); - match crate::ProtoFamily::try_from(family as i32) { - Ok(family) => { - tables.push(Table::from_raw(table, family)); - mnl::mnl_sys::MNL_CB_OK - } - Err(crate::InvalidProtocolFamily) => { - error!("The netlink table didn't have a valid protocol family !?"); - sys::nftnl_table_free(table); - mnl::mnl_sys::MNL_CB_ERROR - } - } + fn set_family(&mut self, family: ProtocolFamily) { + self.family = family; } } -#[cfg(feature = "query")] -pub fn list_tables() -> Result<Vec<Table>, crate::query::Error> { - crate::query::list_objects_with_data(libc::NFT_MSG_GETTABLE as u16, get_tables_cb, &(), None) +pub fn list_tables() -> Result<Vec<Table>, QueryError> { + let mut result = Vec::new(); + crate::query::list_objects_with_data( + NFT_MSG_GETTABLE as u16, + &|table: Table, tables: &mut Vec<Table>| { + tables.push(table); + Ok(()) + }, + None, + &mut result, + )?; + Ok(result) } diff --git a/src/tests/batch.rs b/src/tests/batch.rs new file mode 100644 index 0000000..12f373f --- /dev/null +++ b/src/tests/batch.rs @@ -0,0 +1,96 @@ +use std::mem::size_of; + +use libc::{AF_UNSPEC, NFNL_MSG_BATCH_BEGIN, NLM_F_REQUEST}; +use nix::libc::NFNL_MSG_BATCH_END; + +use crate::nlmsg::{pad_netlink_object_with_variable_size, NfNetlinkDeserializable}; +use crate::parser::{parse_nlmsg, NlMsg}; +use crate::sys::{nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_SUBSYS_NFTABLES}; +use crate::{Batch, MsgType, Table}; + +use super::get_test_table; + +const HEADER_SIZE: u32 = + pad_netlink_object_with_variable_size(size_of::<nlmsghdr>() + size_of::<nfgenmsg>()) as u32; + +const DEFAULT_BATCH_BEGIN_HDR: nlmsghdr = nlmsghdr { + nlmsg_len: HEADER_SIZE, + nlmsg_flags: NLM_F_REQUEST as u16, + nlmsg_type: NFNL_MSG_BATCH_BEGIN as u16, + nlmsg_seq: 0, + nlmsg_pid: 0, +}; +const DEFAULT_BATCH_MSG: NlMsg = NlMsg::NfGenMsg( + nfgenmsg { + nfgen_family: AF_UNSPEC as u8, + version: NFNETLINK_V0 as u8, + res_id: NFNL_SUBSYS_NFTABLES as u16, + }, + &[], +); + +const DEFAULT_BATCH_END_HDR: nlmsghdr = nlmsghdr { + nlmsg_len: HEADER_SIZE, + nlmsg_flags: NLM_F_REQUEST as u16, + nlmsg_type: NFNL_MSG_BATCH_END as u16, + nlmsg_seq: 1, + nlmsg_pid: 0, +}; + +#[test] +fn batch_empty() { + let batch = Batch::new(); + let buf = batch.finalize(); + + let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message"); + assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR); + assert_eq!(msg, DEFAULT_BATCH_MSG); + + let remaining_data_offset = pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize); + + let (hdr, msg) = parse_nlmsg(&buf[remaining_data_offset..]).expect("Invalid nlmsg message"); + assert_eq!(hdr, DEFAULT_BATCH_END_HDR); + assert_eq!(msg, DEFAULT_BATCH_MSG); +} + +#[test] +fn batch_with_objects() { + let mut original_tables = vec![]; + for i in 0..10 { + let mut table = get_test_table(); + table.set_userdata(vec![i as u8]); + original_tables.push(table); + } + + let mut batch = Batch::new(); + for i in 0..10 { + batch.add( + &original_tables[i], + if i % 2 == 0 { + MsgType::Add + } else { + MsgType::Del + }, + ); + } + let buf = batch.finalize(); + + let (hdr, msg) = parse_nlmsg(&buf).expect("Invalid nlmsg message"); + assert_eq!(hdr, DEFAULT_BATCH_BEGIN_HDR); + assert_eq!(msg, DEFAULT_BATCH_MSG); + let mut remaining_data = &buf[pad_netlink_object_with_variable_size(hdr.nlmsg_len as usize)..]; + + for i in 0..10 { + let (deserialized_table, rest) = + Table::deserialize(&remaining_data).expect("could not deserialize a table"); + remaining_data = rest; + + assert_eq!(deserialized_table, original_tables[i]); + } + + let (hdr, msg) = parse_nlmsg(&remaining_data).expect("Invalid nlmsg message"); + let mut end_hdr = DEFAULT_BATCH_END_HDR; + end_hdr.nlmsg_seq = 11; + assert_eq!(hdr, end_hdr); + assert_eq!(msg, DEFAULT_BATCH_MSG); +} diff --git a/src/tests/chain.rs b/src/tests/chain.rs new file mode 100644 index 0000000..7f696e6 --- /dev/null +++ b/src/tests/chain.rs @@ -0,0 +1,120 @@ +use crate::{ + nlmsg::get_operation_from_nlmsghdr_type, + sys::{ + NFTA_CHAIN_HOOK, NFTA_CHAIN_NAME, NFTA_CHAIN_TABLE, NFTA_CHAIN_TYPE, NFTA_CHAIN_USERDATA, + NFTA_HOOK_HOOKNUM, NFTA_HOOK_PRIORITY, NFT_MSG_DELCHAIN, NFT_MSG_NEWCHAIN, + }, + ChainType, Hook, HookClass, MsgType, +}; + +use super::{ + get_test_chain, get_test_nlmsg, get_test_nlmsg_with_msg_type, NetlinkExpr, CHAIN_NAME, + CHAIN_USERDATA, TABLE_NAME, +}; + +#[test] +fn new_empty_chain() { + let mut chain = get_test_chain(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWCHAIN as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_chain_with_hook_and_type() { + let mut chain = get_test_chain() + .with_hook(Hook::new(HookClass::In, 0)) + .with_type(ChainType::Filter); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWCHAIN as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 84); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_TYPE, "filter".as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_CHAIN_HOOK, + vec![ + NetlinkExpr::List(vec![NetlinkExpr::Final( + NFTA_HOOK_HOOKNUM, + vec![0, 0, 0, 1] + )]), + NetlinkExpr::List(vec![NetlinkExpr::Final( + NFTA_HOOK_PRIORITY, + vec![0, 0, 0, 0] + )]) + ] + ), + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_chain_with_userdata() { + let mut chain = get_test_chain(); + chain.set_userdata(CHAIN_USERDATA); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut chain); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWCHAIN as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 72); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_USERDATA, CHAIN_USERDATA.as_bytes().to_vec()) + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_chain() { + let mut chain = get_test_chain(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut chain, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELCHAIN as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_CHAIN_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CHAIN_NAME, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} diff --git a/src/tests/expr.rs b/src/tests/expr.rs new file mode 100644 index 0000000..35c4fea --- /dev/null +++ b/src/tests/expr.rs @@ -0,0 +1,591 @@ +use std::net::Ipv4Addr; + +use libc::NF_DROP; + +use crate::{ + expr::{ + Bitwise, Cmp, CmpOp, Conntrack, ConntrackKey, Counter, ExpressionList, HeaderField, + HighLevelPayload, IcmpCode, Immediate, Log, Lookup, Masquerade, Meta, MetaType, Nat, + NatType, Register, Reject, RejectType, TCPHeaderField, TransportHeaderField, VerdictKind, + }, + set::SetBuilder, + sys::{ + NFTA_BITWISE_DREG, NFTA_BITWISE_LEN, NFTA_BITWISE_MASK, NFTA_BITWISE_SREG, + NFTA_BITWISE_XOR, NFTA_CMP_DATA, NFTA_CMP_OP, NFTA_CMP_SREG, NFTA_COUNTER_BYTES, + NFTA_COUNTER_PACKETS, NFTA_CT_DREG, NFTA_CT_KEY, NFTA_DATA_VALUE, NFTA_DATA_VERDICT, + NFTA_EXPR_DATA, NFTA_EXPR_NAME, NFTA_IMMEDIATE_DATA, NFTA_IMMEDIATE_DREG, NFTA_LIST_ELEM, + NFTA_LOG_GROUP, NFTA_LOG_PREFIX, NFTA_LOOKUP_SET, NFTA_LOOKUP_SREG, NFTA_META_DREG, + NFTA_META_KEY, NFTA_NAT_FAMILY, NFTA_NAT_REG_ADDR_MIN, NFTA_NAT_TYPE, NFTA_PAYLOAD_BASE, + NFTA_PAYLOAD_DREG, NFTA_PAYLOAD_LEN, NFTA_PAYLOAD_OFFSET, NFTA_REJECT_ICMP_CODE, + NFTA_REJECT_TYPE, NFTA_RULE_CHAIN, NFTA_RULE_EXPRESSIONS, NFTA_RULE_TABLE, + NFTA_VERDICT_CODE, NFT_CMP_EQ, NFT_CT_STATE, NFT_META_PROTOCOL, NFT_NAT_SNAT, + NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG_1, NFT_REG_VERDICT, NFT_REJECT_ICMPX_UNREACH, + }, + tests::{get_test_table, SET_NAME}, + ProtocolFamily, +}; + +use super::{get_test_nlmsg, get_test_rule, NetlinkExpr, CHAIN_NAME, TABLE_NAME}; + +#[test] +fn bitwise_expr_is_valid() { + let netmask = Ipv4Addr::new(255, 255, 255, 0); + let bitwise = Bitwise::new(netmask.octets(), [0, 0, 0, 0]).unwrap(); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(bitwise)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 124); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"bitwise".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_BITWISE_SREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_BITWISE_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Final(NFTA_BITWISE_LEN, 4u32.to_be_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_BITWISE_MASK, + vec![NetlinkExpr::Final( + NFTA_DATA_VALUE, + vec![255, 255, 255, 0] + )] + ), + NetlinkExpr::Nested( + NFTA_BITWISE_XOR, + vec![NetlinkExpr::Final( + NFTA_DATA_VALUE, + 0u32.to_be_bytes().to_vec() + )] + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn cmp_expr_is_valid() { + let val = [1u8, 2, 3, 4]; + let cmp = Cmp::new(CmpOp::Eq, val.clone()); + let mut rule = get_test_rule().with_expressions(vec![cmp]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 100); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"cmp".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final(NFTA_CMP_SREG, NFT_REG_1.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_CMP_OP, NFT_CMP_EQ.to_be_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_CMP_DATA, + vec![NetlinkExpr::Final(NFTA_DATA_VALUE, val.to_vec())] + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn counter_expr_is_valid() { + let nb_bytes = 123456u64; + let nb_packets = 987u64; + let counter = Counter::default() + .with_nb_bytes(nb_bytes) + .with_nb_packets(nb_packets); + + let mut rule = get_test_rule().with_expressions(vec![counter]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 100); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"counter".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_COUNTER_BYTES, + nb_bytes.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_COUNTER_PACKETS, + nb_packets.to_be_bytes().to_vec() + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn ct_expr_is_valid() { + let ct = Conntrack::default().with_retrieve_value(ConntrackKey::State); + let mut rule = get_test_rule().with_expressions(vec![ct]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 88); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"ct".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_CT_KEY, + NFT_CT_STATE.to_be_bytes().to_vec() + ), + NetlinkExpr::Final(NFTA_CT_DREG, NFT_REG_1.to_be_bytes().to_vec()) + ] + ) + ] + )] + ) + ]) + .to_raw() + ) +} + +#[test] +fn immediate_expr_is_valid() { + let immediate = Immediate::new_data(vec![42u8], Register::Reg1); + let mut rule = + get_test_rule().with_expressions(ExpressionList::default().with_value(immediate)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 100); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_IMMEDIATE_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Nested( + NFTA_IMMEDIATE_DATA, + vec![NetlinkExpr::Final(1u16, 42u8.to_be_bytes().to_vec())] + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn log_expr_is_valid() { + let log = Log::new(Some(1337), Some("mockprefix")).expect("Could not build a log expression"); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(log)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 96); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"log".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final(NFTA_LOG_GROUP, 1337u16.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_LOG_PREFIX, b"mockprefix".to_vec()), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn lookup_expr_is_valid() { + let table = get_test_table(); + let mut set_builder = SetBuilder::new(SET_NAME, &table).unwrap(); + let address: Ipv4Addr = [8, 8, 8, 8].into(); + set_builder.add(&address); + let (set, _set_elements) = set_builder.finish(); + let lookup = Lookup::new(&set).unwrap(); + + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(lookup)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 96); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"lookup".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final(NFTA_LOOKUP_SET, b"mockset".to_vec()), + NetlinkExpr::Final( + NFTA_LOOKUP_SREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn masquerade_expr_is_valid() { + let masquerade = Masquerade::default(); + let mut rule = get_test_rule().with_expressions(vec![masquerade]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 72); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"masq".to_vec()), + NetlinkExpr::Nested(NFTA_EXPR_DATA, vec![]), + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn meta_expr_is_valid() { + let meta = Meta::default() + .with_key(MetaType::Protocol) + .with_dreg(Register::Reg1); + let mut rule = get_test_rule().with_expressions(vec![meta]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 88); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"meta".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_META_KEY, + NFT_META_PROTOCOL.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_META_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn nat_expr_is_valid() { + let nat = Nat::default() + .with_nat_type(NatType::SNat) + .with_family(ProtocolFamily::Ipv4) + .with_ip_register(Register::Reg1); + let mut rule = get_test_rule().with_expressions(vec![nat]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 96); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"nat".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_NAT_TYPE, + NFT_NAT_SNAT.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_NAT_FAMILY, + (ProtocolFamily::Ipv4 as u32).to_be_bytes().to_vec(), + ), + NetlinkExpr::Final( + NFTA_NAT_REG_ADDR_MIN, + NFT_REG_1.to_be_bytes().to_vec() + ) + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn payload_expr_is_valid() { + let tcp_header_field = TCPHeaderField::Sport; + let transport_header_field = TransportHeaderField::Tcp(tcp_header_field); + let payload = HighLevelPayload::Transport(transport_header_field); + let mut rule = get_test_rule().with_expressions(vec![payload.build()]); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 108); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"payload".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_PAYLOAD_DREG, + NFT_REG_1.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_PAYLOAD_BASE, + NFT_PAYLOAD_TRANSPORT_HEADER.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_PAYLOAD_OFFSET, + tcp_header_field.offset().to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_PAYLOAD_LEN, + tcp_header_field.len().to_be_bytes().to_vec() + ), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn reject_expr_is_valid() { + let code = IcmpCode::NoRoute; + let reject = Reject::default() + .with_type(RejectType::IcmpxUnreach) + .with_icmp_code(code); + let mut rule = get_test_rule().with_expressions(vec![reject]); + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 92); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"reject".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_REJECT_TYPE, + NFT_REJECT_ICMPX_UNREACH.to_be_bytes().to_vec() + ), + NetlinkExpr::Final( + NFTA_REJECT_ICMP_CODE, + (code as u8).to_be_bytes().to_vec() + ), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} + +#[test] +fn verdict_expr_is_valid() { + let verdict = Immediate::new_verdict(VerdictKind::Drop); + let mut rule = get_test_rule().with_expressions(ExpressionList::default().with_value(verdict)); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!(nlmsghdr.nlmsg_len, 104); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_RULE_EXPRESSIONS, + vec![NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![ + NetlinkExpr::Final(NFTA_EXPR_NAME, b"immediate".to_vec()), + NetlinkExpr::Nested( + NFTA_EXPR_DATA, + vec![ + NetlinkExpr::Final( + NFTA_IMMEDIATE_DREG, + NFT_REG_VERDICT.to_be_bytes().to_vec() + ), + NetlinkExpr::Nested( + NFTA_IMMEDIATE_DATA, + vec![NetlinkExpr::Nested( + NFTA_DATA_VERDICT, + vec![NetlinkExpr::Final( + NFTA_VERDICT_CODE, + NF_DROP.to_be_bytes().to_vec() + ),] + )], + ), + ] + ) + ] + )] + ) + ]) + .to_raw() + ); +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs new file mode 100644 index 0000000..75fe8b0 --- /dev/null +++ b/src/tests/mod.rs @@ -0,0 +1,193 @@ +use crate::data_type::DataType; +use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter}; +use crate::parser::{parse_nlmsg, NlMsg}; +use crate::set::{Set, SetBuilder}; +use crate::{sys::*, Chain, MsgType, ProtocolFamily, Rule, Table}; + +mod batch; +mod chain; +mod expr; +mod rule; +mod set; +mod table; + +pub const TABLE_NAME: &'static str = "mocktable"; +pub const CHAIN_NAME: &'static str = "mockchain"; +pub const SET_NAME: &'static str = "mockset"; + +pub const TABLE_USERDATA: &'static str = "mocktabledata"; +pub const CHAIN_USERDATA: &'static str = "mockchaindata"; +pub const RULE_USERDATA: &'static str = "mockruledata"; +pub const SET_USERDATA: &'static str = "mocksetdata"; + +type NetLinkType = u16; + +#[derive(Debug, thiserror::Error)] +#[error("empty data")] +pub struct EmptyDataError; + +#[derive(Debug, Clone, Eq, Ord)] +pub enum NetlinkExpr { + Nested(NetLinkType, Vec<NetlinkExpr>), + Final(NetLinkType, Vec<u8>), + List(Vec<NetlinkExpr>), +} + +impl NetlinkExpr { + pub fn to_raw(self) -> Vec<u8> { + match self.sort() { + NetlinkExpr::Final(ty, val) => { + let len = val.len() + 4; + let mut res = Vec::with_capacity(len); + + res.extend(&(len as u16).to_le_bytes()); + res.extend(&ty.to_le_bytes()); + res.extend(val); + // alignment + while res.len() % 4 != 0 { + res.push(0); + } + + res + } + NetlinkExpr::Nested(ty, exprs) => { + // some heuristic to decrease allocations (even though this is + // only useful for testing so performance is not an objective) + let mut sub = Vec::with_capacity(exprs.len() * 50); + + for expr in exprs { + sub.append(&mut expr.to_raw()); + } + + let len = sub.len() + 4; + let mut res = Vec::with_capacity(len); + + // set the "NESTED" flag + res.extend(&(len as u16).to_le_bytes()); + res.extend(&(ty | NLA_F_NESTED as u16).to_le_bytes()); + res.extend(sub); + + res + } + NetlinkExpr::List(exprs) => { + // some heuristic to decrease allocations (even though this is + // only useful for testing so performance is not an objective) + let mut list = Vec::with_capacity(exprs.len() * 50); + + for expr in exprs { + list.append(&mut expr.to_raw()); + } + + list + } + } + } + + pub fn sort(self) -> Self { + match self { + NetlinkExpr::Final(_, _) => self, + NetlinkExpr::Nested(ty, mut exprs) => { + exprs.sort(); + NetlinkExpr::Nested(ty, exprs) + } + NetlinkExpr::List(mut exprs) => { + exprs.sort(); + NetlinkExpr::List(exprs) + } + } + } +} + +impl PartialEq for NetlinkExpr { + fn eq(&self, other: &Self) -> bool { + match (self.clone().sort(), other.clone().sort()) { + (NetlinkExpr::Nested(k1, v1), NetlinkExpr::Nested(k2, v2)) => k1 == k2 && v1 == v2, + (NetlinkExpr::Final(k1, v1), NetlinkExpr::Final(k2, v2)) => k1 == k2 && v1 == v2, + (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1 == v2, + _ => false, + } + } +} + +impl PartialOrd for NetlinkExpr { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + match (self, other) { + ( + NetlinkExpr::Nested(k1, _) | NetlinkExpr::Final(k1, _), + NetlinkExpr::Nested(k2, _) | NetlinkExpr::Final(k2, _), + ) => k1.partial_cmp(k2), + (NetlinkExpr::List(v1), NetlinkExpr::List(v2)) => v1.partial_cmp(v2), + (_, NetlinkExpr::List(_)) => Some(std::cmp::Ordering::Less), + (NetlinkExpr::List(_), _) => Some(std::cmp::Ordering::Greater), + } + } +} + +pub fn get_test_table() -> Table { + Table::new(ProtocolFamily::Inet) + .with_name(TABLE_NAME) + .with_flags(0u32) +} + +pub fn get_test_table_raw_expr() -> NetlinkExpr { + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()), + ]) + .sort() +} + +pub fn get_test_table_with_userdata_raw_expr() -> NetlinkExpr { + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_TABLE_FLAGS, 0u32.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_TABLE_NAME, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_TABLE_USERDATA, TABLE_USERDATA.as_bytes().to_vec()), + ]) + .sort() +} + +pub fn get_test_chain() -> Chain { + Chain::new(&get_test_table()).with_name(CHAIN_NAME) +} + +pub fn get_test_rule() -> Rule { + Rule::new(&get_test_chain()).unwrap() +} + +pub fn get_test_set<K: DataType>() -> Set { + SetBuilder::<K>::new(SET_NAME, &get_test_table()) + .expect("Couldn't create a set") + .finish() + .0 + .with_userdata(SET_USERDATA) +} + +pub fn get_test_nlmsg_with_msg_type<'a>( + buf: &'a mut Vec<u8>, + obj: &mut impl NfNetlinkObject, + msg_type: MsgType, +) -> (nlmsghdr, nfgenmsg, &'a [u8]) { + let mut writer = NfNetlinkWriter::new(buf); + obj.add_or_remove(&mut writer, msg_type, 0); + + let (hdr, msg) = parse_nlmsg(buf.as_slice()).expect("Couldn't parse the message"); + + let (nfgenmsg, raw_value) = match msg { + NlMsg::NfGenMsg(nfgenmsg, raw_value) => (nfgenmsg, raw_value), + _ => panic!("Invalid return value type, expected a valid message"), + }; + + // sanity checks on the global message (this should be very similar/factorisable for the + // most part in other tests) + // TODO: check the messages flags + assert_eq!(nfgenmsg.res_id.to_be(), 0); + + (hdr, nfgenmsg, raw_value) +} + +pub fn get_test_nlmsg<'a>( + buf: &'a mut Vec<u8>, + obj: &mut impl NfNetlinkObject, +) -> (nlmsghdr, nfgenmsg, &'a [u8]) { + get_test_nlmsg_with_msg_type(buf, obj, MsgType::Add) +} diff --git a/src/tests/rule.rs b/src/tests/rule.rs new file mode 100644 index 0000000..08b4139 --- /dev/null +++ b/src/tests/rule.rs @@ -0,0 +1,132 @@ +use crate::{ + nlmsg::get_operation_from_nlmsghdr_type, + sys::{ + NFTA_RULE_CHAIN, NFTA_RULE_HANDLE, NFTA_RULE_POSITION, NFTA_RULE_TABLE, NFTA_RULE_USERDATA, + NFT_MSG_DELRULE, NFT_MSG_NEWRULE, + }, + MsgType, +}; + +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_rule, NetlinkExpr, CHAIN_NAME, + RULE_USERDATA, TABLE_NAME, +}; + +#[test] +fn new_empty_rule() { + let mut rule = get_test_rule(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_rule_with_userdata() { + let mut rule = get_test_rule().with_userdata(RULE_USERDATA); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 68); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_USERDATA, RULE_USERDATA.as_bytes().to_vec()) + ]) + .to_raw() + ); +} + +#[test] +fn new_empty_rule_with_position_and_handle() { + let handle: u64 = 1337; + let position: u64 = 42; + let mut rule = get_test_rule().with_handle(handle).with_position(position); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut rule); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 76); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_POSITION, position.to_be_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_rule() { + let mut rule = get_test_rule(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 52); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_rule_with_handle() { + let handle: u64 = 42; + let mut rule = get_test_rule().with_handle(handle); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut rule, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELRULE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 64); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_RULE_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_CHAIN, CHAIN_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_RULE_HANDLE, handle.to_be_bytes().to_vec()), + ]) + .to_raw() + ); +} diff --git a/src/tests/set.rs b/src/tests/set.rs new file mode 100644 index 0000000..6c8247c --- /dev/null +++ b/src/tests/set.rs @@ -0,0 +1,119 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use crate::{ + data_type::DataType, + nlmsg::get_operation_from_nlmsghdr_type, + set::SetBuilder, + sys::{ + NFTA_DATA_VALUE, NFTA_LIST_ELEM, NFTA_SET_ELEM_KEY, NFTA_SET_ELEM_LIST_ELEMENTS, + NFTA_SET_ELEM_LIST_SET, NFTA_SET_ELEM_LIST_TABLE, NFTA_SET_KEY_LEN, NFTA_SET_KEY_TYPE, + NFTA_SET_NAME, NFTA_SET_TABLE, NFTA_SET_USERDATA, NFT_MSG_DELSET, NFT_MSG_NEWSET, + NFT_MSG_NEWSETELEM, + }, + MsgType, +}; + +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_set, get_test_table, NetlinkExpr, + SET_NAME, SET_USERDATA, TABLE_NAME, +}; + +#[test] +fn new_empty_set() { + let mut set = get_test_set::<Ipv4Addr>(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut set); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWSET as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 80); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv4Addr::TYPE.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv4Addr::LEN.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn delete_empty_set() { + let mut set = get_test_set::<Ipv6Addr>(); + + let mut buf = Vec::new(); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut set, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELSET as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 80); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_NAME, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_TYPE, Ipv6Addr::TYPE.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_KEY_LEN, Ipv6Addr::LEN.to_be_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_USERDATA, SET_USERDATA.as_bytes().to_vec()), + ]) + .to_raw() + ); +} + +#[test] +fn new_set_with_data() { + let ip1 = Ipv4Addr::new(127, 0, 0, 1); + let ip2 = Ipv4Addr::new(1, 1, 1, 1); + let mut set_builder = SetBuilder::<Ipv4Addr>::new(SET_NAME.to_string(), &get_test_table()) + .expect("Couldn't create a set"); + + set_builder.add(&ip1); + set_builder.add(&ip2); + let (_set, mut elem_list) = set_builder.finish(); + + let mut buf = Vec::new(); + + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut elem_list); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWSETELEM as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 84); + + assert_eq!( + raw_expr, + NetlinkExpr::List(vec![ + NetlinkExpr::Final(NFTA_SET_ELEM_LIST_TABLE, TABLE_NAME.as_bytes().to_vec()), + NetlinkExpr::Final(NFTA_SET_ELEM_LIST_SET, SET_NAME.as_bytes().to_vec()), + NetlinkExpr::Nested( + NFTA_SET_ELEM_LIST_ELEMENTS, + vec![ + NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![NetlinkExpr::Nested( + NFTA_DATA_VALUE, + vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip1.data().to_vec())] + )] + ), + NetlinkExpr::Nested( + NFTA_LIST_ELEM, + vec![NetlinkExpr::Nested( + NFTA_DATA_VALUE, + vec![NetlinkExpr::Final(NFTA_SET_ELEM_KEY, ip2.data().to_vec())] + )] + ), + ] + ), + ]) + .to_raw() + ); +} diff --git a/src/tests/table.rs b/src/tests/table.rs new file mode 100644 index 0000000..39bf399 --- /dev/null +++ b/src/tests/table.rs @@ -0,0 +1,67 @@ +use crate::{ + nlmsg::{get_operation_from_nlmsghdr_type, nft_nlmsg_maxsize, NfNetlinkDeserializable}, + sys::{NFT_MSG_DELTABLE, NFT_MSG_NEWTABLE}, + MsgType, Table, +}; + +use super::{ + get_test_nlmsg, get_test_nlmsg_with_msg_type, get_test_table, get_test_table_raw_expr, + get_test_table_with_userdata_raw_expr, TABLE_USERDATA, +}; + +#[test] +fn new_empty_table() { + let mut table = get_test_table(); + let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut table); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWTABLE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 44); + + assert_eq!(raw_expr, get_test_table_raw_expr().to_raw()); +} + +#[test] +fn new_empty_table_with_userdata() { + let mut table = get_test_table(); + table.set_userdata(TABLE_USERDATA.as_bytes().to_vec()); + let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize); + let (nlmsghdr, _nfgenmsg, raw_expr) = get_test_nlmsg(&mut buf, &mut table); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_NEWTABLE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 64); + + assert_eq!(raw_expr, get_test_table_with_userdata_raw_expr().to_raw()); +} + +#[test] +fn delete_empty_table() { + let mut table = get_test_table(); + let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize); + let (nlmsghdr, _nfgenmsg, raw_expr) = + get_test_nlmsg_with_msg_type(&mut buf, &mut table, MsgType::Del); + assert_eq!( + get_operation_from_nlmsghdr_type(nlmsghdr.nlmsg_type), + NFT_MSG_DELTABLE as u8 + ); + assert_eq!(nlmsghdr.nlmsg_len, 44); + + assert_eq!(raw_expr, get_test_table_raw_expr().to_raw()); +} + +#[test] +fn parse_table() { + let mut table = get_test_table(); + table.set_userdata(TABLE_USERDATA.as_bytes().to_vec()); + let mut buf = Vec::with_capacity(nft_nlmsg_maxsize() as usize); + let (_nlmsghdr, _nfgenmsg, _raw_expr) = get_test_nlmsg(&mut buf, &mut table); + + let (deserialized_table, remaining) = + Table::deserialize(&buf).expect("Couldn't deserialize the object"); + assert_eq!(table, deserialized_table); + assert_eq!(remaining.len(), 0); +} |